Add show_progress option for trainer (#641)
- A DummyTqdm class added to utils: it replicates the interface used by trainers, but does not show the progress bar; - Added a show_progress argument to the base trainer: when show_progress == True, dummy_tqdm is used in place of tqdm.
This commit is contained in:
parent
53e6b0408d
commit
c87b9f49bc
@ -8,13 +8,13 @@ import cv2
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
|
||||
try:
|
||||
import envpool
|
||||
except ImportError:
|
||||
envpool = None
|
||||
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
|
||||
|
||||
class NoopResetEnv(gym.Wrapper):
|
||||
"""Sample initial states by taking random number of no-ops on reset.
|
||||
|
@ -1,14 +1,14 @@
|
||||
import warnings
|
||||
|
||||
import gym
|
||||
|
||||
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
||||
|
||||
try:
|
||||
import envpool
|
||||
except ImportError:
|
||||
envpool = None
|
||||
|
||||
import gym
|
||||
|
||||
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
||||
|
||||
|
||||
def make_mujoco_env(task, seed, training_num, test_num, obs_norm):
|
||||
"""Wrapper function for Mujoco env.
|
||||
|
@ -80,7 +80,7 @@ def test_pg(args=get_args()):
|
||||
dist,
|
||||
args.gamma,
|
||||
reward_normalization=args.rew_norm,
|
||||
action_space=env.action_space
|
||||
action_space=env.action_space,
|
||||
)
|
||||
for m in net.modules():
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
@ -116,7 +116,7 @@ def test_pg(args=get_args()):
|
||||
episode_per_collect=args.episode_per_collect,
|
||||
stop_fn=stop_fn,
|
||||
save_best_fn=save_best_fn,
|
||||
logger=logger
|
||||
logger=logger,
|
||||
)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
|
@ -58,6 +58,7 @@ def get_args():
|
||||
help='watch the play of pre-trained policy only',
|
||||
)
|
||||
parser.add_argument("--load-buffer-name", type=str, default=expert_file_name())
|
||||
parser.add_argument("--show-progress", action="store_true")
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
@ -209,6 +210,7 @@ def test_bcq(args=get_args()):
|
||||
save_best_fn=save_best_fn,
|
||||
stop_fn=stop_fn,
|
||||
logger=logger,
|
||||
show_progress=args.show_progress,
|
||||
)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
|
@ -9,7 +9,14 @@ import tqdm
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.trainer.utils import gather_info, test_episode
|
||||
from tianshou.utils import BaseLogger, LazyLogger, MovAvg, deprecation, tqdm_config
|
||||
from tianshou.utils import (
|
||||
BaseLogger,
|
||||
DummyTqdm,
|
||||
LazyLogger,
|
||||
MovAvg,
|
||||
deprecation,
|
||||
tqdm_config,
|
||||
)
|
||||
|
||||
|
||||
class BaseTrainer(ABC):
|
||||
@ -68,6 +75,8 @@ class BaseTrainer(ABC):
|
||||
:param BaseLogger logger: A logger that logs statistics during
|
||||
training/testing/updating. Default to a logger that doesn't log anything.
|
||||
:param bool verbose: whether to print the information. Default to True.
|
||||
:param bool show_progress: whether to display a progress bar when training.
|
||||
Default to True.
|
||||
:param bool test_in_train: whether to test in the training phase.
|
||||
Default to True.
|
||||
"""
|
||||
@ -143,6 +152,7 @@ class BaseTrainer(ABC):
|
||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||
logger: BaseLogger = LazyLogger(),
|
||||
verbose: bool = True,
|
||||
show_progress: bool = True,
|
||||
test_in_train: bool = True,
|
||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
):
|
||||
@ -190,6 +200,7 @@ class BaseTrainer(ABC):
|
||||
|
||||
self.reward_metric = reward_metric
|
||||
self.verbose = verbose
|
||||
self.show_progress = show_progress
|
||||
self.test_in_train = test_in_train
|
||||
self.resume_from_log = resume_from_log
|
||||
|
||||
@ -259,8 +270,14 @@ class BaseTrainer(ABC):
|
||||
self.policy.train()
|
||||
|
||||
epoch_stat: Dict[str, Any] = dict()
|
||||
|
||||
if self.show_progress:
|
||||
progress = tqdm.tqdm
|
||||
else:
|
||||
progress = DummyTqdm
|
||||
|
||||
# perform n step_per_epoch
|
||||
with tqdm.tqdm(
|
||||
with progress(
|
||||
total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config
|
||||
) as t:
|
||||
while t.n < t.total and not self.stop_fn_flag:
|
||||
|
@ -49,6 +49,8 @@ class OfflineTrainer(BaseTrainer):
|
||||
:param BaseLogger logger: A logger that logs statistics during
|
||||
updating/testing. Default to a logger that doesn't log anything.
|
||||
:param bool verbose: whether to print the information. Default to True.
|
||||
:param bool show_progress: whether to display a progress bar when training.
|
||||
Default to True.
|
||||
"""
|
||||
|
||||
__doc__ = BaseTrainer.gen_doc("offline") + "\n".join(__doc__.split("\n")[1:])
|
||||
@ -70,6 +72,7 @@ class OfflineTrainer(BaseTrainer):
|
||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||
logger: BaseLogger = LazyLogger(),
|
||||
verbose: bool = True,
|
||||
show_progress: bool = True,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(
|
||||
@ -90,6 +93,7 @@ class OfflineTrainer(BaseTrainer):
|
||||
reward_metric=reward_metric,
|
||||
logger=logger,
|
||||
verbose=verbose,
|
||||
show_progress=show_progress,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -57,6 +57,8 @@ class OffpolicyTrainer(BaseTrainer):
|
||||
:param BaseLogger logger: A logger that logs statistics during
|
||||
training/testing/updating. Default to a logger that doesn't log anything.
|
||||
:param bool verbose: whether to print the information. Default to True.
|
||||
:param bool show_progress: whether to display a progress bar when training.
|
||||
Default to True.
|
||||
:param bool test_in_train: whether to test in the training phase.
|
||||
Default to True.
|
||||
"""
|
||||
@ -83,6 +85,7 @@ class OffpolicyTrainer(BaseTrainer):
|
||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||
logger: BaseLogger = LazyLogger(),
|
||||
verbose: bool = True,
|
||||
show_progress: bool = True,
|
||||
test_in_train: bool = True,
|
||||
**kwargs: Any,
|
||||
):
|
||||
@ -106,6 +109,7 @@ class OffpolicyTrainer(BaseTrainer):
|
||||
reward_metric=reward_metric,
|
||||
logger=logger,
|
||||
verbose=verbose,
|
||||
show_progress=show_progress,
|
||||
test_in_train=test_in_train,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -60,6 +60,8 @@ class OnpolicyTrainer(BaseTrainer):
|
||||
:param BaseLogger logger: A logger that logs statistics during
|
||||
training/testing/updating. Default to a logger that doesn't log anything.
|
||||
:param bool verbose: whether to print the information. Default to True.
|
||||
:param bool show_progress: whether to display a progress bar when training.
|
||||
Default to True.
|
||||
:param bool test_in_train: whether to test in the training phase. Default to
|
||||
True.
|
||||
|
||||
@ -91,6 +93,7 @@ class OnpolicyTrainer(BaseTrainer):
|
||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||
logger: BaseLogger = LazyLogger(),
|
||||
verbose: bool = True,
|
||||
show_progress: bool = True,
|
||||
test_in_train: bool = True,
|
||||
**kwargs: Any,
|
||||
):
|
||||
@ -115,6 +118,7 @@ class OnpolicyTrainer(BaseTrainer):
|
||||
reward_metric=reward_metric,
|
||||
logger=logger,
|
||||
verbose=verbose,
|
||||
show_progress=show_progress,
|
||||
test_in_train=test_in_train,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -1,10 +1,10 @@
|
||||
"""Utils package."""
|
||||
|
||||
from tianshou.utils.config import tqdm_config
|
||||
from tianshou.utils.logger.base import BaseLogger, LazyLogger
|
||||
from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger
|
||||
from tianshou.utils.logger.wandb import WandbLogger
|
||||
from tianshou.utils.lr_scheduler import MultipleLRSchedulers
|
||||
from tianshou.utils.progress_bar import DummyTqdm, tqdm_config
|
||||
from tianshou.utils.statistics import MovAvg, RunningMeanStd
|
||||
from tianshou.utils.warning import deprecation
|
||||
|
||||
@ -12,6 +12,7 @@ __all__ = [
|
||||
"MovAvg",
|
||||
"RunningMeanStd",
|
||||
"tqdm_config",
|
||||
"DummyTqdm",
|
||||
"BaseLogger",
|
||||
"TensorboardLogger",
|
||||
"BasicLogger",
|
||||
|
@ -1,4 +0,0 @@
|
||||
tqdm_config = {
|
||||
"dynamic_ncols": True,
|
||||
"ascii": True,
|
||||
}
|
35
tianshou/utils/progress_bar.py
Normal file
35
tianshou/utils/progress_bar.py
Normal file
@ -0,0 +1,35 @@
|
||||
from typing import Any
|
||||
|
||||
tqdm_config = {
|
||||
"dynamic_ncols": True,
|
||||
"ascii": True,
|
||||
}
|
||||
|
||||
|
||||
class DummyTqdm:
|
||||
"""A dummy tqdm class that keeps stats but without progress bar.
|
||||
|
||||
It supports ``__enter__`` and ``__exit__``, update and a dummy
|
||||
``set_postfix``, which is the interface that trainers use.
|
||||
|
||||
.. note::
|
||||
|
||||
Using ``disable=True`` in tqdm config results in infinite loop, thus
|
||||
this class is created. See the discussion at #641 for details.
|
||||
"""
|
||||
|
||||
def __init__(self, total: int, **kwargs: Any):
|
||||
self.total = total
|
||||
self.n = 0
|
||||
|
||||
def set_postfix(self, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def update(self, n: int = 1) -> None:
|
||||
self.n += n
|
||||
|
||||
def __enter__(self) -> "DummyTqdm":
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||
pass
|
Loading…
x
Reference in New Issue
Block a user