finish concepts
This commit is contained in:
parent
0acd0d164c
commit
0e86d44860
@ -20,6 +20,8 @@ This command will run automatic tests in the main directory
|
||||
pytest test --cov tianshou -s
|
||||
```
|
||||
|
||||
To run on your own GitHub Repo, enable the [GitHub Action](/actions) and it will automatically run the test.
|
||||
|
||||
##### PEP8 Code Style Check
|
||||
|
||||
We follow PEP8 python code style. To check, in the main directory, run:
|
||||
|
@ -64,7 +64,7 @@ The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/ma
|
||||
|
||||
Tianshou is a lightweight but high-speed reinforcement learning platform. For example, here is a test on a laptop (i7-8750H + GTX1060). It only uses 3 seconds for training an agent based on vanilla policy gradient on the CartPole-v0 task: (seed may be different across different platform and device)
|
||||
|
||||
```python
|
||||
```bash
|
||||
python3 test/discrete/test_pg.py --seed 0 --render 0.03
|
||||
```
|
||||
|
||||
|
@ -11,7 +11,7 @@ Tianshou splits a Reinforcement Learning agent training procedure into these par
|
||||
Data Batch
|
||||
----------
|
||||
|
||||
Tianshou provides :class:`~tianshou.data.Batch` as the internal data structure to pass any kinds of data to other method, for example, a collector gives a :class:`~tianshou.data.Batch` to policy for learning. Here is its usage:
|
||||
Tianshou provides :class:`~tianshou.data.Batch` as the internal data structure to pass any kind of data to other methods, for example, a collector gives a :class:`~tianshou.data.Batch` to policy for learning. Here is its usage:
|
||||
::
|
||||
|
||||
>>> import numpy as np
|
||||
@ -27,14 +27,14 @@ Tianshou provides :class:`~tianshou.data.Batch` as the internal data structure t
|
||||
|
||||
In short, you can define a :class:`~tianshou.data.Batch` with any key-value pair.
|
||||
|
||||
Current implementation of Tianshou typically use 6 keys in :class:`~tianshou.data.Batch`:
|
||||
The current implementation of Tianshou typically use 6 keys in :class:`~tianshou.data.Batch`:
|
||||
|
||||
* ``obs``: observation of step :math:`t` ;
|
||||
* ``act``: action of step :math:`t` ;
|
||||
* ``rew``: reward of step :math:`t` ;
|
||||
* ``obs``: the observation of step :math:`t` ;
|
||||
* ``act``: the action of step :math:`t` ;
|
||||
* ``rew``: the reward of step :math:`t` ;
|
||||
* ``done``: the done flag of step :math:`t` ;
|
||||
* ``obs_next``: observation of step :math:`t+1` ;
|
||||
* ``info``: info of step :math:`t` (in ``gym.Env``, the ``env.step()`` function return 4 arguments, and the last one is ``info``);
|
||||
* ``obs_next``: the observation of step :math:`t+1` ;
|
||||
* ``info``: the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` function return 4 arguments, and the last one is ``info``);
|
||||
|
||||
:class:`~tianshou.data.Batch` has other methods, including ``__getitem__``, ``append``, and ``split``:
|
||||
::
|
||||
@ -79,7 +79,7 @@ Data Buffer
|
||||
# since its size = 10, it only stores the last 10 steps' result.
|
||||
array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.])
|
||||
|
||||
>>> # move buf2's result into buf (keep the order of time meanwhile)
|
||||
>>> # move buf2's result into buf (keep it chronologically meanwhile)
|
||||
>>> buf.update(buf2)
|
||||
array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
|
||||
0., 0., 0., 0., 0., 0., 0.])
|
||||
@ -97,31 +97,102 @@ Policy
|
||||
|
||||
Tianshou aims to modularizing RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`.
|
||||
|
||||
For demonstration, we use the source code of policy gradient :class:`~tianshou.policy.PGPolicy`. Policy gradient computes each frame's return as:
|
||||
A policy class typically has four parts:
|
||||
|
||||
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including coping the target network and so on;
|
||||
* :meth:`~tianshou.policy.BasePolicy.__call__`: compute action with given observation;
|
||||
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the replay buffer (this function can interact with replay buffer);
|
||||
* :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of data.
|
||||
|
||||
Take 2-step return DQN as an example. The 2-step return DQN compute each frame's return as:
|
||||
|
||||
.. math::
|
||||
|
||||
G_t = \sum_{i=t}^T \gamma^{i - t}r_i = r_t + \gamma r_{t + 1} + \cdots + \gamma^{T - t} r_T
|
||||
G_t = r_t + \gamma r_{t + 1} + \gamma^2 \max_a Q(s_{t + 2}, a)
|
||||
|
||||
, where :math:`T` is the terminal timestep, :math:`\gamma` is the discount factor, :math:`\gamma \in (0, 1]`.
|
||||
Here is the pseudocode showing the training process **without Tianshou framework**:
|
||||
::
|
||||
|
||||
This process is done in ``process_fn``
|
||||
# pseudocode, cannot work
|
||||
buffer = Buffer(size=10000)
|
||||
s = env.reset()
|
||||
agent = DQN()
|
||||
for i in range(int(1e6)):
|
||||
a = agent.compute_action(s)
|
||||
s_, r, d, _ = env.step(a)
|
||||
buffer.store(s, a, s_, r, d)
|
||||
s = s_
|
||||
if i % 1000 == 0:
|
||||
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64)
|
||||
# compute 2-step returns. How?
|
||||
b_ret = compute_2_step_return(buffer, b_r, b_d, ...)
|
||||
# update DQN
|
||||
agent.update(b_s, b_a, b_s_, b_r, b_d, b_ret)
|
||||
|
||||
Thus, we need a time-dependent interface for calculating the 2-step return. :meth:`~tianshou.policy.BasePolicy.process_fn` provides this interface by giving the replay buffer, the sample index, and the sample batch data. Since we store all the data in the order of time, you can simply compute the 2-step return as:
|
||||
::
|
||||
|
||||
class DQN_2step(BasePolicy):
|
||||
"""some code"""
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
buffer_len = len(buffer)
|
||||
batch_2 = buffer[(indice + 2) % buffer_len]
|
||||
# this will return a batch data where batch_2.obs is s_t+2
|
||||
# we can also get s_t+2 through:
|
||||
# batch_2_obs = buffer.obs[(indice + 2) % buffer_len]
|
||||
Q = self(batch_2, eps=0) # shape: [batchsize, action_shape]
|
||||
maxQ = Q.max(dim=-1)
|
||||
batch.returns = batch.rew \
|
||||
+ self._gamma * buffer.rew[(indice + 1) % buffer_len] \
|
||||
+ self._gamma ** 2 * maxQ
|
||||
return batch
|
||||
|
||||
This code does not consider the done flag, so it may not work very well. It shows two ways to get :math:`s_{t + 2}` from the replay buffer easily in :meth:`~tianshou.policy.BasePolicy.process_fn`.
|
||||
|
||||
For other method, you can check out the API documentation for more detail. We give a high-level explanation through the same pseudocode:
|
||||
::
|
||||
|
||||
# pseudocode, cannot work
|
||||
buffer = Buffer(size=10000)
|
||||
s = env.reset()
|
||||
agent = DQN() # done in policy.__init__(...)
|
||||
for i in range(int(1e6)): # done in trainer
|
||||
a = agent.compute_action(s) # done in policy.__call__(batch, ...)
|
||||
s_, r, d, _ = env.step(a) # done in collector.collect(...)
|
||||
buffer.store(s, a, s_, r, d) # done in collector.collect(...)
|
||||
s = s_ # done in collector.collect(...)
|
||||
if i % 1000 == 0: # done in trainer
|
||||
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # done in collector.sample(batch_size)
|
||||
# compute 2-step returns. How?
|
||||
b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # done in policy.process_fn(batch, buffer, indice)
|
||||
# update DQN
|
||||
agent.update(b_s, b_a, b_s_, b_r, b_d, b_ret) # done in policy.learn(batch, ...)
|
||||
|
||||
|
||||
Collector
|
||||
---------
|
||||
|
||||
TODO
|
||||
The collector enables the policy to interact with different types of environments conveniently.
|
||||
|
||||
* :meth:`~tianshou.data.Collector.collect`: let the policy perform (at least) a specified number of steps ``n_step`` or episodes ``n_episode`` and store the data in the replay buffer;
|
||||
* :meth:`~tianshou.data.Collector.sample`: sample a data batch from replay buffer; it will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data.
|
||||
|
||||
Why do we mention **at least** here? For a single environment, the collector will finish exactly ``n_step`` or ``n_episode``. However, for multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically.
|
||||
|
||||
The solution is to add some cache buffers inside the collector. Once collecting **a full episode of trajectory**, it will move the stored data from the cache buffer to the main buffer. To satisfy this condition, the collector will interact with environments that may exceed the given step number or episode number.
|
||||
|
||||
The general explanation is listed in the pseudocode above. Other usages of collector are listed in :class:`~tianshou.data.Collector` documentation.
|
||||
|
||||
|
||||
Trainer
|
||||
-------
|
||||
|
||||
Once you have a collector and a policy, you can start writing the training method for your RL agent. Trainer, to be honest, is a simple wrapper. It helps you save energy for writing the training loop.
|
||||
Once you have a collector and a policy, you can start writing the training method for your RL agent. Trainer, to be honest, is a simple wrapper. It helps you save energy for writing the training loop. You can also construct your own trainer: :ref:`customized_trainer`.
|
||||
|
||||
Tianshou has two types of trainer: :meth:`~tianshou.trainer.onpolicy_trainer` and :meth:`~tianshou.trainer.offpolicy_trainer`, corresponding to on-policy algorithms (such as Policy Gradient) and off-policy algorithms (such as DQN). Please check out the API documentation for the usage.
|
||||
|
||||
There will be more types of trainer, for instance, multi-agent trainer.
|
||||
There will be more types of trainers, for instance, multi-agent trainer.
|
||||
|
||||
|
||||
Conclusion
|
||||
|
@ -178,6 +178,8 @@ Watch the Agent's Performance
|
||||
collector.close()
|
||||
|
||||
|
||||
.. _customized_trainer:
|
||||
|
||||
Train a Policy with Customized Codes
|
||||
------------------------------------
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user