Improve W&B logger (#441)

- rename WandBLogger -> WandbLogger
- add save_data and restore_data
- allow more input arguments for wandb init
- integrate wandb into test/modelbase/test_psrl.py and examples/atari/atari_dqn.py
- documentation update
This commit is contained in:
Ayush Chaurasia 2021-09-24 19:22:23 +05:30 committed by GitHub
parent e8f8cdfa41
commit 22d7bf38c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 162 additions and 28 deletions

View File

@ -22,6 +22,9 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install ".[dev]" --upgrade python -m pip install ".[dev]" --upgrade
- name: wandb login
run: |
wandb login e2366d661b89f2bee877c40bee15502d67b7abef
- name: Test with pytest - name: Test with pytest
run: | run: |
pytest test/base test/continuous --cov=tianshou --durations=0 -v pytest test/base test/continuous --cov=tianshou --durations=0 -v

View File

@ -18,6 +18,9 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install ".[dev]" --upgrade python -m pip install ".[dev]" --upgrade
- name: wandb login
run: |
wandb login e2366d661b89f2bee877c40bee15502d67b7abef
- name: Test with pytest - name: Test with pytest
# ignore test/throughput which only profiles the code # ignore test/throughput which only profiles the code
run: | run: |

View File

@ -21,6 +21,9 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install ".[dev]" --upgrade python -m pip install ".[dev]" --upgrade
- name: wandb login
run: |
wandb login e2366d661b89f2bee877c40bee15502d67b7abef
- name: Test with pytest - name: Test with pytest
# ignore test/throughput which only profiles the code # ignore test/throughput which only profiles the code
run: | run: |

View File

@ -47,12 +47,13 @@ Here is Tianshou's other features:
- Elegant framework, using only ~4000 lines of code - Elegant framework, using only ~4000 lines of code
- State-of-the-art [MuJoCo benchmark](https://github.com/thu-ml/tianshou/tree/master/examples/mujoco) for REINFORCE/A2C/TRPO/PPO/DDPG/TD3/SAC algorithms - State-of-the-art [MuJoCo benchmark](https://github.com/thu-ml/tianshou/tree/master/examples/mujoco) for REINFORCE/A2C/TRPO/PPO/DDPG/TD3/SAC algorithms
- Support parallel environment simulation (synchronous or asynchronous) for all algorithms [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#parallel-sampling) - Support parallel environment simulation (synchronous or asynchronous) for all algorithms [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#parallel-sampling)
- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#rnn-style-training) - Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#rnn-style-training)
- Support any type of environment state/action (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation) - Support any type of environment state/action (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process) - Support customized training process [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#customize-training-process)
- Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation - Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning) - Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html##multi-agent-reinforcement-learning)
- Support both [TensorBoard](https://www.tensorflow.org/tensorboard) and [W&B](https://wandb.ai/) log tools
- Comprehensive documentation, PEP8 code-style checking, type checking and [unit tests](https://github.com/thu-ml/tianshou/actions) - Comprehensive documentation, PEP8 code-style checking, type checking and [unit tests](https://github.com/thu-ml/tianshou/actions)
In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment. In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.
@ -191,8 +192,7 @@ gamma, n_step, target_freq = 0.9, 3, 320
buffer_size = 20000 buffer_size = 20000
eps_train, eps_test = 0.1, 0.05 eps_train, eps_test = 0.1, 0.05
step_per_epoch, step_per_collect = 10000, 10 step_per_epoch, step_per_collect = 10000, 10
writer = SummaryWriter('log/dqn') # tensorboard is also supported! logger = ts.utils.TensorboardLogger(SummaryWriter('log/dqn')) # TensorBoard is supported!
logger = ts.utils.TensorboardLogger(writer)
``` ```
Make environments: Make environments:
@ -208,7 +208,7 @@ Define the network:
```python ```python
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
# you can define other net by following the API: # you can define other net by following the API:
# https://tianshou.readthedocs.io/en/latest/tutorials/dqn.html#build-the-network # https://tianshou.readthedocs.io/en/master/tutorials/dqn.html#build-the-network
env = gym.make(task) env = gym.make(task)
state_shape = env.observation_space.shape or env.observation_space.n state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n action_shape = env.action_space.shape or env.action_space.n
@ -273,7 +273,7 @@ $ python3 test/discrete/test_pg.py --seed 0 --render 0.03
## Contributing ## Contributing
Tianshou is still under development. More algorithms and features are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out [this link](https://tianshou.readthedocs.io/en/latest/contributing.html). Tianshou is still under development. More algorithms and features are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out [this link](https://tianshou.readthedocs.io/en/master/contributing.html).
## Citing Tianshou ## Citing Tianshou
@ -281,7 +281,7 @@ If you find Tianshou useful, please cite it in your publications.
```latex ```latex
@article{weng2021tianshou, @article{weng2021tianshou,
title={Tianshou: a Highly Modularized Deep Reinforcement Learning Library}, title={Tianshou: A Highly Modularized Deep Reinforcement Learning Library},
author={Weng, Jiayi and Chen, Huayu and Yan, Dong and You, Kaichao and Duburcq, Alexis and Zhang, Minghao and Su, Hang and Zhu, Jun}, author={Weng, Jiayi and Chen, Huayu and Yan, Dong and You, Kaichao and Duburcq, Alexis and Zhang, Minghao and Su, Hang and Zhu, Jun},
journal={arXiv preprint arXiv:2107.14171}, journal={arXiv preprint arXiv:2107.14171},
year={2021} year={2021}

View File

@ -44,9 +44,10 @@ Here is Tianshou's other features:
* Support :ref:`customize_training` * Support :ref:`customize_training`
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation * Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
* Support :doc:`/tutorials/tictactoe` * Support :doc:`/tutorials/tictactoe`
* Support both `TensorBoard <https://www.tensorflow.org/tensorboard>`_ and `W&B <https://wandb.ai/>`_ log tools
* Comprehensive `unit tests <https://github.com/thu-ml/tianshou/actions>`_, including functional checking, RL pipeline checking, documentation checking, PEP8 code-style checking, and type checking * Comprehensive `unit tests <https://github.com/thu-ml/tianshou/actions>`_, including functional checking, RL pipeline checking, documentation checking, PEP8 code-style checking, and type checking
中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ <https://tianshou.readthedocs.io/zh/latest/>`_ 中文文档位于 `https://tianshou.readthedocs.io/zh/master/ <https://tianshou.readthedocs.io/zh/master/>`_
Installation Installation

View File

@ -12,7 +12,7 @@ from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv from tianshou.env import ShmemVectorEnv
from tianshou.policy import DQNPolicy from tianshou.policy import DQNPolicy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger, WandbLogger
def get_args(): def get_args():
@ -41,6 +41,13 @@ def get_args():
) )
parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None) parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--resume-id', type=str, default=None)
parser.add_argument(
'--logger',
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument( parser.add_argument(
'--watch', '--watch',
default=False, default=False,
@ -112,9 +119,18 @@ def test_dqn(args=get_args()):
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)
# log # log
log_path = os.path.join(args.logdir, args.task, 'dqn') log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path) if args.logger == "tensorboard":
writer.add_text("args", str(args)) writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer) writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
else:
logger = WandbLogger(
save_interval=1,
project=args.task,
name='dqn',
run_id=args.resume_id,
config=args,
)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -141,6 +157,12 @@ def test_dqn(args=get_args()):
def test_fn(epoch, env_step): def test_fn(epoch, env_step):
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)
def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
torch.save({'model': policy.state_dict()}, ckpt_path)
return ckpt_path
# watch agent's performance # watch agent's performance
def watch(): def watch():
print("Setup test envs ...") print("Setup test envs ...")
@ -192,7 +214,9 @@ def test_dqn(args=get_args()):
save_fn=save_fn, save_fn=save_fn,
logger=logger, logger=logger,
update_per_step=args.update_per_step, update_per_step=args.update_per_step,
test_in_train=False test_in_train=False,
resume_from_log=args.resume_id is not None,
save_checkpoint_fn=save_checkpoint_fn,
) )
pprint.pprint(result) pprint.pprint(result)

View File

@ -47,7 +47,7 @@ setup(
exclude=["test", "test.*", "examples", "examples.*", "docs", "docs.*"] exclude=["test", "test.*", "examples", "examples.*", "docs", "docs.*"]
), ),
install_requires=[ install_requires=[
"gym>=0.15.4", "gym>=0.15.4,<0.20",
"tqdm", "tqdm",
"numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793 "numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793
"tensorboard>=2.5.0", "tensorboard>=2.5.0",

View File

@ -11,6 +11,7 @@ from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.env import DummyVectorEnv, SubprocVectorEnv
from tianshou.policy import PSRLPolicy from tianshou.policy import PSRLPolicy
from tianshou.trainer import onpolicy_trainer from tianshou.trainer import onpolicy_trainer
from tianshou.utils import LazyLogger, TensorboardLogger, WandbLogger
def get_args(): def get_args():
@ -30,6 +31,12 @@ def get_args():
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--eps', type=float, default=0.01) parser.add_argument('--eps', type=float, default=0.01)
parser.add_argument('--add-done-loop', action="store_true", default=False) parser.add_argument('--add-done-loop', action="store_true", default=False)
parser.add_argument(
'--logger',
type=str,
default="wandb",
choices=["wandb", "tensorboard", "none"],
)
return parser.parse_known_args()[0] return parser.parse_known_args()[0]
@ -72,10 +79,18 @@ def test_psrl(args=get_args()):
exploration_noise=True exploration_noise=True
) )
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
# log # Logger
log_path = os.path.join(args.logdir, args.task, 'psrl') if args.logger == "wandb":
writer = SummaryWriter(log_path) logger = WandbLogger(
writer.add_text("args", str(args)) save_interval=1, project='psrl', name='wandb_test', config=args
)
elif args.logger == "tensorboard":
log_path = os.path.join(args.logdir, args.task, 'psrl')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
else:
logger = LazyLogger()
def stop_fn(mean_rewards): def stop_fn(mean_rewards):
if env.spec.reward_threshold: if env.spec.reward_threshold:
@ -96,8 +111,8 @@ def test_psrl(args=get_args()):
0, 0,
episode_per_collect=args.episode_per_collect, episode_per_collect=args.episode_per_collect,
stop_fn=stop_fn, stop_fn=stop_fn,
# logger=logger, logger=logger,
test_in_train=False test_in_train=False,
) )
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -3,10 +3,10 @@
from tianshou.utils.config import tqdm_config from tianshou.utils.config import tqdm_config
from tianshou.utils.logger.base import BaseLogger, LazyLogger from tianshou.utils.logger.base import BaseLogger, LazyLogger
from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger
from tianshou.utils.logger.wandb import WandBLogger from tianshou.utils.logger.wandb import WandbLogger
from tianshou.utils.statistics import MovAvg, RunningMeanStd from tianshou.utils.statistics import MovAvg, RunningMeanStd
__all__ = [ __all__ = [
"MovAvg", "RunningMeanStd", "tqdm_config", "BaseLogger", "TensorboardLogger", "MovAvg", "RunningMeanStd", "tqdm_config", "BaseLogger", "TensorboardLogger",
"BasicLogger", "LazyLogger", "WandBLogger" "BasicLogger", "LazyLogger", "WandbLogger"
] ]

View File

@ -1,3 +1,7 @@
import argparse
import os
from typing import Callable, Optional, Tuple
from tianshou.utils import BaseLogger from tianshou.utils import BaseLogger
from tianshou.utils.logger.base import LOG_DATA_TYPE from tianshou.utils.logger.base import LOG_DATA_TYPE
@ -7,10 +11,10 @@ except ImportError:
pass pass
class WandBLogger(BaseLogger): class WandbLogger(BaseLogger):
"""Weights and Biases logger that sends data to Weights and Biases. """Weights and Biases logger that sends data to https://wandb.ai/.
Creates three panels with plots: train, test, and update. This logger creates three panels with plots: train, test, and update.
Make sure to select the correct access for each panel in weights and biases: Make sure to select the correct access for each panel in weights and biases:
- ``train/env_step`` for train plots - ``train/env_step`` for train plots
@ -29,6 +33,11 @@ class WandBLogger(BaseLogger):
:param int test_interval: the log interval in log_test_data(). Default to 1. :param int test_interval: the log interval in log_test_data(). Default to 1.
:param int update_interval: the log interval in log_update_data(). :param int update_interval: the log interval in log_update_data().
Default to 1000. Default to 1000.
:param str project: W&B project name. Default to "tianshou".
:param str name: W&B run name. Default to None. If None, random name is assigned.
:param str entity: W&B team/organization name. Default to None.
:param str run_id: run id of W&B run to be resumed. Default to None.
:param argparse.Namespace config: experiment configurations. Default to None.
""" """
def __init__( def __init__(
@ -36,9 +45,85 @@ class WandBLogger(BaseLogger):
train_interval: int = 1000, train_interval: int = 1000,
test_interval: int = 1, test_interval: int = 1,
update_interval: int = 1000, update_interval: int = 1000,
save_interval: int = 1000,
project: str = 'tianshou',
name: Optional[str] = None,
entity: Optional[str] = None,
run_id: Optional[str] = None,
config: Optional[argparse.Namespace] = None,
) -> None: ) -> None:
super().__init__(train_interval, test_interval, update_interval) super().__init__(train_interval, test_interval, update_interval)
self.last_save_step = -1
self.save_interval = save_interval
self.restored = False
self.wandb_run = wandb.init(
project=project,
name=name,
id=run_id,
resume="allow",
entity=entity,
monitor_gym=True,
config=config, # type: ignore
) if not wandb.run else wandb.run
self.wandb_run._label(repo="tianshou") # type: ignore
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
data[step_type] = step data[step_type] = step
wandb.log(data) wandb.log(data)
def save_data(
self,
epoch: int,
env_step: int,
gradient_step: int,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
) -> None:
"""Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.
:param int epoch: the epoch in trainer.
:param int env_step: the env_step in trainer.
:param int gradient_step: the gradient_step in trainer.
:param function save_checkpoint_fn: a hook defined by user, see trainer
documentation for detail.
"""
if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval:
self.last_save_step = epoch
checkpoint_path = save_checkpoint_fn(epoch, env_step, gradient_step)
checkpoint_artifact = wandb.Artifact(
'run_' + self.wandb_run.id + '_checkpoint', # type: ignore
type='model',
metadata={
"save/epoch": epoch,
"save/env_step": env_step,
"save/gradient_step": gradient_step,
"checkpoint_path": str(checkpoint_path)
}
)
checkpoint_artifact.add_file(str(checkpoint_path))
self.wandb_run.log_artifact(checkpoint_artifact) # type: ignore
def restore_data(self) -> Tuple[int, int, int]:
checkpoint_artifact = self.wandb_run.use_artifact( # type: ignore
'run_' + self.wandb_run.id + '_checkpoint:latest' # type: ignore
)
assert checkpoint_artifact is not None, "W&B dataset artifact doesn't exist"
checkpoint_artifact.download(
os.path.dirname(checkpoint_artifact.metadata['checkpoint_path'])
)
try: # epoch / gradient_step
epoch = checkpoint_artifact.metadata["save/epoch"]
self.last_save_step = self.last_log_test_step = epoch
gradient_step = checkpoint_artifact.metadata["save/gradient_step"]
self.last_log_update_step = gradient_step
except KeyError:
epoch, gradient_step = 0, 0
try: # offline trainer doesn't have env_step
env_step = checkpoint_artifact.metadata["save/env_step"]
self.last_log_train_step = env_step
except KeyError:
env_step = 0
return epoch, env_step, gradient_step