From 22d7bf38c8cad70be89e44e91666dcc90e193730 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Fri, 24 Sep 2021 19:22:23 +0530 Subject: [PATCH] Improve W&B logger (#441) - rename WandBLogger -> WandbLogger - add save_data and restore_data - allow more input arguments for wandb init - integrate wandb into test/modelbase/test_psrl.py and examples/atari/atari_dqn.py - documentation update --- .github/workflows/extra_sys.yml | 3 ++ .github/workflows/gputest.yml | 3 ++ .github/workflows/pytest.yml | 3 ++ README.md | 20 ++++---- docs/index.rst | 3 +- examples/atari/atari_dqn.py | 34 ++++++++++-- setup.py | 2 +- test/modelbased/test_psrl.py | 27 +++++++--- tianshou/utils/__init__.py | 4 +- tianshou/utils/logger/wandb.py | 91 +++++++++++++++++++++++++++++++-- 10 files changed, 162 insertions(+), 28 deletions(-) diff --git a/.github/workflows/extra_sys.yml b/.github/workflows/extra_sys.yml index 124ec40..df21abc 100644 --- a/.github/workflows/extra_sys.yml +++ b/.github/workflows/extra_sys.yml @@ -22,6 +22,9 @@ jobs: - name: Install dependencies run: | python -m pip install ".[dev]" --upgrade + - name: wandb login + run: | + wandb login e2366d661b89f2bee877c40bee15502d67b7abef - name: Test with pytest run: | pytest test/base test/continuous --cov=tianshou --durations=0 -v diff --git a/.github/workflows/gputest.yml b/.github/workflows/gputest.yml index b973c34..8032bd3 100644 --- a/.github/workflows/gputest.yml +++ b/.github/workflows/gputest.yml @@ -18,6 +18,9 @@ jobs: - name: Install dependencies run: | python -m pip install ".[dev]" --upgrade + - name: wandb login + run: | + wandb login e2366d661b89f2bee877c40bee15502d67b7abef - name: Test with pytest # ignore test/throughput which only profiles the code run: | diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index ac52b19..102219b 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -21,6 +21,9 @@ jobs: - name: Install dependencies run: | python -m pip install ".[dev]" --upgrade + - name: wandb login + run: | + wandb login e2366d661b89f2bee877c40bee15502d67b7abef - name: Test with pytest # ignore test/throughput which only profiles the code run: | diff --git a/README.md b/README.md index 9e2d7a9..40fcdd3 100644 --- a/README.md +++ b/README.md @@ -47,12 +47,13 @@ Here is Tianshou's other features: - Elegant framework, using only ~4000 lines of code - State-of-the-art [MuJoCo benchmark](https://github.com/thu-ml/tianshou/tree/master/examples/mujoco) for REINFORCE/A2C/TRPO/PPO/DDPG/TD3/SAC algorithms -- Support parallel environment simulation (synchronous or asynchronous) for all algorithms [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#parallel-sampling) -- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#rnn-style-training) -- Support any type of environment state/action (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation) -- Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process) +- Support parallel environment simulation (synchronous or asynchronous) for all algorithms [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#parallel-sampling) +- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#rnn-style-training) +- Support any type of environment state/action (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation) +- Support customized training process [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#customize-training-process) - Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation -- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning) +- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html##multi-agent-reinforcement-learning) +- Support both [TensorBoard](https://www.tensorflow.org/tensorboard) and [W&B](https://wandb.ai/) log tools - Comprehensive documentation, PEP8 code-style checking, type checking and [unit tests](https://github.com/thu-ml/tianshou/actions) In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment. @@ -191,8 +192,7 @@ gamma, n_step, target_freq = 0.9, 3, 320 buffer_size = 20000 eps_train, eps_test = 0.1, 0.05 step_per_epoch, step_per_collect = 10000, 10 -writer = SummaryWriter('log/dqn') # tensorboard is also supported! -logger = ts.utils.TensorboardLogger(writer) +logger = ts.utils.TensorboardLogger(SummaryWriter('log/dqn')) # TensorBoard is supported! ``` Make environments: @@ -208,7 +208,7 @@ Define the network: ```python from tianshou.utils.net.common import Net # you can define other net by following the API: -# https://tianshou.readthedocs.io/en/latest/tutorials/dqn.html#build-the-network +# https://tianshou.readthedocs.io/en/master/tutorials/dqn.html#build-the-network env = gym.make(task) state_shape = env.observation_space.shape or env.observation_space.n action_shape = env.action_space.shape or env.action_space.n @@ -273,7 +273,7 @@ $ python3 test/discrete/test_pg.py --seed 0 --render 0.03 ## Contributing -Tianshou is still under development. More algorithms and features are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out [this link](https://tianshou.readthedocs.io/en/latest/contributing.html). +Tianshou is still under development. More algorithms and features are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out [this link](https://tianshou.readthedocs.io/en/master/contributing.html). ## Citing Tianshou @@ -281,7 +281,7 @@ If you find Tianshou useful, please cite it in your publications. ```latex @article{weng2021tianshou, - title={Tianshou: a Highly Modularized Deep Reinforcement Learning Library}, + title={Tianshou: A Highly Modularized Deep Reinforcement Learning Library}, author={Weng, Jiayi and Chen, Huayu and Yan, Dong and You, Kaichao and Duburcq, Alexis and Zhang, Minghao and Su, Hang and Zhu, Jun}, journal={arXiv preprint arXiv:2107.14171}, year={2021} diff --git a/docs/index.rst b/docs/index.rst index 5c33245..a4dae5b 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -44,9 +44,10 @@ Here is Tianshou's other features: * Support :ref:`customize_training` * Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation * Support :doc:`/tutorials/tictactoe` +* Support both `TensorBoard `_ and `W&B `_ log tools * Comprehensive `unit tests `_, including functional checking, RL pipeline checking, documentation checking, PEP8 code-style checking, and type checking -中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ `_ +中文文档位于 `https://tianshou.readthedocs.io/zh/master/ `_ Installation diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 67a44d0..2cd26f8 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import ShmemVectorEnv from tianshou.policy import DQNPolicy from tianshou.trainer import offpolicy_trainer -from tianshou.utils import TensorboardLogger +from tianshou.utils import TensorboardLogger, WandbLogger def get_args(): @@ -41,6 +41,13 @@ def get_args(): ) parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument('--resume-id', type=str, default=None) + parser.add_argument( + '--logger', + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) parser.add_argument( '--watch', default=False, @@ -112,9 +119,18 @@ def test_dqn(args=get_args()): test_collector = Collector(policy, test_envs, exploration_noise=True) # log log_path = os.path.join(args.logdir, args.task, 'dqn') - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - logger = TensorboardLogger(writer) + if args.logger == "tensorboard": + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = TensorboardLogger(writer) + else: + logger = WandbLogger( + save_interval=1, + project=args.task, + name='dqn', + run_id=args.resume_id, + config=args, + ) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -141,6 +157,12 @@ def test_dqn(args=get_args()): def test_fn(epoch, env_step): policy.set_eps(args.eps_test) + def save_checkpoint_fn(epoch, env_step, gradient_step): + # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html + ckpt_path = os.path.join(log_path, 'checkpoint.pth') + torch.save({'model': policy.state_dict()}, ckpt_path) + return ckpt_path + # watch agent's performance def watch(): print("Setup test envs ...") @@ -192,7 +214,9 @@ def test_dqn(args=get_args()): save_fn=save_fn, logger=logger, update_per_step=args.update_per_step, - test_in_train=False + test_in_train=False, + resume_from_log=args.resume_id is not None, + save_checkpoint_fn=save_checkpoint_fn, ) pprint.pprint(result) diff --git a/setup.py b/setup.py index bf48020..80209ed 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ setup( exclude=["test", "test.*", "examples", "examples.*", "docs", "docs.*"] ), install_requires=[ - "gym>=0.15.4", + "gym>=0.15.4,<0.20", "tqdm", "numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793 "tensorboard>=2.5.0", diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index 3a50f36..b6b05b4 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -11,6 +11,7 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import PSRLPolicy from tianshou.trainer import onpolicy_trainer +from tianshou.utils import LazyLogger, TensorboardLogger, WandbLogger def get_args(): @@ -30,6 +31,12 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--eps', type=float, default=0.01) parser.add_argument('--add-done-loop', action="store_true", default=False) + parser.add_argument( + '--logger', + type=str, + default="wandb", + choices=["wandb", "tensorboard", "none"], + ) return parser.parse_known_args()[0] @@ -72,10 +79,18 @@ def test_psrl(args=get_args()): exploration_noise=True ) test_collector = Collector(policy, test_envs) - # log - log_path = os.path.join(args.logdir, args.task, 'psrl') - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) + # Logger + if args.logger == "wandb": + logger = WandbLogger( + save_interval=1, project='psrl', name='wandb_test', config=args + ) + elif args.logger == "tensorboard": + log_path = os.path.join(args.logdir, args.task, 'psrl') + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = TensorboardLogger(writer) + else: + logger = LazyLogger() def stop_fn(mean_rewards): if env.spec.reward_threshold: @@ -96,8 +111,8 @@ def test_psrl(args=get_args()): 0, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, - # logger=logger, - test_in_train=False + logger=logger, + test_in_train=False, ) if __name__ == '__main__': diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index 5af038a..25ceda1 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -3,10 +3,10 @@ from tianshou.utils.config import tqdm_config from tianshou.utils.logger.base import BaseLogger, LazyLogger from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger -from tianshou.utils.logger.wandb import WandBLogger +from tianshou.utils.logger.wandb import WandbLogger from tianshou.utils.statistics import MovAvg, RunningMeanStd __all__ = [ "MovAvg", "RunningMeanStd", "tqdm_config", "BaseLogger", "TensorboardLogger", - "BasicLogger", "LazyLogger", "WandBLogger" + "BasicLogger", "LazyLogger", "WandbLogger" ] diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py index 7a837c9..f9c047c 100644 --- a/tianshou/utils/logger/wandb.py +++ b/tianshou/utils/logger/wandb.py @@ -1,3 +1,7 @@ +import argparse +import os +from typing import Callable, Optional, Tuple + from tianshou.utils import BaseLogger from tianshou.utils.logger.base import LOG_DATA_TYPE @@ -7,10 +11,10 @@ except ImportError: pass -class WandBLogger(BaseLogger): - """Weights and Biases logger that sends data to Weights and Biases. +class WandbLogger(BaseLogger): + """Weights and Biases logger that sends data to https://wandb.ai/. - Creates three panels with plots: train, test, and update. + This logger creates three panels with plots: train, test, and update. Make sure to select the correct access for each panel in weights and biases: - ``train/env_step`` for train plots @@ -29,6 +33,11 @@ class WandBLogger(BaseLogger): :param int test_interval: the log interval in log_test_data(). Default to 1. :param int update_interval: the log interval in log_update_data(). Default to 1000. + :param str project: W&B project name. Default to "tianshou". + :param str name: W&B run name. Default to None. If None, random name is assigned. + :param str entity: W&B team/organization name. Default to None. + :param str run_id: run id of W&B run to be resumed. Default to None. + :param argparse.Namespace config: experiment configurations. Default to None. """ def __init__( @@ -36,9 +45,85 @@ class WandBLogger(BaseLogger): train_interval: int = 1000, test_interval: int = 1, update_interval: int = 1000, + save_interval: int = 1000, + project: str = 'tianshou', + name: Optional[str] = None, + entity: Optional[str] = None, + run_id: Optional[str] = None, + config: Optional[argparse.Namespace] = None, ) -> None: super().__init__(train_interval, test_interval, update_interval) + self.last_save_step = -1 + self.save_interval = save_interval + self.restored = False + + self.wandb_run = wandb.init( + project=project, + name=name, + id=run_id, + resume="allow", + entity=entity, + monitor_gym=True, + config=config, # type: ignore + ) if not wandb.run else wandb.run + self.wandb_run._label(repo="tianshou") # type: ignore def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: data[step_type] = step wandb.log(data) + + def save_data( + self, + epoch: int, + env_step: int, + gradient_step: int, + save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + ) -> None: + """Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer. + + :param int epoch: the epoch in trainer. + :param int env_step: the env_step in trainer. + :param int gradient_step: the gradient_step in trainer. + :param function save_checkpoint_fn: a hook defined by user, see trainer + documentation for detail. + """ + if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval: + self.last_save_step = epoch + checkpoint_path = save_checkpoint_fn(epoch, env_step, gradient_step) + + checkpoint_artifact = wandb.Artifact( + 'run_' + self.wandb_run.id + '_checkpoint', # type: ignore + type='model', + metadata={ + "save/epoch": epoch, + "save/env_step": env_step, + "save/gradient_step": gradient_step, + "checkpoint_path": str(checkpoint_path) + } + ) + checkpoint_artifact.add_file(str(checkpoint_path)) + self.wandb_run.log_artifact(checkpoint_artifact) # type: ignore + + def restore_data(self) -> Tuple[int, int, int]: + checkpoint_artifact = self.wandb_run.use_artifact( # type: ignore + 'run_' + self.wandb_run.id + '_checkpoint:latest' # type: ignore + ) + assert checkpoint_artifact is not None, "W&B dataset artifact doesn't exist" + + checkpoint_artifact.download( + os.path.dirname(checkpoint_artifact.metadata['checkpoint_path']) + ) + + try: # epoch / gradient_step + epoch = checkpoint_artifact.metadata["save/epoch"] + self.last_save_step = self.last_log_test_step = epoch + gradient_step = checkpoint_artifact.metadata["save/gradient_step"] + self.last_log_update_step = gradient_step + except KeyError: + epoch, gradient_step = 0, 0 + try: # offline trainer doesn't have env_step + env_step = checkpoint_artifact.metadata["save/env_step"] + self.last_log_train_step = env_step + except KeyError: + env_step = 0 + return epoch, env_step, gradient_step