Add show_progress option for trainer (#641)

- A DummyTqdm class added to utils: it replicates the interface used by trainers, but does not show the progress bar;
- Added a show_progress argument to the base trainer: when show_progress == True, dummy_tqdm is used in place of tqdm.
This commit is contained in:
Michal Gregor 2022-05-17 17:41:59 +02:00 committed by GitHub
parent 53e6b0408d
commit c87b9f49bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 78 additions and 15 deletions

View File

@ -8,13 +8,13 @@ import cv2
import gym
import numpy as np
from tianshou.env import ShmemVectorEnv
try:
import envpool
except ImportError:
envpool = None
from tianshou.env import ShmemVectorEnv
class NoopResetEnv(gym.Wrapper):
"""Sample initial states by taking random number of no-ops on reset.

View File

@ -1,14 +1,14 @@
import warnings
import gym
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
try:
import envpool
except ImportError:
envpool = None
import gym
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
def make_mujoco_env(task, seed, training_num, test_num, obs_norm):
"""Wrapper function for Mujoco env.

View File

@ -80,7 +80,7 @@ def test_pg(args=get_args()):
dist,
args.gamma,
reward_normalization=args.rew_norm,
action_space=env.action_space
action_space=env.action_space,
)
for m in net.modules():
if isinstance(m, torch.nn.Linear):
@ -116,7 +116,7 @@ def test_pg(args=get_args()):
episode_per_collect=args.episode_per_collect,
stop_fn=stop_fn,
save_best_fn=save_best_fn,
logger=logger
logger=logger,
)
assert stop_fn(result['best_reward'])

View File

@ -58,6 +58,7 @@ def get_args():
help='watch the play of pre-trained policy only',
)
parser.add_argument("--load-buffer-name", type=str, default=expert_file_name())
parser.add_argument("--show-progress", action="store_true")
args = parser.parse_known_args()[0]
return args
@ -209,6 +210,7 @@ def test_bcq(args=get_args()):
save_best_fn=save_best_fn,
stop_fn=stop_fn,
logger=logger,
show_progress=args.show_progress,
)
assert stop_fn(result['best_reward'])

View File

@ -9,7 +9,14 @@ import tqdm
from tianshou.data import Collector, ReplayBuffer
from tianshou.policy import BasePolicy
from tianshou.trainer.utils import gather_info, test_episode
from tianshou.utils import BaseLogger, LazyLogger, MovAvg, deprecation, tqdm_config
from tianshou.utils import (
BaseLogger,
DummyTqdm,
LazyLogger,
MovAvg,
deprecation,
tqdm_config,
)
class BaseTrainer(ABC):
@ -68,6 +75,8 @@ class BaseTrainer(ABC):
:param BaseLogger logger: A logger that logs statistics during
training/testing/updating. Default to a logger that doesn't log anything.
:param bool verbose: whether to print the information. Default to True.
:param bool show_progress: whether to display a progress bar when training.
Default to True.
:param bool test_in_train: whether to test in the training phase.
Default to True.
"""
@ -143,6 +152,7 @@ class BaseTrainer(ABC):
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
show_progress: bool = True,
test_in_train: bool = True,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
):
@ -190,6 +200,7 @@ class BaseTrainer(ABC):
self.reward_metric = reward_metric
self.verbose = verbose
self.show_progress = show_progress
self.test_in_train = test_in_train
self.resume_from_log = resume_from_log
@ -259,8 +270,14 @@ class BaseTrainer(ABC):
self.policy.train()
epoch_stat: Dict[str, Any] = dict()
if self.show_progress:
progress = tqdm.tqdm
else:
progress = DummyTqdm
# perform n step_per_epoch
with tqdm.tqdm(
with progress(
total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config
) as t:
while t.n < t.total and not self.stop_fn_flag:

View File

@ -49,6 +49,8 @@ class OfflineTrainer(BaseTrainer):
:param BaseLogger logger: A logger that logs statistics during
updating/testing. Default to a logger that doesn't log anything.
:param bool verbose: whether to print the information. Default to True.
:param bool show_progress: whether to display a progress bar when training.
Default to True.
"""
__doc__ = BaseTrainer.gen_doc("offline") + "\n".join(__doc__.split("\n")[1:])
@ -70,6 +72,7 @@ class OfflineTrainer(BaseTrainer):
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
show_progress: bool = True,
**kwargs: Any,
):
super().__init__(
@ -90,6 +93,7 @@ class OfflineTrainer(BaseTrainer):
reward_metric=reward_metric,
logger=logger,
verbose=verbose,
show_progress=show_progress,
**kwargs,
)

View File

@ -57,6 +57,8 @@ class OffpolicyTrainer(BaseTrainer):
:param BaseLogger logger: A logger that logs statistics during
training/testing/updating. Default to a logger that doesn't log anything.
:param bool verbose: whether to print the information. Default to True.
:param bool show_progress: whether to display a progress bar when training.
Default to True.
:param bool test_in_train: whether to test in the training phase.
Default to True.
"""
@ -83,6 +85,7 @@ class OffpolicyTrainer(BaseTrainer):
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
show_progress: bool = True,
test_in_train: bool = True,
**kwargs: Any,
):
@ -106,6 +109,7 @@ class OffpolicyTrainer(BaseTrainer):
reward_metric=reward_metric,
logger=logger,
verbose=verbose,
show_progress=show_progress,
test_in_train=test_in_train,
**kwargs,
)

View File

@ -60,6 +60,8 @@ class OnpolicyTrainer(BaseTrainer):
:param BaseLogger logger: A logger that logs statistics during
training/testing/updating. Default to a logger that doesn't log anything.
:param bool verbose: whether to print the information. Default to True.
:param bool show_progress: whether to display a progress bar when training.
Default to True.
:param bool test_in_train: whether to test in the training phase. Default to
True.
@ -91,6 +93,7 @@ class OnpolicyTrainer(BaseTrainer):
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
show_progress: bool = True,
test_in_train: bool = True,
**kwargs: Any,
):
@ -115,6 +118,7 @@ class OnpolicyTrainer(BaseTrainer):
reward_metric=reward_metric,
logger=logger,
verbose=verbose,
show_progress=show_progress,
test_in_train=test_in_train,
**kwargs,
)

View File

@ -1,10 +1,10 @@
"""Utils package."""
from tianshou.utils.config import tqdm_config
from tianshou.utils.logger.base import BaseLogger, LazyLogger
from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger
from tianshou.utils.logger.wandb import WandbLogger
from tianshou.utils.lr_scheduler import MultipleLRSchedulers
from tianshou.utils.progress_bar import DummyTqdm, tqdm_config
from tianshou.utils.statistics import MovAvg, RunningMeanStd
from tianshou.utils.warning import deprecation
@ -12,6 +12,7 @@ __all__ = [
"MovAvg",
"RunningMeanStd",
"tqdm_config",
"DummyTqdm",
"BaseLogger",
"TensorboardLogger",
"BasicLogger",

View File

@ -1,4 +0,0 @@
tqdm_config = {
"dynamic_ncols": True,
"ascii": True,
}

View File

@ -0,0 +1,35 @@
from typing import Any
tqdm_config = {
"dynamic_ncols": True,
"ascii": True,
}
class DummyTqdm:
"""A dummy tqdm class that keeps stats but without progress bar.
It supports ``__enter__`` and ``__exit__``, update and a dummy
``set_postfix``, which is the interface that trainers use.
.. note::
Using ``disable=True`` in tqdm config results in infinite loop, thus
this class is created. See the discussion at #641 for details.
"""
def __init__(self, total: int, **kwargs: Any):
self.total = total
self.n = 0
def set_postfix(self, **kwargs: Any) -> None:
pass
def update(self, n: int = 1) -> None:
self.n += n
def __enter__(self) -> "DummyTqdm":
return self
def __exit__(self, *args: Any, **kwargs: Any) -> None:
pass