221 lines
9.9 KiB
ReStructuredText
Raw Normal View History

2020-03-29 15:18:33 +08:00
Deep Q Network
==============
Deep reinforcement learning has achieved significant successes in various applications.
**Deep Q Network** (DQN) :cite:`DQN` is the pioneer one.
In this tutorial, we will show how to train a DQN agent on CartPole with Tianshou step by step.
The full script is at `test/discrete/test_dqn.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_dqn.py>`_.
Contrary to existing Deep RL libraries such as `RLlib <https://github.com/ray-project/ray/tree/master/rllib/>`_, which could only accept a config specification of hyperparameters, network, and others, Tianshou provides an easy way of construction through the code-level.
Make an Environment
2020-06-02 08:51:14 +08:00
-------------------
2020-03-29 15:18:33 +08:00
2020-04-02 09:07:04 +08:00
First of all, you have to make an environment for your agent to interact with. For environment interfaces, we follow the convention of `OpenAI Gym <https://github.com/openai/gym>`_. In your Python code, simply import Tianshou and make the environment:
2020-03-29 15:18:33 +08:00
::
import gym
import tianshou as ts
env = gym.make('CartPole-v0')
CartPole-v0 is a simple environment with a discrete action space, for which DQN applies. You have to identify whether the action space is continuous or discrete and apply eligible algorithms. DDPG :cite:`DDPG`, for example, could only be applied to continuous action spaces, while almost all other policy gradient methods could be applied to both, depending on the probability distribution on the action.
Setup Multi-environment Wrapper
2020-06-02 08:51:14 +08:00
-------------------------------
2020-03-29 15:18:33 +08:00
It is available if you want the original ``gym.Env``:
::
train_envs = gym.make('CartPole-v0')
test_envs = gym.make('CartPole-v0')
Tianshou supports parallel sampling for all algorithms. It provides three types of vectorized environment wrapper: :class:`~tianshou.env.VectorEnv`, :class:`~tianshou.env.SubprocVectorEnv`, and :class:`~tianshou.env.RayVectorEnv`. It can be used as follows:
::
train_envs = ts.env.VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(8)])
test_envs = ts.env.VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(100)])
Here, we set up 8 environments in ``train_envs`` and 100 environments in ``test_envs``.
2020-04-02 09:07:04 +08:00
For the demonstration, here we use the second block of codes.
2020-03-29 15:18:33 +08:00
2020-06-08 21:53:00 +08:00
.. _build_the_network:
2020-03-29 15:18:33 +08:00
Build the Network
2020-06-02 08:51:14 +08:00
-----------------
2020-03-29 15:18:33 +08:00
Tianshou supports any user-defined PyTorch networks and optimizers but with the limitation of input and output API. Here is an example code:
::
import torch, numpy as np
from torch import nn
class Net(nn.Module):
def __init__(self, state_shape, action_shape):
super().__init__()
self.model = nn.Sequential(*[
nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True),
nn.Linear(128, 128), nn.ReLU(inplace=True),
nn.Linear(128, 128), nn.ReLU(inplace=True),
nn.Linear(128, np.prod(action_shape))
])
def forward(self, obs, state=None, info={}):
if not isinstance(obs, torch.Tensor):
obs = torch.tensor(obs, dtype=torch.float)
batch = obs.shape[0]
logits = self.model(obs.view(batch, -1))
return logits, state
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
net = Net(state_shape, action_shape)
optim = torch.optim.Adam(net.parameters(), lr=1e-3)
You can also have a try with those pre-defined networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are:
2020-03-29 15:18:33 +08:00
2020-06-08 21:53:00 +08:00
1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment.
2. Output: some ``logits``, the next hidden state ``state``, and intermediate result during the policy forwarding procedure ``policy``. The logits could be a tuple instead of a ``torch.Tensor``. It depends on how the policy process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. The ``policy`` can be a Batch of torch.Tensor or other things, which will be stored in the replay buffer, and can be accessed in the policy update process (e.g. in ``policy.learn()``, the ``batch.policy`` is what you need).
2020-03-29 15:18:33 +08:00
Setup Policy
2020-06-02 08:51:14 +08:00
------------
2020-03-29 15:18:33 +08:00
We use the defined ``net`` and ``optim``, with extra policy hyper-parameters, to define a policy. Here we define a DQN policy with using a target network:
::
policy = ts.policy.DQNPolicy(net, optim,
discount_factor=0.9, estimation_step=3,
Add multi-agent example: tic-tac-toe (#122) * make fileds with empty Batch rather than None after reset * dummy code * remove dummy * add reward_length argument for collector * Improve Batch (#126) * make sure the key type of Batch is string, and add unit tests * add is_empty() function and unit tests * enable cat of mixing dict and Batch, just like stack * bugfix for reward_length * add get_final_reward_fn argument to collector to deal with marl * minor polish * remove multibuf * minor polish * improve and implement Batch.cat_ * bugfix for buffer.sample with field impt_weight * restore the usage of a.cat_(b) * fix 2 bugs in batch and add corresponding unittest * code fix for update * update is_empty to recognize empty over empty; bugfix for len * bugfix for update and add testcase * add testcase of update * make fileds with empty Batch rather than None after reset * dummy code * remove dummy * add reward_length argument for collector * bugfix for reward_length * add get_final_reward_fn argument to collector to deal with marl * make sure the key type of Batch is string, and add unit tests * add is_empty() function and unit tests * enable cat of mixing dict and Batch, just like stack * dummy code * remove dummy * add multi-agent example: tic-tac-toe * move TicTacToeEnv to a separate file * remove dummy MANet * code refactor * move tic-tac-toe example to test * update doc with marl-example * fix docs * reduce the threshold * revert * update player id to start from 1 and change player to agent; keep coding * add reward_length argument for collector * Improve Batch (#128) * minor polish * improve and implement Batch.cat_ * bugfix for buffer.sample with field impt_weight * restore the usage of a.cat_(b) * fix 2 bugs in batch and add corresponding unittest * code fix for update * update is_empty to recognize empty over empty; bugfix for len * bugfix for update and add testcase * add testcase of update * fix docs * fix docs * fix docs [ci skip] * fix docs [ci skip] Co-authored-by: Trinkle23897 <463003665@qq.com> * refact * re-implement Batch.stack and add testcases * add doc for Batch.stack * reward_metric * modify flag * minor fix * reuse _create_values and refactor stack_ & cat_ * fix pep8 * fix reward stat in collector * fix stat of collector, simplify test/base/env.py * fix docs * minor fix * raise exception for stacking with partial keys and axis!=0 * minor fix * minor fix * minor fix * marl-examples * add condense; bugfix for torch.Tensor; code refactor * marl example can run now * enable tic tac toe with larger board size and win-size * add test dependency * Fix padding of inconsistent keys with Batch.stack and Batch.cat (#130) * re-implement Batch.stack and add testcases * add doc for Batch.stack * reuse _create_values and refactor stack_ & cat_ * fix pep8 * fix docs * raise exception for stacking with partial keys and axis!=0 * minor fix * minor fix Co-authored-by: Trinkle23897 <463003665@qq.com> * stash * let agent learn to play as agent 2 which is harder * code refactor * Improve collector (#125) * remove multibuf * reward_metric * make fileds with empty Batch rather than None after reset * many fixes and refactor Co-authored-by: Trinkle23897 <463003665@qq.com> * marl for tic-tac-toe and general gomoku * update default gamma to 0.1 for tic tac toe to win earlier * fix name typo; change default game config; add rew_norm option * fix pep8 * test commit * mv test dir name * add rew flag * fix torch.optim import error and madqn rew_norm * remove useless kwargs * Vector env enable select worker (#132) * Enable selecting worker for vector env step method. * Update collector to match new vecenv selective worker behavior. * Bug fix. * Fix rebase Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu> * show the last move of tictactoe by capital letters * add multi-agent tutorial * fix link * Standardized behavior of Batch.cat and misc code refactor (#137) * code refactor; remove unused kwargs; add reward_normalization for dqn * bugfix for __setitem__ with torch.Tensor; add Batch.condense * minor fix * support cat with empty Batch * remove the dependency of is_empty on len; specify the semantic of empty Batch by test cases * support stack with empty Batch * remove condense * refactor code to reflect the shared / partial / reserved categories of keys * add is_empty(recursive=False) * doc fix * docfix and bugfix for _is_batch_set * add doc for key reservation * bugfix for algebra operators * fix cat with lens hint * code refactor * bugfix for storing None * use ValueError instead of exception * hide lens away from users * add comment for __cat * move the computation of the initial value of lens in cat_ itself. * change the place of doc string * doc fix for Batch doc string * change recursive to recurse * doc string fix * minor fix for batch doc * write tutorials to specify the standard of Batch (#142) * add doc for len exceptions * doc move; unify is_scalar_value function * remove some issubclass check * bugfix for shape of Batch(a=1) * keep moving doc * keep writing batch tutorial * draft version of Batch tutorial done * improving doc * keep improving doc * batch tutorial done * rename _is_number * rename _is_scalar * shape property do not raise exception * restore some doc string * grammarly [ci skip] * grammarly + fix warning of building docs * polish docs * trim and re-arrange batch tutorial * go straight to the point * minor fix for batch doc * add shape / len in basic usage * keep improving tutorial * unify _to_array_with_correct_type to remove duplicate code * delegate type convertion to Batch.__init__ * further delegate type convertion to Batch.__init__ * bugfix for setattr * add a _parse_value function * remove dummy function call * polish docs Co-authored-by: Trinkle23897 <463003665@qq.com> * bugfix for mapolicy * pretty code * remove debug code; remove condense * doc fix * check before get_agents in tutorials/tictactoe * tutorial * fix * minor fix for batch doc * minor polish * faster test_ttt * improve tic-tac-toe environment * change default epoch and step-per-epoch for tic-tac-toe * fix mapolicy * minor polish for mapolicy * 90% to 80% (need to change the tutorial) * win rate * show step number at board * simplify mapolicy * minor polish for mapolicy * remove MADQN * fix pep8 * change legal_actions to mask (need to update docs) * simplify maenv * fix typo * move basevecenv to single file * separate RandomAgent * update docs * grammarly * fix pep8 * win rate typo * format in cheatsheet * use bool mask directly * update doc for boolean mask Co-authored-by: Trinkle23897 <463003665@qq.com> Co-authored-by: Alexis DUBURCQ <alexis.duburcq@gmail.com> Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
2020-07-21 14:59:49 +08:00
target_update_freq=320)
2020-03-29 15:18:33 +08:00
Setup Collector
2020-06-02 08:51:14 +08:00
---------------
2020-03-29 15:18:33 +08:00
The collector is a key concept in Tianshou. It allows the policy to interact with different types of environments conveniently.
In each step, the collector will let the policy perform (at least) a specified number of steps or episodes and store the data in a replay buffer.
::
train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(size=20000))
test_collector = ts.data.Collector(policy, test_envs)
Train Policy with a Trainer
2020-06-02 08:51:14 +08:00
---------------------------
2020-03-29 15:18:33 +08:00
Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tianshou.trainer.offpolicy_trainer`. The trainer will automatically stop training when the policy reach the stop condition ``stop_fn`` on test collector. Since DQN is an off-policy algorithm, we use the :class:`~tianshou.trainer.offpolicy_trainer` as follows:
::
result = ts.trainer.offpolicy_trainer(
policy, train_collector, test_collector,
max_epoch=10, step_per_epoch=1000, collect_per_step=10,
episode_per_test=100, batch_size=64,
train_fn=lambda e: policy.set_eps(0.1),
test_fn=lambda e: policy.set_eps(0.05),
stop_fn=lambda x: x >= env.spec.reward_threshold,
writer=None)
print(f'Finished training! Use {result["duration"]}')
The meaning of each parameter is as follows:
* ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``;
* ``step_per_epoch``: The number of step for updating policy network in one epoch;
2020-04-02 09:07:04 +08:00
* ``collect_per_step``: The number of frames the collector would collect before the network update. For example, the code above means "collect 10 frames and do one policy network update";
* ``episode_per_test``: The number of episodes for one policy evaluation.
2020-04-02 09:07:04 +08:00
* ``batch_size``: The batch size of sample data, which is going to feed in the policy network.
2020-03-29 15:18:33 +08:00
* ``train_fn``: A function receives the current number of epoch index and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training".
* ``test_fn``: A function receives the current number of epoch index and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing".
* ``stop_fn``: A function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal.
* ``writer``: See below.
The trainer supports `TensorBoard <https://www.tensorflow.org/tensorboard>`_ for logging. It can be used as:
::
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('log/dqn')
Pass the writer into the trainer, and the training result will be recorded into the TensorBoard.
The returned result is a dictionary as follows:
::
{
'train_step': 9246,
'train_episode': 504.0,
'train_time/collector': '0.65s',
'train_time/model': '1.97s',
'train_speed': '3518.79 step/s',
'test_step': 49112,
'test_episode': 400.0,
'test_time': '1.38s',
'test_speed': '35600.52 step/s',
'best_reward': 199.03,
'duration': '4.01s'
}
It shows that within approximately 4 seconds, we finished training a DQN agent on CartPole. The mean returns over 100 consecutive episodes is 199.03.
Save/Load Policy
2020-06-02 08:51:14 +08:00
----------------
2020-03-29 15:18:33 +08:00
Since the policy inherits the ``torch.nn.Module`` class, saving and loading the policy are exactly the same as a torch module:
::
torch.save(policy.state_dict(), 'dqn.pth')
policy.load_state_dict(torch.load('dqn.pth'))
2020-04-02 09:07:04 +08:00
Watch the Agent's Performance
2020-06-02 08:51:14 +08:00
-----------------------------
2020-03-29 15:18:33 +08:00
2020-04-02 09:07:04 +08:00
:class:`~tianshou.data.Collector` supports rendering. Here is the example of watching the agent's performance in 35 FPS:
2020-03-29 15:18:33 +08:00
::
collector = ts.data.Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)
collector.close()
2020-04-02 12:31:22 +08:00
.. _customized_trainer:
2020-03-29 15:18:33 +08:00
Train a Policy with Customized Codes
2020-06-02 08:51:14 +08:00
------------------------------------
2020-03-29 15:18:33 +08:00
"I don't want to use your provided trainer. I want to customize it!"
2020-04-02 09:07:04 +08:00
No problem! Tianshou supports user-defined training code. Here is the usage:
2020-03-29 15:18:33 +08:00
::
# pre-collect 5000 frames with random action before training
policy.set_eps(1)
train_collector.collect(n_step=5000)
policy.set_eps(0.1)
for i in range(int(1e6)): # total step
collect_result = train_collector.collect(n_step=10)
# once if the collected episodes' mean returns reach the threshold,
# or every 1000 steps, we test it on test_collector
if collect_result['rew'] >= env.spec.reward_threshold or i % 1000 == 0:
policy.set_eps(0.05)
result = test_collector.collect(n_episode=100)
if result['rew'] >= env.spec.reward_threshold:
print(f'Finished training! Test mean returns: {result["rew"]}')
break
else:
# back to training eps
policy.set_eps(0.1)
# train policy with a sampled batch data
losses = policy.learn(train_collector.sample(batch_size=64))
2020-04-10 11:16:33 +08:00
For further usage, you can refer to :doc:`/tutorials/cheatsheet`.
2020-03-29 15:18:33 +08:00
.. rubric:: References
2020-04-04 21:02:06 +08:00
.. bibliography:: /refs.bib
2020-03-29 15:18:33 +08:00
:style: unsrtalpha