update dqn tutorial
This commit is contained in:
parent
d9e4b9d16f
commit
4e7df7616a
@ -34,7 +34,7 @@ Documents are written under the `docs/` directory as RestructuredText (`.rst`) f
|
||||
API References are automatically generated by [Sphinx](http://www.sphinx-doc.org/en/stable/) according to the outlines under
|
||||
`doc/api/` and should be modified when any code changes.
|
||||
|
||||
To compile docs into webpages, Run
|
||||
To compile docs into webpages, run
|
||||
```
|
||||
make html
|
||||
```
|
||||
|
43
README.md
43
README.md
@ -6,12 +6,13 @@
|
||||
[](https://pypi.org/project/tianshou/)
|
||||
[](https://github.com/thu-ml/tianshou/actions)
|
||||
[](https://tianshou.readthedocs.io)
|
||||
[](https://github.com/thu-ml/tianshou/issues)
|
||||
[](https://github.com/thu-ml/tianshou/stargazers)
|
||||
[](https://github.com/thu-ml/tianshou/network)
|
||||
[](https://github.com/thu-ml/tianshou/issues)
|
||||
[](https://github.com/thu-ml/tianshou/blob/master/LICENSE)
|
||||
[](https://gitter.im/thu-ml/tianshou?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||
|
||||
**Tianshou**(天授) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed framework and pythonic API for building the deep reinforcement learning agent. The supported interface algorithms include:
|
||||
**Tianshou** (天授) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed framework and pythonic API for building the deep reinforcement learning agent. The supported interface algorithms include:
|
||||
|
||||
|
||||
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
|
||||
@ -131,18 +132,14 @@ You can check out the [documentation](https://tianshou.readthedocs.io) for furth
|
||||
|
||||
## Quick Start
|
||||
|
||||
This is an example of Deep Q Network. You can also run the full script under [test/discrete/test_dqn.py](https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_dqn.py).
|
||||
This is an example of Deep Q Network. You can also run the full script at [test/discrete/test_dqn.py](https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_dqn.py).
|
||||
|
||||
First, import some relevant packages:
|
||||
|
||||
```python
|
||||
import gym, torch, numpy as np, torch.nn as nn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||
import tianshou as ts
|
||||
```
|
||||
|
||||
Define some hyper-parameters:
|
||||
@ -166,12 +163,9 @@ writer = SummaryWriter('log/dqn') # tensorboard is also supported!
|
||||
Make environments:
|
||||
|
||||
```python
|
||||
env = gym.make(task)
|
||||
state_shape = env.observation_space.shape or env.observation_space.n
|
||||
action_shape = env.action_space.shape or env.action_space.n
|
||||
# you can also try with SubprocVectorEnv
|
||||
train_envs = VectorEnv([lambda: gym.make(task) for _ in range(train_num)])
|
||||
test_envs = VectorEnv([lambda: gym.make(task) for _ in range(test_num)])
|
||||
train_envs = ts.env.VectorEnv([lambda: gym.make(task) for _ in range(train_num)])
|
||||
test_envs = ts.env.VectorEnv([lambda: gym.make(task) for _ in range(test_num)])
|
||||
```
|
||||
|
||||
Define the network:
|
||||
@ -193,6 +187,9 @@ class Net(nn.Module):
|
||||
logits = self.model(s.view(batch, -1))
|
||||
return logits, state
|
||||
|
||||
env = gym.make(task)
|
||||
state_shape = env.observation_space.shape or env.observation_space.n
|
||||
action_shape = env.action_space.shape or env.action_space.n
|
||||
net = Net(state_shape, action_shape)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=lr)
|
||||
```
|
||||
@ -200,16 +197,16 @@ optim = torch.optim.Adam(net.parameters(), lr=lr)
|
||||
Setup policy and collectors:
|
||||
|
||||
```python
|
||||
policy = DQNPolicy(net, optim, gamma, n_step,
|
||||
use_target_network=True, target_update_freq=target_freq)
|
||||
train_collector = Collector(policy, train_envs, ReplayBuffer(buffer_size))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
policy = ts.policy.DQNPolicy(net, optim, gamma, n_step,
|
||||
use_target_network=True, target_update_freq=target_freq)
|
||||
train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(buffer_size))
|
||||
test_collector = ts.data.Collector(policy, test_envs)
|
||||
```
|
||||
|
||||
Let's train it:
|
||||
|
||||
```python
|
||||
result = offpolicy_trainer(
|
||||
result = ts.trainer.offpolicy_trainer(
|
||||
policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step,
|
||||
test_num, batch_size, train_fn=lambda e: policy.set_eps(eps_train),
|
||||
test_fn=lambda e: policy.set_eps(eps_test),
|
||||
@ -217,7 +214,7 @@ result = offpolicy_trainer(
|
||||
print(f'Finished training! Use {result["duration"]}')
|
||||
```
|
||||
|
||||
Saving / loading trained policy (it's exactly the same as PyTorch nn.module):
|
||||
Save / load the trained policy (it's exactly the same as PyTorch nn.module):
|
||||
|
||||
```python
|
||||
torch.save(policy.state_dict(), 'dqn.pth')
|
||||
@ -226,13 +223,13 @@ policy.load_state_dict(torch.load('dqn.pth'))
|
||||
|
||||
Watch the performance with 35 FPS:
|
||||
|
||||
```python3
|
||||
collector = Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1/35)
|
||||
```python
|
||||
collector = ts.data.Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
collector.close()
|
||||
```
|
||||
|
||||
Looking at the result saved in tensorboard: (on bash script)
|
||||
Look at the result saved in tensorboard: (on bash script)
|
||||
|
||||
```bash
|
||||
tensorboard --logdir log/dqn
|
||||
|
@ -50,6 +50,7 @@ extensions = [
|
||||
'sphinx.ext.ifconfig',
|
||||
'sphinx.ext.viewcode',
|
||||
'sphinx.ext.githubpages',
|
||||
'sphinxcontrib.bibtex',
|
||||
]
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
|
@ -49,6 +49,8 @@ If no error occurs, you have successfully installed Tianshou.
|
||||
:maxdepth: 1
|
||||
:caption: Tutorials
|
||||
|
||||
tutorials/dqn
|
||||
tutorials/concepts
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
69
docs/refs.bib
Normal file
69
docs/refs.bib
Normal file
@ -0,0 +1,69 @@
|
||||
@article{DQN,
|
||||
author = {Volodymyr Mnih and
|
||||
Koray Kavukcuoglu and
|
||||
David Silver and
|
||||
Andrei A. Rusu and
|
||||
Joel Veness and
|
||||
Marc G. Bellemare and
|
||||
Alex Graves and
|
||||
Martin A. Riedmiller and
|
||||
Andreas Fidjeland and
|
||||
Georg Ostrovski and
|
||||
Stig Petersen and
|
||||
Charles Beattie and
|
||||
Amir Sadik and
|
||||
Ioannis Antonoglou and
|
||||
Helen King and
|
||||
Dharshan Kumaran and
|
||||
Daan Wierstra and
|
||||
Shane Legg and
|
||||
Demis Hassabis},
|
||||
title = {Human-level control through deep reinforcement learning},
|
||||
journal = {Nature},
|
||||
volume = {518},
|
||||
number = {7540},
|
||||
pages = {529--533},
|
||||
year = {2015},
|
||||
url = {https://doi.org/10.1038/nature14236},
|
||||
doi = {10.1038/nature14236},
|
||||
timestamp = {Wed, 14 Nov 2018 10:30:43 +0100},
|
||||
biburl = {https://dblp.org/rec/journals/nature/MnihKSRVBGRFOPB15.bib},
|
||||
bibsource = {dblp computer science bibliography, https://dblp.org}
|
||||
}
|
||||
|
||||
@inproceedings{DDPG,
|
||||
author = {Timothy P. Lillicrap and
|
||||
Jonathan J. Hunt and
|
||||
Alexander Pritzel and
|
||||
Nicolas Heess and
|
||||
Tom Erez and
|
||||
Yuval Tassa and
|
||||
David Silver and
|
||||
Daan Wierstra},
|
||||
title = {Continuous control with deep reinforcement learning},
|
||||
booktitle = {4th International Conference on Learning Representations, {ICLR} 2016,
|
||||
San Juan, Puerto Rico, May 2-4, 2016, Conference Track Proceedings},
|
||||
year = {2016},
|
||||
url = {http://arxiv.org/abs/1509.02971},
|
||||
timestamp = {Thu, 25 Jul 2019 14:25:37 +0200},
|
||||
biburl = {https://dblp.org/rec/journals/corr/LillicrapHPHETS15.bib},
|
||||
bibsource = {dblp computer science bibliography, https://dblp.org}
|
||||
}
|
||||
|
||||
@article{PPO,
|
||||
author = {John Schulman and
|
||||
Filip Wolski and
|
||||
Prafulla Dhariwal and
|
||||
Alec Radford and
|
||||
Oleg Klimov},
|
||||
title = {Proximal Policy Optimization Algorithms},
|
||||
journal = {CoRR},
|
||||
volume = {abs/1707.06347},
|
||||
year = {2017},
|
||||
url = {http://arxiv.org/abs/1707.06347},
|
||||
archivePrefix = {arXiv},
|
||||
eprint = {1707.06347},
|
||||
timestamp = {Mon, 13 Aug 2018 16:47:34 +0200},
|
||||
biburl = {https://dblp.org/rec/journals/corr/SchulmanWDRK17.bib},
|
||||
bibsource = {dblp computer science bibliography, https://dblp.org}
|
||||
}
|
1
docs/requirements.txt
Normal file
1
docs/requirements.txt
Normal file
@ -0,0 +1 @@
|
||||
sphinxcontrib-bibtex
|
4
docs/tutorials/concepts.rst
Normal file
4
docs/tutorials/concepts.rst
Normal file
@ -0,0 +1,4 @@
|
||||
Basic Concepts in Tianshou
|
||||
==========================
|
||||
|
||||
Under construction...
|
216
docs/tutorials/dqn.rst
Normal file
216
docs/tutorials/dqn.rst
Normal file
@ -0,0 +1,216 @@
|
||||
Deep Q Network
|
||||
==============
|
||||
|
||||
Deep reinforcement learning has achieved significant successes in various applications.
|
||||
**Deep Q Network** (DQN) :cite:`DQN` is the pioneer one.
|
||||
In this tutorial, we will show how to train a DQN agent on CartPole with Tianshou step by step.
|
||||
The full script is at `test/discrete/test_dqn.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_dqn.py>`_.
|
||||
|
||||
Contrary to existing Deep RL libraries such as `RLlib <https://github.com/ray-project/ray/tree/master/rllib/>`_, which could only accept a config specification of hyperparameters, network, and others, Tianshou provides an easy way of construction through the code-level.
|
||||
|
||||
|
||||
Make an Environment
|
||||
-------------------
|
||||
|
||||
First of all, you have to make an environment for your agent to act in. For the environment interfaces, we follow the convention of `OpenAI Gym <https://github.com/openai/gym>`_. In your Python code, simply import Tianshou and make the environment
|
||||
::
|
||||
|
||||
import gym
|
||||
import tianshou as ts
|
||||
|
||||
env = gym.make('CartPole-v0')
|
||||
|
||||
CartPole-v0 is a simple environment with a discrete action space, for which DQN applies. You have to identify whether the action space is continuous or discrete and apply eligible algorithms. DDPG :cite:`DDPG`, for example, could only be applied to continuous action spaces, while almost all other policy gradient methods could be applied to both, depending on the probability distribution on the action.
|
||||
|
||||
|
||||
Setup Multi-environment Wrapper
|
||||
-------------------------------
|
||||
|
||||
It is available if you want the original ``gym.Env``:
|
||||
::
|
||||
|
||||
train_envs = gym.make('CartPole-v0')
|
||||
test_envs = gym.make('CartPole-v0')
|
||||
|
||||
Tianshou supports parallel sampling for all algorithms. It provides three types of vectorized environment wrapper: :class:`~tianshou.env.VectorEnv`, :class:`~tianshou.env.SubprocVectorEnv`, and :class:`~tianshou.env.RayVectorEnv`. It can be used as follows:
|
||||
::
|
||||
|
||||
train_envs = ts.env.VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(8)])
|
||||
test_envs = ts.env.VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(100)])
|
||||
|
||||
Here, we set up 8 environments in ``train_envs`` and 100 environments in ``test_envs``.
|
||||
|
||||
For demonstrating, here we use the second block of codes.
|
||||
|
||||
|
||||
Build the Network
|
||||
-----------------
|
||||
|
||||
Tianshou supports any user-defined PyTorch networks and optimizers but with the limitation of input and output API. Here is an example code:
|
||||
::
|
||||
|
||||
import torch, numpy as np
|
||||
from torch import nn
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, state_shape, action_shape):
|
||||
super().__init__()
|
||||
self.model = nn.Sequential(*[
|
||||
nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True),
|
||||
nn.Linear(128, 128), nn.ReLU(inplace=True),
|
||||
nn.Linear(128, 128), nn.ReLU(inplace=True),
|
||||
nn.Linear(128, np.prod(action_shape))
|
||||
])
|
||||
def forward(self, obs, state=None, info={}):
|
||||
if not isinstance(obs, torch.Tensor):
|
||||
obs = torch.tensor(obs, dtype=torch.float)
|
||||
batch = obs.shape[0]
|
||||
logits = self.model(obs.view(batch, -1))
|
||||
return logits, state
|
||||
|
||||
state_shape = env.observation_space.shape or env.observation_space.n
|
||||
action_shape = env.action_space.shape or env.action_space.n
|
||||
net = Net(state_shape, action_shape)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=1e-3)
|
||||
|
||||
The rules of self-defined networks are:
|
||||
|
||||
1. Input: observation ``obs`` (may be a ``numpy.ndarray`` or ``torch.Tensor``), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment.
|
||||
2. Output: some ``logits`` and the next hidden state ``state``. The logits could be a tuple instead of a ``torch.Tensor``. It depends on how the policy process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy.
|
||||
|
||||
|
||||
Setup Policy
|
||||
------------
|
||||
|
||||
We use the defined ``net`` and ``optim``, with extra policy hyper-parameters, to define a policy. Here we define a DQN policy with using a target network:
|
||||
::
|
||||
|
||||
policy = ts.policy.DQNPolicy(net, optim,
|
||||
discount_factor=0.9, estimation_step=3,
|
||||
use_target_network=True, target_update_freq=320)
|
||||
|
||||
|
||||
Setup Collector
|
||||
---------------
|
||||
|
||||
The collector is a key concept in Tianshou. It allows the policy to interact with different types of environments conveniently.
|
||||
In each step, the collector will let the policy perform (at least) a specified number of steps or episodes and store the data in a replay buffer.
|
||||
::
|
||||
|
||||
train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(size=20000))
|
||||
test_collector = ts.data.Collector(policy, test_envs)
|
||||
|
||||
|
||||
Train Policy with a Trainer
|
||||
---------------------------
|
||||
|
||||
Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tianshou.trainer.offpolicy_trainer`. The trainer will automatically stop training when the policy reach the stop condition ``stop_fn`` on test collector. Since DQN is an off-policy algorithm, we use the :class:`~tianshou.trainer.offpolicy_trainer` as follows:
|
||||
::
|
||||
|
||||
result = ts.trainer.offpolicy_trainer(
|
||||
policy, train_collector, test_collector,
|
||||
max_epoch=10, step_per_epoch=1000, collect_per_step=10,
|
||||
episode_per_test=100, batch_size=64,
|
||||
train_fn=lambda e: policy.set_eps(0.1),
|
||||
test_fn=lambda e: policy.set_eps(0.05),
|
||||
stop_fn=lambda x: x >= env.spec.reward_threshold,
|
||||
writer=None)
|
||||
print(f'Finished training! Use {result["duration"]}')
|
||||
|
||||
The meaning of each parameter is as follows:
|
||||
|
||||
* ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``;
|
||||
* ``step_per_epoch``: The number of step for updating policy network in one epoch;
|
||||
* ``collect_per_step``: The number of frame the collector would collect before the network update. For example, the code above means "collect 10 frames and do one policy network update";
|
||||
* ``episode_per_test``: The number of episode for one policy evaluation.
|
||||
* ``batch_size``: The batch size of sampled data, which is going to feed in the policy network.
|
||||
* ``train_fn``: A function receives the current number of epoch index and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training".
|
||||
* ``test_fn``: A function receives the current number of epoch index and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing".
|
||||
* ``stop_fn``: A function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal.
|
||||
* ``writer``: See below.
|
||||
|
||||
The trainer supports `TensorBoard <https://www.tensorflow.org/tensorboard>`_ for logging. It can be used as:
|
||||
::
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
writer = SummaryWriter('log/dqn')
|
||||
|
||||
Pass the writer into the trainer, and the training result will be recorded into the TensorBoard.
|
||||
|
||||
The returned result is a dictionary as follows:
|
||||
::
|
||||
|
||||
{
|
||||
'train_step': 9246,
|
||||
'train_episode': 504.0,
|
||||
'train_time/collector': '0.65s',
|
||||
'train_time/model': '1.97s',
|
||||
'train_speed': '3518.79 step/s',
|
||||
'test_step': 49112,
|
||||
'test_episode': 400.0,
|
||||
'test_time': '1.38s',
|
||||
'test_speed': '35600.52 step/s',
|
||||
'best_reward': 199.03,
|
||||
'duration': '4.01s'
|
||||
}
|
||||
|
||||
It shows that within approximately 4 seconds, we finished training a DQN agent on CartPole. The mean returns over 100 consecutive episodes is 199.03.
|
||||
|
||||
|
||||
Save/Load Policy
|
||||
----------------
|
||||
|
||||
Since the policy inherits the ``torch.nn.Module`` class, saving and loading the policy are exactly the same as a torch module:
|
||||
::
|
||||
|
||||
torch.save(policy.state_dict(), 'dqn.pth')
|
||||
policy.load_state_dict(torch.load('dqn.pth'))
|
||||
|
||||
|
||||
Watch the Performance
|
||||
---------------------
|
||||
|
||||
:class:`~tianshou.data.Collector` supports rendering. Here is the example of watching the agent in 35 FPS:
|
||||
::
|
||||
|
||||
collector = ts.data.Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
collector.close()
|
||||
|
||||
|
||||
Train a Policy with Customized Codes
|
||||
------------------------------------
|
||||
|
||||
"I don't want to use your provided trainer. I want to customize it!"
|
||||
|
||||
No problem! Here are the usages:
|
||||
::
|
||||
|
||||
# pre-collect 5000 frames with random action before training
|
||||
policy.set_eps(1)
|
||||
train_collector.collect(n_step=5000)
|
||||
|
||||
policy.set_eps(0.1)
|
||||
for i in range(int(1e6)): # total step
|
||||
collect_result = train_collector.collect(n_step=10)
|
||||
|
||||
# once if the collected episodes' mean returns reach the threshold,
|
||||
# or every 1000 steps, we test it on test_collector
|
||||
if collect_result['rew'] >= env.spec.reward_threshold or i % 1000 == 0:
|
||||
policy.set_eps(0.05)
|
||||
result = test_collector.collect(n_episode=100)
|
||||
if result['rew'] >= env.spec.reward_threshold:
|
||||
print(f'Finished training! Test mean returns: {result["rew"]}')
|
||||
break
|
||||
else:
|
||||
# back to training eps
|
||||
policy.set_eps(0.1)
|
||||
|
||||
# train policy with a sampled batch data
|
||||
losses = policy.learn(train_collector.sample(batch_size=64))
|
||||
|
||||
|
||||
.. rubric:: References
|
||||
|
||||
.. bibliography:: ../refs.bib
|
||||
:style: unsrtalpha
|
Loading…
x
Reference in New Issue
Block a user