Add show_progress option for trainer (#641)

- A DummyTqdm class added to utils: it replicates the interface used by trainers, but does not show the progress bar;
- Added a show_progress argument to the base trainer: when show_progress == True, dummy_tqdm is used in place of tqdm.
This commit is contained in:
Michal Gregor 2022-05-17 17:41:59 +02:00 committed by GitHub
parent 53e6b0408d
commit c87b9f49bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 78 additions and 15 deletions

View File

@ -8,13 +8,13 @@ import cv2
import gym import gym
import numpy as np import numpy as np
from tianshou.env import ShmemVectorEnv
try: try:
import envpool import envpool
except ImportError: except ImportError:
envpool = None envpool = None
from tianshou.env import ShmemVectorEnv
class NoopResetEnv(gym.Wrapper): class NoopResetEnv(gym.Wrapper):
"""Sample initial states by taking random number of no-ops on reset. """Sample initial states by taking random number of no-ops on reset.

View File

@ -1,14 +1,14 @@
import warnings import warnings
import gym
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
try: try:
import envpool import envpool
except ImportError: except ImportError:
envpool = None envpool = None
import gym
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
def make_mujoco_env(task, seed, training_num, test_num, obs_norm): def make_mujoco_env(task, seed, training_num, test_num, obs_norm):
"""Wrapper function for Mujoco env. """Wrapper function for Mujoco env.

View File

@ -80,7 +80,7 @@ def test_pg(args=get_args()):
dist, dist,
args.gamma, args.gamma,
reward_normalization=args.rew_norm, reward_normalization=args.rew_norm,
action_space=env.action_space action_space=env.action_space,
) )
for m in net.modules(): for m in net.modules():
if isinstance(m, torch.nn.Linear): if isinstance(m, torch.nn.Linear):
@ -116,7 +116,7 @@ def test_pg(args=get_args()):
episode_per_collect=args.episode_per_collect, episode_per_collect=args.episode_per_collect,
stop_fn=stop_fn, stop_fn=stop_fn,
save_best_fn=save_best_fn, save_best_fn=save_best_fn,
logger=logger logger=logger,
) )
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])

View File

@ -58,6 +58,7 @@ def get_args():
help='watch the play of pre-trained policy only', help='watch the play of pre-trained policy only',
) )
parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) parser.add_argument("--load-buffer-name", type=str, default=expert_file_name())
parser.add_argument("--show-progress", action="store_true")
args = parser.parse_known_args()[0] args = parser.parse_known_args()[0]
return args return args
@ -209,6 +210,7 @@ def test_bcq(args=get_args()):
save_best_fn=save_best_fn, save_best_fn=save_best_fn,
stop_fn=stop_fn, stop_fn=stop_fn,
logger=logger, logger=logger,
show_progress=args.show_progress,
) )
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])

View File

@ -9,7 +9,14 @@ import tqdm
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.trainer.utils import gather_info, test_episode from tianshou.trainer.utils import gather_info, test_episode
from tianshou.utils import BaseLogger, LazyLogger, MovAvg, deprecation, tqdm_config from tianshou.utils import (
BaseLogger,
DummyTqdm,
LazyLogger,
MovAvg,
deprecation,
tqdm_config,
)
class BaseTrainer(ABC): class BaseTrainer(ABC):
@ -68,6 +75,8 @@ class BaseTrainer(ABC):
:param BaseLogger logger: A logger that logs statistics during :param BaseLogger logger: A logger that logs statistics during
training/testing/updating. Default to a logger that doesn't log anything. training/testing/updating. Default to a logger that doesn't log anything.
:param bool verbose: whether to print the information. Default to True. :param bool verbose: whether to print the information. Default to True.
:param bool show_progress: whether to display a progress bar when training.
Default to True.
:param bool test_in_train: whether to test in the training phase. :param bool test_in_train: whether to test in the training phase.
Default to True. Default to True.
""" """
@ -143,6 +152,7 @@ class BaseTrainer(ABC):
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(), logger: BaseLogger = LazyLogger(),
verbose: bool = True, verbose: bool = True,
show_progress: bool = True,
test_in_train: bool = True, test_in_train: bool = True,
save_fn: Optional[Callable[[BasePolicy], None]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None,
): ):
@ -190,6 +200,7 @@ class BaseTrainer(ABC):
self.reward_metric = reward_metric self.reward_metric = reward_metric
self.verbose = verbose self.verbose = verbose
self.show_progress = show_progress
self.test_in_train = test_in_train self.test_in_train = test_in_train
self.resume_from_log = resume_from_log self.resume_from_log = resume_from_log
@ -259,8 +270,14 @@ class BaseTrainer(ABC):
self.policy.train() self.policy.train()
epoch_stat: Dict[str, Any] = dict() epoch_stat: Dict[str, Any] = dict()
if self.show_progress:
progress = tqdm.tqdm
else:
progress = DummyTqdm
# perform n step_per_epoch # perform n step_per_epoch
with tqdm.tqdm( with progress(
total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config
) as t: ) as t:
while t.n < t.total and not self.stop_fn_flag: while t.n < t.total and not self.stop_fn_flag:

View File

@ -49,6 +49,8 @@ class OfflineTrainer(BaseTrainer):
:param BaseLogger logger: A logger that logs statistics during :param BaseLogger logger: A logger that logs statistics during
updating/testing. Default to a logger that doesn't log anything. updating/testing. Default to a logger that doesn't log anything.
:param bool verbose: whether to print the information. Default to True. :param bool verbose: whether to print the information. Default to True.
:param bool show_progress: whether to display a progress bar when training.
Default to True.
""" """
__doc__ = BaseTrainer.gen_doc("offline") + "\n".join(__doc__.split("\n")[1:]) __doc__ = BaseTrainer.gen_doc("offline") + "\n".join(__doc__.split("\n")[1:])
@ -70,6 +72,7 @@ class OfflineTrainer(BaseTrainer):
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(), logger: BaseLogger = LazyLogger(),
verbose: bool = True, verbose: bool = True,
show_progress: bool = True,
**kwargs: Any, **kwargs: Any,
): ):
super().__init__( super().__init__(
@ -90,6 +93,7 @@ class OfflineTrainer(BaseTrainer):
reward_metric=reward_metric, reward_metric=reward_metric,
logger=logger, logger=logger,
verbose=verbose, verbose=verbose,
show_progress=show_progress,
**kwargs, **kwargs,
) )

View File

@ -57,6 +57,8 @@ class OffpolicyTrainer(BaseTrainer):
:param BaseLogger logger: A logger that logs statistics during :param BaseLogger logger: A logger that logs statistics during
training/testing/updating. Default to a logger that doesn't log anything. training/testing/updating. Default to a logger that doesn't log anything.
:param bool verbose: whether to print the information. Default to True. :param bool verbose: whether to print the information. Default to True.
:param bool show_progress: whether to display a progress bar when training.
Default to True.
:param bool test_in_train: whether to test in the training phase. :param bool test_in_train: whether to test in the training phase.
Default to True. Default to True.
""" """
@ -83,6 +85,7 @@ class OffpolicyTrainer(BaseTrainer):
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(), logger: BaseLogger = LazyLogger(),
verbose: bool = True, verbose: bool = True,
show_progress: bool = True,
test_in_train: bool = True, test_in_train: bool = True,
**kwargs: Any, **kwargs: Any,
): ):
@ -106,6 +109,7 @@ class OffpolicyTrainer(BaseTrainer):
reward_metric=reward_metric, reward_metric=reward_metric,
logger=logger, logger=logger,
verbose=verbose, verbose=verbose,
show_progress=show_progress,
test_in_train=test_in_train, test_in_train=test_in_train,
**kwargs, **kwargs,
) )

View File

@ -60,6 +60,8 @@ class OnpolicyTrainer(BaseTrainer):
:param BaseLogger logger: A logger that logs statistics during :param BaseLogger logger: A logger that logs statistics during
training/testing/updating. Default to a logger that doesn't log anything. training/testing/updating. Default to a logger that doesn't log anything.
:param bool verbose: whether to print the information. Default to True. :param bool verbose: whether to print the information. Default to True.
:param bool show_progress: whether to display a progress bar when training.
Default to True.
:param bool test_in_train: whether to test in the training phase. Default to :param bool test_in_train: whether to test in the training phase. Default to
True. True.
@ -91,6 +93,7 @@ class OnpolicyTrainer(BaseTrainer):
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(), logger: BaseLogger = LazyLogger(),
verbose: bool = True, verbose: bool = True,
show_progress: bool = True,
test_in_train: bool = True, test_in_train: bool = True,
**kwargs: Any, **kwargs: Any,
): ):
@ -115,6 +118,7 @@ class OnpolicyTrainer(BaseTrainer):
reward_metric=reward_metric, reward_metric=reward_metric,
logger=logger, logger=logger,
verbose=verbose, verbose=verbose,
show_progress=show_progress,
test_in_train=test_in_train, test_in_train=test_in_train,
**kwargs, **kwargs,
) )

View File

@ -1,10 +1,10 @@
"""Utils package.""" """Utils package."""
from tianshou.utils.config import tqdm_config
from tianshou.utils.logger.base import BaseLogger, LazyLogger from tianshou.utils.logger.base import BaseLogger, LazyLogger
from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger
from tianshou.utils.logger.wandb import WandbLogger from tianshou.utils.logger.wandb import WandbLogger
from tianshou.utils.lr_scheduler import MultipleLRSchedulers from tianshou.utils.lr_scheduler import MultipleLRSchedulers
from tianshou.utils.progress_bar import DummyTqdm, tqdm_config
from tianshou.utils.statistics import MovAvg, RunningMeanStd from tianshou.utils.statistics import MovAvg, RunningMeanStd
from tianshou.utils.warning import deprecation from tianshou.utils.warning import deprecation
@ -12,6 +12,7 @@ __all__ = [
"MovAvg", "MovAvg",
"RunningMeanStd", "RunningMeanStd",
"tqdm_config", "tqdm_config",
"DummyTqdm",
"BaseLogger", "BaseLogger",
"TensorboardLogger", "TensorboardLogger",
"BasicLogger", "BasicLogger",

View File

@ -1,4 +0,0 @@
tqdm_config = {
"dynamic_ncols": True,
"ascii": True,
}

View File

@ -0,0 +1,35 @@
from typing import Any
tqdm_config = {
"dynamic_ncols": True,
"ascii": True,
}
class DummyTqdm:
"""A dummy tqdm class that keeps stats but without progress bar.
It supports ``__enter__`` and ``__exit__``, update and a dummy
``set_postfix``, which is the interface that trainers use.
.. note::
Using ``disable=True`` in tqdm config results in infinite loop, thus
this class is created. See the discussion at #641 for details.
"""
def __init__(self, total: int, **kwargs: Any):
self.total = total
self.n = 0
def set_postfix(self, **kwargs: Any) -> None:
pass
def update(self, n: int = 1) -> None:
self.n += n
def __enter__(self) -> "DummyTqdm":
return self
def __exit__(self, *args: Any, **kwargs: Any) -> None:
pass