From c87b9f49bc4f0010087c1d6c53a64b59f212d271 Mon Sep 17 00:00:00 2001 From: Michal Gregor Date: Tue, 17 May 2022 17:41:59 +0200 Subject: [PATCH] Add show_progress option for trainer (#641) - A DummyTqdm class added to utils: it replicates the interface used by trainers, but does not show the progress bar; - Added a show_progress argument to the base trainer: when show_progress == True, dummy_tqdm is used in place of tqdm. --- examples/atari/atari_wrapper.py | 4 ++-- examples/mujoco/mujoco_env.py | 8 ++++---- test/discrete/test_pg.py | 4 ++-- test/offline/test_bcq.py | 2 ++ tianshou/trainer/base.py | 21 ++++++++++++++++++-- tianshou/trainer/offline.py | 4 ++++ tianshou/trainer/offpolicy.py | 4 ++++ tianshou/trainer/onpolicy.py | 4 ++++ tianshou/utils/__init__.py | 3 ++- tianshou/utils/config.py | 4 ---- tianshou/utils/progress_bar.py | 35 +++++++++++++++++++++++++++++++++ 11 files changed, 78 insertions(+), 15 deletions(-) delete mode 100644 tianshou/utils/config.py create mode 100644 tianshou/utils/progress_bar.py diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index 4aca612..96fc6f7 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -8,13 +8,13 @@ import cv2 import gym import numpy as np +from tianshou.env import ShmemVectorEnv + try: import envpool except ImportError: envpool = None -from tianshou.env import ShmemVectorEnv - class NoopResetEnv(gym.Wrapper): """Sample initial states by taking random number of no-ops on reset. diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index e6524a4..e0036bc 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -1,14 +1,14 @@ import warnings +import gym + +from tianshou.env import ShmemVectorEnv, VectorEnvNormObs + try: import envpool except ImportError: envpool = None -import gym - -from tianshou.env import ShmemVectorEnv, VectorEnvNormObs - def make_mujoco_env(task, seed, training_num, test_num, obs_norm): """Wrapper function for Mujoco env. diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 1f5007f..0ea8f74 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -80,7 +80,7 @@ def test_pg(args=get_args()): dist, args.gamma, reward_normalization=args.rew_norm, - action_space=env.action_space + action_space=env.action_space, ) for m in net.modules(): if isinstance(m, torch.nn.Linear): @@ -116,7 +116,7 @@ def test_pg(args=get_args()): episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_best_fn=save_best_fn, - logger=logger + logger=logger, ) assert stop_fn(result['best_reward']) diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index dccb2ec..a0fea4f 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -58,6 +58,7 @@ def get_args(): help='watch the play of pre-trained policy only', ) parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) + parser.add_argument("--show-progress", action="store_true") args = parser.parse_known_args()[0] return args @@ -209,6 +210,7 @@ def test_bcq(args=get_args()): save_best_fn=save_best_fn, stop_fn=stop_fn, logger=logger, + show_progress=args.show_progress, ) assert stop_fn(result['best_reward']) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 60525f6..2504056 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -9,7 +9,14 @@ import tqdm from tianshou.data import Collector, ReplayBuffer from tianshou.policy import BasePolicy from tianshou.trainer.utils import gather_info, test_episode -from tianshou.utils import BaseLogger, LazyLogger, MovAvg, deprecation, tqdm_config +from tianshou.utils import ( + BaseLogger, + DummyTqdm, + LazyLogger, + MovAvg, + deprecation, + tqdm_config, +) class BaseTrainer(ABC): @@ -68,6 +75,8 @@ class BaseTrainer(ABC): :param BaseLogger logger: A logger that logs statistics during training/testing/updating. Default to a logger that doesn't log anything. :param bool verbose: whether to print the information. Default to True. + :param bool show_progress: whether to display a progress bar when training. + Default to True. :param bool test_in_train: whether to test in the training phase. Default to True. """ @@ -143,6 +152,7 @@ class BaseTrainer(ABC): reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, logger: BaseLogger = LazyLogger(), verbose: bool = True, + show_progress: bool = True, test_in_train: bool = True, save_fn: Optional[Callable[[BasePolicy], None]] = None, ): @@ -190,6 +200,7 @@ class BaseTrainer(ABC): self.reward_metric = reward_metric self.verbose = verbose + self.show_progress = show_progress self.test_in_train = test_in_train self.resume_from_log = resume_from_log @@ -259,8 +270,14 @@ class BaseTrainer(ABC): self.policy.train() epoch_stat: Dict[str, Any] = dict() + + if self.show_progress: + progress = tqdm.tqdm + else: + progress = DummyTqdm + # perform n step_per_epoch - with tqdm.tqdm( + with progress( total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config ) as t: while t.n < t.total and not self.stop_fn_flag: diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 82d9aa3..e60b909 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -49,6 +49,8 @@ class OfflineTrainer(BaseTrainer): :param BaseLogger logger: A logger that logs statistics during updating/testing. Default to a logger that doesn't log anything. :param bool verbose: whether to print the information. Default to True. + :param bool show_progress: whether to display a progress bar when training. + Default to True. """ __doc__ = BaseTrainer.gen_doc("offline") + "\n".join(__doc__.split("\n")[1:]) @@ -70,6 +72,7 @@ class OfflineTrainer(BaseTrainer): reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, logger: BaseLogger = LazyLogger(), verbose: bool = True, + show_progress: bool = True, **kwargs: Any, ): super().__init__( @@ -90,6 +93,7 @@ class OfflineTrainer(BaseTrainer): reward_metric=reward_metric, logger=logger, verbose=verbose, + show_progress=show_progress, **kwargs, ) diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index e7be852..e6a42ae 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -57,6 +57,8 @@ class OffpolicyTrainer(BaseTrainer): :param BaseLogger logger: A logger that logs statistics during training/testing/updating. Default to a logger that doesn't log anything. :param bool verbose: whether to print the information. Default to True. + :param bool show_progress: whether to display a progress bar when training. + Default to True. :param bool test_in_train: whether to test in the training phase. Default to True. """ @@ -83,6 +85,7 @@ class OffpolicyTrainer(BaseTrainer): reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, logger: BaseLogger = LazyLogger(), verbose: bool = True, + show_progress: bool = True, test_in_train: bool = True, **kwargs: Any, ): @@ -106,6 +109,7 @@ class OffpolicyTrainer(BaseTrainer): reward_metric=reward_metric, logger=logger, verbose=verbose, + show_progress=show_progress, test_in_train=test_in_train, **kwargs, ) diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index a2234e7..0739cd9 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -60,6 +60,8 @@ class OnpolicyTrainer(BaseTrainer): :param BaseLogger logger: A logger that logs statistics during training/testing/updating. Default to a logger that doesn't log anything. :param bool verbose: whether to print the information. Default to True. + :param bool show_progress: whether to display a progress bar when training. + Default to True. :param bool test_in_train: whether to test in the training phase. Default to True. @@ -91,6 +93,7 @@ class OnpolicyTrainer(BaseTrainer): reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, logger: BaseLogger = LazyLogger(), verbose: bool = True, + show_progress: bool = True, test_in_train: bool = True, **kwargs: Any, ): @@ -115,6 +118,7 @@ class OnpolicyTrainer(BaseTrainer): reward_metric=reward_metric, logger=logger, verbose=verbose, + show_progress=show_progress, test_in_train=test_in_train, **kwargs, ) diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index 8acd06c..78eb29d 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -1,10 +1,10 @@ """Utils package.""" -from tianshou.utils.config import tqdm_config from tianshou.utils.logger.base import BaseLogger, LazyLogger from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger from tianshou.utils.logger.wandb import WandbLogger from tianshou.utils.lr_scheduler import MultipleLRSchedulers +from tianshou.utils.progress_bar import DummyTqdm, tqdm_config from tianshou.utils.statistics import MovAvg, RunningMeanStd from tianshou.utils.warning import deprecation @@ -12,6 +12,7 @@ __all__ = [ "MovAvg", "RunningMeanStd", "tqdm_config", + "DummyTqdm", "BaseLogger", "TensorboardLogger", "BasicLogger", diff --git a/tianshou/utils/config.py b/tianshou/utils/config.py deleted file mode 100644 index ca6e91f..0000000 --- a/tianshou/utils/config.py +++ /dev/null @@ -1,4 +0,0 @@ -tqdm_config = { - "dynamic_ncols": True, - "ascii": True, -} diff --git a/tianshou/utils/progress_bar.py b/tianshou/utils/progress_bar.py new file mode 100644 index 0000000..dc3cd03 --- /dev/null +++ b/tianshou/utils/progress_bar.py @@ -0,0 +1,35 @@ +from typing import Any + +tqdm_config = { + "dynamic_ncols": True, + "ascii": True, +} + + +class DummyTqdm: + """A dummy tqdm class that keeps stats but without progress bar. + + It supports ``__enter__`` and ``__exit__``, update and a dummy + ``set_postfix``, which is the interface that trainers use. + + .. note:: + + Using ``disable=True`` in tqdm config results in infinite loop, thus + this class is created. See the discussion at #641 for details. + """ + + def __init__(self, total: int, **kwargs: Any): + self.total = total + self.n = 0 + + def set_postfix(self, **kwargs: Any) -> None: + pass + + def update(self, n: int = 1) -> None: + self.n += n + + def __enter__(self) -> "DummyTqdm": + return self + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + pass