Improve W&B logger (#441)

- rename WandBLogger -> WandbLogger
- add save_data and restore_data
- allow more input arguments for wandb init
- integrate wandb into test/modelbase/test_psrl.py and examples/atari/atari_dqn.py
- documentation update
This commit is contained in:
Ayush Chaurasia 2021-09-24 19:22:23 +05:30 committed by GitHub
parent e8f8cdfa41
commit 22d7bf38c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 162 additions and 28 deletions

View File

@ -22,6 +22,9 @@ jobs:
- name: Install dependencies
run: |
python -m pip install ".[dev]" --upgrade
- name: wandb login
run: |
wandb login e2366d661b89f2bee877c40bee15502d67b7abef
- name: Test with pytest
run: |
pytest test/base test/continuous --cov=tianshou --durations=0 -v

View File

@ -18,6 +18,9 @@ jobs:
- name: Install dependencies
run: |
python -m pip install ".[dev]" --upgrade
- name: wandb login
run: |
wandb login e2366d661b89f2bee877c40bee15502d67b7abef
- name: Test with pytest
# ignore test/throughput which only profiles the code
run: |

View File

@ -21,6 +21,9 @@ jobs:
- name: Install dependencies
run: |
python -m pip install ".[dev]" --upgrade
- name: wandb login
run: |
wandb login e2366d661b89f2bee877c40bee15502d67b7abef
- name: Test with pytest
# ignore test/throughput which only profiles the code
run: |

View File

@ -47,12 +47,13 @@ Here is Tianshou's other features:
- Elegant framework, using only ~4000 lines of code
- State-of-the-art [MuJoCo benchmark](https://github.com/thu-ml/tianshou/tree/master/examples/mujoco) for REINFORCE/A2C/TRPO/PPO/DDPG/TD3/SAC algorithms
- Support parallel environment simulation (synchronous or asynchronous) for all algorithms [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#parallel-sampling)
- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#rnn-style-training)
- Support any type of environment state/action (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process)
- Support parallel environment simulation (synchronous or asynchronous) for all algorithms [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#parallel-sampling)
- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#rnn-style-training)
- Support any type of environment state/action (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#customize-training-process)
- Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning)
- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html##multi-agent-reinforcement-learning)
- Support both [TensorBoard](https://www.tensorflow.org/tensorboard) and [W&B](https://wandb.ai/) log tools
- Comprehensive documentation, PEP8 code-style checking, type checking and [unit tests](https://github.com/thu-ml/tianshou/actions)
In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.
@ -191,8 +192,7 @@ gamma, n_step, target_freq = 0.9, 3, 320
buffer_size = 20000
eps_train, eps_test = 0.1, 0.05
step_per_epoch, step_per_collect = 10000, 10
writer = SummaryWriter('log/dqn') # tensorboard is also supported!
logger = ts.utils.TensorboardLogger(writer)
logger = ts.utils.TensorboardLogger(SummaryWriter('log/dqn')) # TensorBoard is supported!
```
Make environments:
@ -208,7 +208,7 @@ Define the network:
```python
from tianshou.utils.net.common import Net
# you can define other net by following the API:
# https://tianshou.readthedocs.io/en/latest/tutorials/dqn.html#build-the-network
# https://tianshou.readthedocs.io/en/master/tutorials/dqn.html#build-the-network
env = gym.make(task)
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
@ -273,7 +273,7 @@ $ python3 test/discrete/test_pg.py --seed 0 --render 0.03
## Contributing
Tianshou is still under development. More algorithms and features are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out [this link](https://tianshou.readthedocs.io/en/latest/contributing.html).
Tianshou is still under development. More algorithms and features are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out [this link](https://tianshou.readthedocs.io/en/master/contributing.html).
## Citing Tianshou
@ -281,7 +281,7 @@ If you find Tianshou useful, please cite it in your publications.
```latex
@article{weng2021tianshou,
title={Tianshou: a Highly Modularized Deep Reinforcement Learning Library},
title={Tianshou: A Highly Modularized Deep Reinforcement Learning Library},
author={Weng, Jiayi and Chen, Huayu and Yan, Dong and You, Kaichao and Duburcq, Alexis and Zhang, Minghao and Su, Hang and Zhu, Jun},
journal={arXiv preprint arXiv:2107.14171},
year={2021}

View File

@ -44,9 +44,10 @@ Here is Tianshou's other features:
* Support :ref:`customize_training`
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
* Support :doc:`/tutorials/tictactoe`
* Support both `TensorBoard <https://www.tensorflow.org/tensorboard>`_ and `W&B <https://wandb.ai/>`_ log tools
* Comprehensive `unit tests <https://github.com/thu-ml/tianshou/actions>`_, including functional checking, RL pipeline checking, documentation checking, PEP8 code-style checking, and type checking
中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ <https://tianshou.readthedocs.io/zh/latest/>`_
中文文档位于 `https://tianshou.readthedocs.io/zh/master/ <https://tianshou.readthedocs.io/zh/master/>`_
Installation

View File

@ -12,7 +12,7 @@ from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv
from tianshou.policy import DQNPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils import TensorboardLogger, WandbLogger
def get_args():
@ -41,6 +41,13 @@ def get_args():
)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--resume-id', type=str, default=None)
parser.add_argument(
'--logger',
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument(
'--watch',
default=False,
@ -112,9 +119,18 @@ def test_dqn(args=get_args()):
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_path = os.path.join(args.logdir, args.task, 'dqn')
if args.logger == "tensorboard":
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
else:
logger = WandbLogger(
save_interval=1,
project=args.task,
name='dqn',
run_id=args.resume_id,
config=args,
)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -141,6 +157,12 @@ def test_dqn(args=get_args()):
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
torch.save({'model': policy.state_dict()}, ckpt_path)
return ckpt_path
# watch agent's performance
def watch():
print("Setup test envs ...")
@ -192,7 +214,9 @@ def test_dqn(args=get_args()):
save_fn=save_fn,
logger=logger,
update_per_step=args.update_per_step,
test_in_train=False
test_in_train=False,
resume_from_log=args.resume_id is not None,
save_checkpoint_fn=save_checkpoint_fn,
)
pprint.pprint(result)

View File

@ -47,7 +47,7 @@ setup(
exclude=["test", "test.*", "examples", "examples.*", "docs", "docs.*"]
),
install_requires=[
"gym>=0.15.4",
"gym>=0.15.4,<0.20",
"tqdm",
"numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793
"tensorboard>=2.5.0",

View File

@ -11,6 +11,7 @@ from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
from tianshou.policy import PSRLPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.utils import LazyLogger, TensorboardLogger, WandbLogger
def get_args():
@ -30,6 +31,12 @@ def get_args():
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--eps', type=float, default=0.01)
parser.add_argument('--add-done-loop', action="store_true", default=False)
parser.add_argument(
'--logger',
type=str,
default="wandb",
choices=["wandb", "tensorboard", "none"],
)
return parser.parse_known_args()[0]
@ -72,10 +79,18 @@ def test_psrl(args=get_args()):
exploration_noise=True
)
test_collector = Collector(policy, test_envs)
# log
# Logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1, project='psrl', name='wandb_test', config=args
)
elif args.logger == "tensorboard":
log_path = os.path.join(args.logdir, args.task, 'psrl')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
else:
logger = LazyLogger()
def stop_fn(mean_rewards):
if env.spec.reward_threshold:
@ -96,8 +111,8 @@ def test_psrl(args=get_args()):
0,
episode_per_collect=args.episode_per_collect,
stop_fn=stop_fn,
# logger=logger,
test_in_train=False
logger=logger,
test_in_train=False,
)
if __name__ == '__main__':

View File

@ -3,10 +3,10 @@
from tianshou.utils.config import tqdm_config
from tianshou.utils.logger.base import BaseLogger, LazyLogger
from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger
from tianshou.utils.logger.wandb import WandBLogger
from tianshou.utils.logger.wandb import WandbLogger
from tianshou.utils.statistics import MovAvg, RunningMeanStd
__all__ = [
"MovAvg", "RunningMeanStd", "tqdm_config", "BaseLogger", "TensorboardLogger",
"BasicLogger", "LazyLogger", "WandBLogger"
"BasicLogger", "LazyLogger", "WandbLogger"
]

View File

@ -1,3 +1,7 @@
import argparse
import os
from typing import Callable, Optional, Tuple
from tianshou.utils import BaseLogger
from tianshou.utils.logger.base import LOG_DATA_TYPE
@ -7,10 +11,10 @@ except ImportError:
pass
class WandBLogger(BaseLogger):
"""Weights and Biases logger that sends data to Weights and Biases.
class WandbLogger(BaseLogger):
"""Weights and Biases logger that sends data to https://wandb.ai/.
Creates three panels with plots: train, test, and update.
This logger creates three panels with plots: train, test, and update.
Make sure to select the correct access for each panel in weights and biases:
- ``train/env_step`` for train plots
@ -29,6 +33,11 @@ class WandBLogger(BaseLogger):
:param int test_interval: the log interval in log_test_data(). Default to 1.
:param int update_interval: the log interval in log_update_data().
Default to 1000.
:param str project: W&B project name. Default to "tianshou".
:param str name: W&B run name. Default to None. If None, random name is assigned.
:param str entity: W&B team/organization name. Default to None.
:param str run_id: run id of W&B run to be resumed. Default to None.
:param argparse.Namespace config: experiment configurations. Default to None.
"""
def __init__(
@ -36,9 +45,85 @@ class WandBLogger(BaseLogger):
train_interval: int = 1000,
test_interval: int = 1,
update_interval: int = 1000,
save_interval: int = 1000,
project: str = 'tianshou',
name: Optional[str] = None,
entity: Optional[str] = None,
run_id: Optional[str] = None,
config: Optional[argparse.Namespace] = None,
) -> None:
super().__init__(train_interval, test_interval, update_interval)
self.last_save_step = -1
self.save_interval = save_interval
self.restored = False
self.wandb_run = wandb.init(
project=project,
name=name,
id=run_id,
resume="allow",
entity=entity,
monitor_gym=True,
config=config, # type: ignore
) if not wandb.run else wandb.run
self.wandb_run._label(repo="tianshou") # type: ignore
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
data[step_type] = step
wandb.log(data)
def save_data(
self,
epoch: int,
env_step: int,
gradient_step: int,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
) -> None:
"""Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.
:param int epoch: the epoch in trainer.
:param int env_step: the env_step in trainer.
:param int gradient_step: the gradient_step in trainer.
:param function save_checkpoint_fn: a hook defined by user, see trainer
documentation for detail.
"""
if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval:
self.last_save_step = epoch
checkpoint_path = save_checkpoint_fn(epoch, env_step, gradient_step)
checkpoint_artifact = wandb.Artifact(
'run_' + self.wandb_run.id + '_checkpoint', # type: ignore
type='model',
metadata={
"save/epoch": epoch,
"save/env_step": env_step,
"save/gradient_step": gradient_step,
"checkpoint_path": str(checkpoint_path)
}
)
checkpoint_artifact.add_file(str(checkpoint_path))
self.wandb_run.log_artifact(checkpoint_artifact) # type: ignore
def restore_data(self) -> Tuple[int, int, int]:
checkpoint_artifact = self.wandb_run.use_artifact( # type: ignore
'run_' + self.wandb_run.id + '_checkpoint:latest' # type: ignore
)
assert checkpoint_artifact is not None, "W&B dataset artifact doesn't exist"
checkpoint_artifact.download(
os.path.dirname(checkpoint_artifact.metadata['checkpoint_path'])
)
try: # epoch / gradient_step
epoch = checkpoint_artifact.metadata["save/epoch"]
self.last_save_step = self.last_log_test_step = epoch
gradient_step = checkpoint_artifact.metadata["save/gradient_step"]
self.last_log_update_step = gradient_step
except KeyError:
epoch, gradient_step = 0, 0
try: # offline trainer doesn't have env_step
env_step = checkpoint_artifact.metadata["save/env_step"]
self.last_log_train_step = env_step
except KeyError:
env_step = 0
return epoch, env_step, gradient_step