Improve W&B logger (#441)
- rename WandBLogger -> WandbLogger - add save_data and restore_data - allow more input arguments for wandb init - integrate wandb into test/modelbase/test_psrl.py and examples/atari/atari_dqn.py - documentation update
This commit is contained in:
parent
e8f8cdfa41
commit
22d7bf38c8
3
.github/workflows/extra_sys.yml
vendored
3
.github/workflows/extra_sys.yml
vendored
@ -22,6 +22,9 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install ".[dev]" --upgrade
|
||||
- name: wandb login
|
||||
run: |
|
||||
wandb login e2366d661b89f2bee877c40bee15502d67b7abef
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest test/base test/continuous --cov=tianshou --durations=0 -v
|
||||
|
3
.github/workflows/gputest.yml
vendored
3
.github/workflows/gputest.yml
vendored
@ -18,6 +18,9 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install ".[dev]" --upgrade
|
||||
- name: wandb login
|
||||
run: |
|
||||
wandb login e2366d661b89f2bee877c40bee15502d67b7abef
|
||||
- name: Test with pytest
|
||||
# ignore test/throughput which only profiles the code
|
||||
run: |
|
||||
|
3
.github/workflows/pytest.yml
vendored
3
.github/workflows/pytest.yml
vendored
@ -21,6 +21,9 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install ".[dev]" --upgrade
|
||||
- name: wandb login
|
||||
run: |
|
||||
wandb login e2366d661b89f2bee877c40bee15502d67b7abef
|
||||
- name: Test with pytest
|
||||
# ignore test/throughput which only profiles the code
|
||||
run: |
|
||||
|
20
README.md
20
README.md
@ -47,12 +47,13 @@ Here is Tianshou's other features:
|
||||
|
||||
- Elegant framework, using only ~4000 lines of code
|
||||
- State-of-the-art [MuJoCo benchmark](https://github.com/thu-ml/tianshou/tree/master/examples/mujoco) for REINFORCE/A2C/TRPO/PPO/DDPG/TD3/SAC algorithms
|
||||
- Support parallel environment simulation (synchronous or asynchronous) for all algorithms [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#parallel-sampling)
|
||||
- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#rnn-style-training)
|
||||
- Support any type of environment state/action (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
|
||||
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process)
|
||||
- Support parallel environment simulation (synchronous or asynchronous) for all algorithms [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#parallel-sampling)
|
||||
- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#rnn-style-training)
|
||||
- Support any type of environment state/action (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
|
||||
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#customize-training-process)
|
||||
- Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
|
||||
- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning)
|
||||
- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html##multi-agent-reinforcement-learning)
|
||||
- Support both [TensorBoard](https://www.tensorflow.org/tensorboard) and [W&B](https://wandb.ai/) log tools
|
||||
- Comprehensive documentation, PEP8 code-style checking, type checking and [unit tests](https://github.com/thu-ml/tianshou/actions)
|
||||
|
||||
In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.
|
||||
@ -191,8 +192,7 @@ gamma, n_step, target_freq = 0.9, 3, 320
|
||||
buffer_size = 20000
|
||||
eps_train, eps_test = 0.1, 0.05
|
||||
step_per_epoch, step_per_collect = 10000, 10
|
||||
writer = SummaryWriter('log/dqn') # tensorboard is also supported!
|
||||
logger = ts.utils.TensorboardLogger(writer)
|
||||
logger = ts.utils.TensorboardLogger(SummaryWriter('log/dqn')) # TensorBoard is supported!
|
||||
```
|
||||
|
||||
Make environments:
|
||||
@ -208,7 +208,7 @@ Define the network:
|
||||
```python
|
||||
from tianshou.utils.net.common import Net
|
||||
# you can define other net by following the API:
|
||||
# https://tianshou.readthedocs.io/en/latest/tutorials/dqn.html#build-the-network
|
||||
# https://tianshou.readthedocs.io/en/master/tutorials/dqn.html#build-the-network
|
||||
env = gym.make(task)
|
||||
state_shape = env.observation_space.shape or env.observation_space.n
|
||||
action_shape = env.action_space.shape or env.action_space.n
|
||||
@ -273,7 +273,7 @@ $ python3 test/discrete/test_pg.py --seed 0 --render 0.03
|
||||
|
||||
## Contributing
|
||||
|
||||
Tianshou is still under development. More algorithms and features are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out [this link](https://tianshou.readthedocs.io/en/latest/contributing.html).
|
||||
Tianshou is still under development. More algorithms and features are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out [this link](https://tianshou.readthedocs.io/en/master/contributing.html).
|
||||
|
||||
## Citing Tianshou
|
||||
|
||||
@ -281,7 +281,7 @@ If you find Tianshou useful, please cite it in your publications.
|
||||
|
||||
```latex
|
||||
@article{weng2021tianshou,
|
||||
title={Tianshou: a Highly Modularized Deep Reinforcement Learning Library},
|
||||
title={Tianshou: A Highly Modularized Deep Reinforcement Learning Library},
|
||||
author={Weng, Jiayi and Chen, Huayu and Yan, Dong and You, Kaichao and Duburcq, Alexis and Zhang, Minghao and Su, Hang and Zhu, Jun},
|
||||
journal={arXiv preprint arXiv:2107.14171},
|
||||
year={2021}
|
||||
|
@ -44,9 +44,10 @@ Here is Tianshou's other features:
|
||||
* Support :ref:`customize_training`
|
||||
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
|
||||
* Support :doc:`/tutorials/tictactoe`
|
||||
* Support both `TensorBoard <https://www.tensorflow.org/tensorboard>`_ and `W&B <https://wandb.ai/>`_ log tools
|
||||
* Comprehensive `unit tests <https://github.com/thu-ml/tianshou/actions>`_, including functional checking, RL pipeline checking, documentation checking, PEP8 code-style checking, and type checking
|
||||
|
||||
中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ <https://tianshou.readthedocs.io/zh/latest/>`_
|
||||
中文文档位于 `https://tianshou.readthedocs.io/zh/master/ <https://tianshou.readthedocs.io/zh/master/>`_
|
||||
|
||||
|
||||
Installation
|
||||
|
@ -12,7 +12,7 @@ from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||
|
||||
|
||||
def get_args():
|
||||
@ -41,6 +41,13 @@ def get_args():
|
||||
)
|
||||
parser.add_argument('--frames-stack', type=int, default=4)
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
parser.add_argument('--resume-id', type=str, default=None)
|
||||
parser.add_argument(
|
||||
'--logger',
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
choices=["tensorboard", "wandb"],
|
||||
)
|
||||
parser.add_argument(
|
||||
'--watch',
|
||||
default=False,
|
||||
@ -112,9 +119,18 @@ def test_dqn(args=get_args()):
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = TensorboardLogger(writer)
|
||||
if args.logger == "tensorboard":
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = TensorboardLogger(writer)
|
||||
else:
|
||||
logger = WandbLogger(
|
||||
save_interval=1,
|
||||
project=args.task,
|
||||
name='dqn',
|
||||
run_id=args.resume_id,
|
||||
config=args,
|
||||
)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
@ -141,6 +157,12 @@ def test_dqn(args=get_args()):
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
||||
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
||||
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
|
||||
torch.save({'model': policy.state_dict()}, ckpt_path)
|
||||
return ckpt_path
|
||||
|
||||
# watch agent's performance
|
||||
def watch():
|
||||
print("Setup test envs ...")
|
||||
@ -192,7 +214,9 @@ def test_dqn(args=get_args()):
|
||||
save_fn=save_fn,
|
||||
logger=logger,
|
||||
update_per_step=args.update_per_step,
|
||||
test_in_train=False
|
||||
test_in_train=False,
|
||||
resume_from_log=args.resume_id is not None,
|
||||
save_checkpoint_fn=save_checkpoint_fn,
|
||||
)
|
||||
|
||||
pprint.pprint(result)
|
||||
|
2
setup.py
2
setup.py
@ -47,7 +47,7 @@ setup(
|
||||
exclude=["test", "test.*", "examples", "examples.*", "docs", "docs.*"]
|
||||
),
|
||||
install_requires=[
|
||||
"gym>=0.15.4",
|
||||
"gym>=0.15.4,<0.20",
|
||||
"tqdm",
|
||||
"numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793
|
||||
"tensorboard>=2.5.0",
|
||||
|
@ -11,6 +11,7 @@ from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
|
||||
from tianshou.policy import PSRLPolicy
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.utils import LazyLogger, TensorboardLogger, WandbLogger
|
||||
|
||||
|
||||
def get_args():
|
||||
@ -30,6 +31,12 @@ def get_args():
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--eps', type=float, default=0.01)
|
||||
parser.add_argument('--add-done-loop', action="store_true", default=False)
|
||||
parser.add_argument(
|
||||
'--logger',
|
||||
type=str,
|
||||
default="wandb",
|
||||
choices=["wandb", "tensorboard", "none"],
|
||||
)
|
||||
return parser.parse_known_args()[0]
|
||||
|
||||
|
||||
@ -72,10 +79,18 @@ def test_psrl(args=get_args()):
|
||||
exploration_noise=True
|
||||
)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'psrl')
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
# Logger
|
||||
if args.logger == "wandb":
|
||||
logger = WandbLogger(
|
||||
save_interval=1, project='psrl', name='wandb_test', config=args
|
||||
)
|
||||
elif args.logger == "tensorboard":
|
||||
log_path = os.path.join(args.logdir, args.task, 'psrl')
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = TensorboardLogger(writer)
|
||||
else:
|
||||
logger = LazyLogger()
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
if env.spec.reward_threshold:
|
||||
@ -96,8 +111,8 @@ def test_psrl(args=get_args()):
|
||||
0,
|
||||
episode_per_collect=args.episode_per_collect,
|
||||
stop_fn=stop_fn,
|
||||
# logger=logger,
|
||||
test_in_train=False
|
||||
logger=logger,
|
||||
test_in_train=False,
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -3,10 +3,10 @@
|
||||
from tianshou.utils.config import tqdm_config
|
||||
from tianshou.utils.logger.base import BaseLogger, LazyLogger
|
||||
from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger
|
||||
from tianshou.utils.logger.wandb import WandBLogger
|
||||
from tianshou.utils.logger.wandb import WandbLogger
|
||||
from tianshou.utils.statistics import MovAvg, RunningMeanStd
|
||||
|
||||
__all__ = [
|
||||
"MovAvg", "RunningMeanStd", "tqdm_config", "BaseLogger", "TensorboardLogger",
|
||||
"BasicLogger", "LazyLogger", "WandBLogger"
|
||||
"BasicLogger", "LazyLogger", "WandbLogger"
|
||||
]
|
||||
|
@ -1,3 +1,7 @@
|
||||
import argparse
|
||||
import os
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
from tianshou.utils import BaseLogger
|
||||
from tianshou.utils.logger.base import LOG_DATA_TYPE
|
||||
|
||||
@ -7,10 +11,10 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class WandBLogger(BaseLogger):
|
||||
"""Weights and Biases logger that sends data to Weights and Biases.
|
||||
class WandbLogger(BaseLogger):
|
||||
"""Weights and Biases logger that sends data to https://wandb.ai/.
|
||||
|
||||
Creates three panels with plots: train, test, and update.
|
||||
This logger creates three panels with plots: train, test, and update.
|
||||
Make sure to select the correct access for each panel in weights and biases:
|
||||
|
||||
- ``train/env_step`` for train plots
|
||||
@ -29,6 +33,11 @@ class WandBLogger(BaseLogger):
|
||||
:param int test_interval: the log interval in log_test_data(). Default to 1.
|
||||
:param int update_interval: the log interval in log_update_data().
|
||||
Default to 1000.
|
||||
:param str project: W&B project name. Default to "tianshou".
|
||||
:param str name: W&B run name. Default to None. If None, random name is assigned.
|
||||
:param str entity: W&B team/organization name. Default to None.
|
||||
:param str run_id: run id of W&B run to be resumed. Default to None.
|
||||
:param argparse.Namespace config: experiment configurations. Default to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -36,9 +45,85 @@ class WandBLogger(BaseLogger):
|
||||
train_interval: int = 1000,
|
||||
test_interval: int = 1,
|
||||
update_interval: int = 1000,
|
||||
save_interval: int = 1000,
|
||||
project: str = 'tianshou',
|
||||
name: Optional[str] = None,
|
||||
entity: Optional[str] = None,
|
||||
run_id: Optional[str] = None,
|
||||
config: Optional[argparse.Namespace] = None,
|
||||
) -> None:
|
||||
super().__init__(train_interval, test_interval, update_interval)
|
||||
self.last_save_step = -1
|
||||
self.save_interval = save_interval
|
||||
self.restored = False
|
||||
|
||||
self.wandb_run = wandb.init(
|
||||
project=project,
|
||||
name=name,
|
||||
id=run_id,
|
||||
resume="allow",
|
||||
entity=entity,
|
||||
monitor_gym=True,
|
||||
config=config, # type: ignore
|
||||
) if not wandb.run else wandb.run
|
||||
self.wandb_run._label(repo="tianshou") # type: ignore
|
||||
|
||||
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
|
||||
data[step_type] = step
|
||||
wandb.log(data)
|
||||
|
||||
def save_data(
|
||||
self,
|
||||
epoch: int,
|
||||
env_step: int,
|
||||
gradient_step: int,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||
) -> None:
|
||||
"""Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.
|
||||
|
||||
:param int epoch: the epoch in trainer.
|
||||
:param int env_step: the env_step in trainer.
|
||||
:param int gradient_step: the gradient_step in trainer.
|
||||
:param function save_checkpoint_fn: a hook defined by user, see trainer
|
||||
documentation for detail.
|
||||
"""
|
||||
if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval:
|
||||
self.last_save_step = epoch
|
||||
checkpoint_path = save_checkpoint_fn(epoch, env_step, gradient_step)
|
||||
|
||||
checkpoint_artifact = wandb.Artifact(
|
||||
'run_' + self.wandb_run.id + '_checkpoint', # type: ignore
|
||||
type='model',
|
||||
metadata={
|
||||
"save/epoch": epoch,
|
||||
"save/env_step": env_step,
|
||||
"save/gradient_step": gradient_step,
|
||||
"checkpoint_path": str(checkpoint_path)
|
||||
}
|
||||
)
|
||||
checkpoint_artifact.add_file(str(checkpoint_path))
|
||||
self.wandb_run.log_artifact(checkpoint_artifact) # type: ignore
|
||||
|
||||
def restore_data(self) -> Tuple[int, int, int]:
|
||||
checkpoint_artifact = self.wandb_run.use_artifact( # type: ignore
|
||||
'run_' + self.wandb_run.id + '_checkpoint:latest' # type: ignore
|
||||
)
|
||||
assert checkpoint_artifact is not None, "W&B dataset artifact doesn't exist"
|
||||
|
||||
checkpoint_artifact.download(
|
||||
os.path.dirname(checkpoint_artifact.metadata['checkpoint_path'])
|
||||
)
|
||||
|
||||
try: # epoch / gradient_step
|
||||
epoch = checkpoint_artifact.metadata["save/epoch"]
|
||||
self.last_save_step = self.last_log_test_step = epoch
|
||||
gradient_step = checkpoint_artifact.metadata["save/gradient_step"]
|
||||
self.last_log_update_step = gradient_step
|
||||
except KeyError:
|
||||
epoch, gradient_step = 0, 0
|
||||
try: # offline trainer doesn't have env_step
|
||||
env_step = checkpoint_artifact.metadata["save/env_step"]
|
||||
self.last_log_train_step = env_step
|
||||
except KeyError:
|
||||
env_step = 0
|
||||
return epoch, env_step, gradient_step
|
||||
|
Loading…
x
Reference in New Issue
Block a user