test api doc
This commit is contained in:
parent
0b08a41610
commit
0acd0d164c
27
README.md
27
README.md
@ -13,7 +13,7 @@
|
||||
[](https://github.com/thu-ml/tianshou/blob/master/LICENSE)
|
||||
[](https://gitter.im/thu-ml/tianshou?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||
|
||||
**Tianshou** ([天授]([https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88/9342](https://baike.baidu.com/item/天授/9342))) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed framework and pythonic API for building the deep reinforcement learning agent. The supported interface algorithms include:
|
||||
**Tianshou** ([天授](https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88)) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed framework and pythonic API for building the deep reinforcement learning agent. The supported interface algorithms include:
|
||||
|
||||
|
||||
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
|
||||
@ -25,7 +25,7 @@
|
||||
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
|
||||
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
|
||||
|
||||
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms.
|
||||
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms. Our team is working on supporting more algorithms and more scenarios on Tianshou in this period of development.
|
||||
|
||||
## Installation
|
||||
|
||||
@ -62,13 +62,17 @@ The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/ma
|
||||
|
||||
### Fast-speed
|
||||
|
||||
Tianshou is a lightweight but high-speed reinforcement learning platform. For example, here is a test on a laptop (i7-8750H + GTX1060). It only uses 3 seconds for training a agent based on vanilla policy gradient on the CartPole-v0 task: `python3 test/discrete/test_pg.py --seed 0 --render 0.03` (seed may be different across different platform and device)
|
||||
Tianshou is a lightweight but high-speed reinforcement learning platform. For example, here is a test on a laptop (i7-8750H + GTX1060). It only uses 3 seconds for training an agent based on vanilla policy gradient on the CartPole-v0 task: (seed may be different across different platform and device)
|
||||
|
||||
```python
|
||||
python3 test/discrete/test_pg.py --seed 0 --render 0.03
|
||||
```
|
||||
|
||||
<div align="center">
|
||||
<img src="docs/_static/images/testpg.gif"></a>
|
||||
</div>
|
||||
|
||||
We select some of famous reinforcement learning platforms: 2 GitHub repo with most stars in all RL platforms (Baselines, RLlib) and 2 GitHub repo with most stars in PyTorch RL platforms (PyTorch DRL and rlpyt). Here is the benchmark result for other algorithms and platforms on toy scenarios: (tested on the same laptop as mentioned above)
|
||||
We select some of famous reinforcement learning platforms: 2 GitHub repos with most stars in all RL platforms (OpenAI Baseline and RLlib) and 2 GitHub repos with most stars in PyTorch RL platforms (PyTorch DRL and rlpyt). Here is the benchmark result for other algorithms and platforms on toy scenarios: (tested on the same laptop as mentioned above)
|
||||
|
||||
| RL Platform | [Tianshou](https://github.com/thu-ml/tianshou) | [Baselines](https://github.com/openai/baselines) | [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [PyTorch DRL](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch) | [rlpyt](https://github.com/astooke/rlpyt) |
|
||||
| --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||
@ -111,13 +115,13 @@ Within these API, we can interact with different policies conveniently.
|
||||
|
||||
### Elegant and Flexible
|
||||
|
||||
Currently, the overall code of Tianshou platform is less than 1500 lines. Most of the implemented algorithms are less than 100 lines of python code. It is quite easy to go through the framework and understand how it works. We provide many flexible API as you wish, for instance, if you want to use your policy to interact with environment with `n` steps:
|
||||
Currently, the overall code of Tianshou platform is less than 1500 lines without environment wrappers for Atari and Mujoco. Most of the implemented algorithms are less than 100 lines of python code. It is quite easy to go through the framework and understand how it works. We provide many flexible API as you wish, for instance, if you want to use your policy to interact with the environment with (at least) `n` steps:
|
||||
|
||||
```python
|
||||
result = collector.collect(n_step=n)
|
||||
```
|
||||
|
||||
If you have 3 environment in total and want to collect 1 episode in the first environment, 3 for third environment:
|
||||
If you have 3 environments in total and want to collect 1 episode in the first environment, 3 for the third environment:
|
||||
|
||||
```python
|
||||
result = collector.collect(n_episode=[1, 0, 3])
|
||||
@ -244,7 +248,7 @@ Tianshou is still under development. More algorithms and features are going to b
|
||||
|
||||
## TODO
|
||||
|
||||
- [x] More examples on [mujoco, atari] benchmark
|
||||
- [ ] More examples on [mujoco, atari] benchmark
|
||||
- [ ] More algorithms
|
||||
- [ ] Prioritized replay buffer
|
||||
- [ ] RNN support
|
||||
@ -267,9 +271,8 @@ If you find Tianshou useful, please cite it in your publications.
|
||||
}
|
||||
```
|
||||
|
||||
## Acknowledgment
|
||||
|
||||
Tianshou was previously a reinforcement learning platform based on TensorFlow. You can check out the branch [`priv`](https://github.com/thu-ml/tianshou/tree/priv) for more detail. Many thanks to [Haosheng Zou](https://github.com/HaoshengZou)'s pioneering work for `Tianshou<=0.1.1`.
|
||||
|
||||
We would like to thank [TSAIL](http://ml.cs.tsinghua.edu.cn/) and [Institute for Artificial Intelligence, Tsinghua University](http://ai.tsinghua.edu.cn/) for providing such an excellent AI research platform.
|
||||
|
||||
## Miscellaneous
|
||||
|
||||
Tianshou was previously a reinforcement learning platform based on TensorFlow. You can checkout the branch [`priv`](https://github.com/thu-ml/tianshou/tree/priv) for more detail.
|
||||
|
||||
|
7
docs/api/tianshou.data.rst
Normal file
7
docs/api/tianshou.data.rst
Normal file
@ -0,0 +1,7 @@
|
||||
tianshou.data
|
||||
=============
|
||||
|
||||
.. automodule:: tianshou.data
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
7
docs/api/tianshou.env.rst
Normal file
7
docs/api/tianshou.env.rst
Normal file
@ -0,0 +1,7 @@
|
||||
tianshou.env
|
||||
============
|
||||
|
||||
.. automodule:: tianshou.env
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
7
docs/api/tianshou.exploration.rst
Normal file
7
docs/api/tianshou.exploration.rst
Normal file
@ -0,0 +1,7 @@
|
||||
tianshou.exploration
|
||||
====================
|
||||
|
||||
.. automodule:: tianshou.exploration
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
7
docs/api/tianshou.policy.rst
Normal file
7
docs/api/tianshou.policy.rst
Normal file
@ -0,0 +1,7 @@
|
||||
tianshou.policy
|
||||
===============
|
||||
|
||||
.. automodule:: tianshou.policy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
7
docs/api/tianshou.trainer.rst
Normal file
7
docs/api/tianshou.trainer.rst
Normal file
@ -0,0 +1,7 @@
|
||||
tianshou.trainer
|
||||
================
|
||||
|
||||
.. automodule:: tianshou.trainer
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
7
docs/api/tianshou.utils.rst
Normal file
7
docs/api/tianshou.utils.rst
Normal file
@ -0,0 +1,7 @@
|
||||
tianshou.utils
|
||||
==============
|
||||
|
||||
.. automodule:: tianshou.utils
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
10
docs/conf.py
10
docs/conf.py
@ -14,16 +14,12 @@
|
||||
# import sys
|
||||
# sys.path.insert(0, os.path.abspath('.'))
|
||||
|
||||
|
||||
import tianshou
|
||||
import sphinx_rtd_theme
|
||||
|
||||
import re
|
||||
from os import path
|
||||
|
||||
here = path.abspath(path.dirname(__file__))
|
||||
|
||||
# Get the version string
|
||||
with open(path.join(here, '..', 'tianshou', '__init__.py')) as f:
|
||||
version = re.search(r'__version__ = \'(.*?)\'', f.read()).group(1)
|
||||
version = tianshou.__version__
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
|
@ -6,7 +6,7 @@
|
||||
Welcome to Tianshou!
|
||||
====================
|
||||
|
||||
**Tianshou** (天授) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed framework and pythonic API for building the deep reinforcement learning agent. The supported interface algorithms include:
|
||||
**Tianshou** (`天授 <https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88>`_) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed framework and pythonic API for building the deep reinforcement learning agent. The supported interface algorithms include:
|
||||
|
||||
* `Policy Gradient (PG) <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
|
||||
* `Deep Q-Network (DQN) <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
|
||||
@ -20,6 +20,7 @@ Welcome to Tianshou!
|
||||
|
||||
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms.
|
||||
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
@ -28,14 +29,11 @@ Tianshou is currently hosted on `PyPI <https://pypi.org/project/tianshou/>`_. Yo
|
||||
|
||||
pip3 install tianshou
|
||||
|
||||
|
||||
You can also install with the newest version through GitHub:
|
||||
|
||||
::
|
||||
|
||||
pip3 install git+https://github.com/thu-ml/tianshou.git@master
|
||||
|
||||
|
||||
After installation, open your python console and type
|
||||
::
|
||||
|
||||
@ -56,6 +54,12 @@ If no error occurs, you have successfully installed Tianshou.
|
||||
:maxdepth: 1
|
||||
:caption: API Docs
|
||||
|
||||
api/tianshou.data
|
||||
api/tianshou.env
|
||||
api/tianshou.policy
|
||||
api/tianshou.trainer
|
||||
api/tianshou.exploration
|
||||
api/tianshou.utils
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
@ -1,11 +1,11 @@
|
||||
Basic Concepts in Tianshou
|
||||
==========================
|
||||
|
||||
Tianshou has split a Reinforcement Learning agent training procedure into these parts: trainer, collector, policy, and data buffer. The general control flow can be discribed as:
|
||||
Tianshou splits a Reinforcement Learning agent training procedure into these parts: trainer, collector, policy, and data buffer. The general control flow can be described as:
|
||||
|
||||
.. image:: ../_static/images/concepts_arch.png
|
||||
:align: center
|
||||
:height: 250
|
||||
:height: 300
|
||||
|
||||
|
||||
Data Batch
|
||||
@ -27,16 +27,16 @@ Tianshou provides :class:`~tianshou.data.Batch` as the internal data structure t
|
||||
|
||||
In short, you can define a :class:`~tianshou.data.Batch` with any key-value pair.
|
||||
|
||||
Current implementation of Tianshou typically use 6 keys:
|
||||
Current implementation of Tianshou typically use 6 keys in :class:`~tianshou.data.Batch`:
|
||||
|
||||
* ``obs``: observation of step t;
|
||||
* ``act``: action of step t;
|
||||
* ``rew``: reward of step t;
|
||||
* ``done``: the done flag of step t;
|
||||
* ``obs_next``: observation of step t+1;
|
||||
* ``info``: info of step t (in ``gym.Env``, the ``env.step()`` function return 4 arguments, and the last one is ``info``);
|
||||
* ``obs``: observation of step :math:`t` ;
|
||||
* ``act``: action of step :math:`t` ;
|
||||
* ``rew``: reward of step :math:`t` ;
|
||||
* ``done``: the done flag of step :math:`t` ;
|
||||
* ``obs_next``: observation of step :math:`t+1` ;
|
||||
* ``info``: info of step :math:`t` (in ``gym.Env``, the ``env.step()`` function return 4 arguments, and the last one is ``info``);
|
||||
|
||||
:class:`~tianshou.data.Batch` has other methods:
|
||||
:class:`~tianshou.data.Batch` has other methods, including ``__getitem__``, ``append``, and ``split``:
|
||||
::
|
||||
|
||||
>>> data = Batch(obs=np.array([0, 11, 22]), rew=np.array([6, 6, 6]))
|
||||
@ -51,7 +51,7 @@ Current implementation of Tianshou typically use 6 keys:
|
||||
|
||||
>>> # split whole data into multiple small batch
|
||||
>>> for d in data.split(size=2, permute=False):
|
||||
>>> print(d.obs, d.rew)
|
||||
... print(d.obs, d.rew)
|
||||
[ 0 11] [6 6]
|
||||
[22 0] [6 6]
|
||||
[11 22] [6 6]
|
||||
@ -66,7 +66,7 @@ Data Buffer
|
||||
>>> from tianshou.data import ReplayBuffer
|
||||
>>> buf = ReplayBuffer(size=20)
|
||||
>>> for i in range(3):
|
||||
>>> buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
|
||||
... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
|
||||
>>> buf.obs
|
||||
# since we set size = 20, len(buf.obs) == 20.
|
||||
array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||
@ -74,7 +74,7 @@ Data Buffer
|
||||
|
||||
>>> buf2 = ReplayBuffer(size=10)
|
||||
>>> for i in range(15):
|
||||
>>> buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
|
||||
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
|
||||
>>> buf2.obs
|
||||
# since its size = 10, it only stores the last 10 steps' result.
|
||||
array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.])
|
||||
@ -103,9 +103,9 @@ For demonstration, we use the source code of policy gradient :class:`~tianshou.p
|
||||
|
||||
G_t = \sum_{i=t}^T \gamma^{i - t}r_i = r_t + \gamma r_{t + 1} + \cdots + \gamma^{T - t} r_T
|
||||
|
||||
, where T is the terminal timestep, :math:`\gamma` is the discount factor :math:`\in [0, 1]`.
|
||||
, where :math:`T` is the terminal timestep, :math:`\gamma` is the discount factor, :math:`\gamma \in (0, 1]`.
|
||||
|
||||
TODO
|
||||
This process is done in ``process_fn``
|
||||
|
||||
|
||||
Collector
|
||||
|
@ -12,7 +12,7 @@ Contrary to existing Deep RL libraries such as `RLlib <https://github.com/ray-pr
|
||||
Make an Environment
|
||||
-------------------
|
||||
|
||||
First of all, you have to make an environment for your agent to act in. For the environment interfaces, we follow the convention of `OpenAI Gym <https://github.com/openai/gym>`_. In your Python code, simply import Tianshou and make the environment
|
||||
First of all, you have to make an environment for your agent to interact with. For environment interfaces, we follow the convention of `OpenAI Gym <https://github.com/openai/gym>`_. In your Python code, simply import Tianshou and make the environment:
|
||||
::
|
||||
|
||||
import gym
|
||||
@ -40,7 +40,7 @@ Tianshou supports parallel sampling for all algorithms. It provides three types
|
||||
|
||||
Here, we set up 8 environments in ``train_envs`` and 100 environments in ``test_envs``.
|
||||
|
||||
For demonstrating, here we use the second block of codes.
|
||||
For the demonstration, here we use the second block of codes.
|
||||
|
||||
|
||||
Build the Network
|
||||
@ -121,9 +121,9 @@ The meaning of each parameter is as follows:
|
||||
|
||||
* ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``;
|
||||
* ``step_per_epoch``: The number of step for updating policy network in one epoch;
|
||||
* ``collect_per_step``: The number of frame the collector would collect before the network update. For example, the code above means "collect 10 frames and do one policy network update";
|
||||
* ``collect_per_step``: The number of frames the collector would collect before the network update. For example, the code above means "collect 10 frames and do one policy network update";
|
||||
* ``episode_per_test``: The number of episode for one policy evaluation.
|
||||
* ``batch_size``: The batch size of sampled data, which is going to feed in the policy network.
|
||||
* ``batch_size``: The batch size of sample data, which is going to feed in the policy network.
|
||||
* ``train_fn``: A function receives the current number of epoch index and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training".
|
||||
* ``test_fn``: A function receives the current number of epoch index and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing".
|
||||
* ``stop_fn``: A function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal.
|
||||
@ -167,10 +167,10 @@ Since the policy inherits the ``torch.nn.Module`` class, saving and loading the
|
||||
policy.load_state_dict(torch.load('dqn.pth'))
|
||||
|
||||
|
||||
Watch the Performance
|
||||
---------------------
|
||||
Watch the Agent's Performance
|
||||
-----------------------------
|
||||
|
||||
:class:`~tianshou.data.Collector` supports rendering. Here is the example of watching the agent in 35 FPS:
|
||||
:class:`~tianshou.data.Collector` supports rendering. Here is the example of watching the agent's performance in 35 FPS:
|
||||
::
|
||||
|
||||
collector = ts.data.Collector(policy, env)
|
||||
@ -183,7 +183,7 @@ Train a Policy with Customized Codes
|
||||
|
||||
"I don't want to use your provided trainer. I want to customize it!"
|
||||
|
||||
No problem! Here is the usage:
|
||||
No problem! Tianshou supports user-defined training code. Here is the usage:
|
||||
::
|
||||
|
||||
# pre-collect 5000 frames with random action before training
|
||||
|
15
tianshou/env/vecenv.py
vendored
15
tianshou/env/vecenv.py
vendored
@ -200,8 +200,7 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
|
||||
def step(self, action):
|
||||
assert len(action) == self.env_num
|
||||
result_obj = [e.step.remote(a) for e, a in zip(self.envs, action)]
|
||||
result = [ray.get(r) for r in result_obj]
|
||||
result = ray.get([e.step.remote(a) for e, a in zip(self.envs, action)])
|
||||
self._obs, self._rew, self._done, self._info = zip(*result)
|
||||
self._obs = np.stack(self._obs)
|
||||
self._rew = np.stack(self._rew)
|
||||
@ -212,7 +211,7 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
def reset(self, id=None):
|
||||
if id is None:
|
||||
result_obj = [e.reset.remote() for e in self.envs]
|
||||
self._obs = np.stack([ray.get(r) for r in result_obj])
|
||||
self._obs = np.stack(ray.get(result_obj))
|
||||
else:
|
||||
result_obj = []
|
||||
if np.isscalar(id):
|
||||
@ -230,16 +229,12 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
seed = [seed + _ for _ in range(self.env_num)]
|
||||
elif seed is None:
|
||||
seed = [seed] * self.env_num
|
||||
result_obj = [e.seed.remote(s) for e, s in zip(self.envs, seed)]
|
||||
return [ray.get(r) for r in result_obj]
|
||||
return ray.get([e.seed.remote(s) for e, s in zip(self.envs, seed)])
|
||||
|
||||
def render(self, **kwargs):
|
||||
if not hasattr(self.envs[0], 'render'):
|
||||
return
|
||||
result_obj = [e.render.remote(**kwargs) for e in self.envs]
|
||||
return [ray.get(r) for r in result_obj]
|
||||
return ray.get([e.render.remote(**kwargs) for e in self.envs])
|
||||
|
||||
def close(self):
|
||||
result_obj = [e.close.remote() for e in self.envs]
|
||||
for r in result_obj:
|
||||
ray.get(r)
|
||||
return ray.get([e.close.remote() for e in self.envs])
|
||||
|
Loading…
x
Reference in New Issue
Block a user