This commit is contained in:
Trinkle23897 2020-03-29 10:22:03 +08:00
parent a326d30739
commit d9e4b9d16f
8 changed files with 100 additions and 24 deletions

1
.gitignore vendored
View File

@ -140,3 +140,4 @@ dmypy.json
flake8.sh
log/
MUJOCO_LOG.TXT
*.pth

View File

@ -2,6 +2,7 @@
<a href="http://tianshou.readthedocs.io"><img width="300px" height="auto" src="docs/_static/images/tianshou-logo.png"></a>
</div>
[![PyPI](https://img.shields.io/pypi/v/tianshou)](https://pypi.org/project/tianshou/)
[![Unittest](https://github.com/thu-ml/tianshou/workflows/Unittest/badge.svg?branch=master)](https://github.com/thu-ml/tianshou/actions)
[![Documentation Status](https://readthedocs.org/projects/tianshou/badge/?version=latest)](https://tianshou.readthedocs.io)
@ -15,7 +16,7 @@
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) + n-step returns
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) with n-step returns
- [Advantage Actor-Critic (A2C)](http://incompleteideas.net/book/RLbook2018.pdf)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
@ -24,7 +25,7 @@
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms.
Tianshou is still under development. More algorithms are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out [CONTRIBUTING.md](/CONTRIBUTING.md).
Tianshou is still under development. More algorithms are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out [CONTRIBUTING.md](https://github.com/thu-ml/tianshou/blob/master/CONTRIBUTING.md).
## Installation
@ -34,11 +35,26 @@ Tianshou is currently hosted on [PyPI](https://pypi.org/project/tianshou/). You
pip3 install tianshou
```
You can also install with the newest version through GitHub:
```bash
pip3 install git+https://github.com/thu-ml/tianshou.git@master
```
After installation, open your python console and type
```python
import tianshou as ts
print(ts.__version__)
```
If no error occurs, you have successfully installed Tianshou.
## Documentation
The tutorials and API documentation are hosted on [https://tianshou.readthedocs.io](https://tianshou.readthedocs.io).
The example scripts are under [test/](/test/) folder and [examples/](/examples/) folder.
The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/master/test) folder and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folder.
## Why Tianshou?
@ -50,7 +66,7 @@ Tianshou is a lightweight but high-speed reinforcement learning platform. For ex
<img src="docs/_static/images/testpg.gif"></a>
</div>
We select some of famous (>1k stars) reinforcement learning platforms. Here is the benchmark result for other algorithms and platforms on toy scenarios:
We select some of famous (>1k stars) reinforcement learning platforms. Here is the benchmark result for other algorithms and platforms on toy scenarios: (tested on the same laptop as mentioned above)
| RL Platform | [Tianshou](https://github.com/thu-ml/tianshou) | [Baselines](https://github.com/openai/baselines) | [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [PyTorch DRL](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch) | [rlpyt](https://github.com/astooke/rlpyt) |
| --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
@ -115,18 +131,18 @@ You can check out the [documentation](https://tianshou.readthedocs.io) for furth
## Quick Start
This is an example of Deep Q Network. You can also run the full script under [test/discrete/test_dqn.py](/test/discrete/test_dqn.py).
This is an example of Deep Q Network. You can also run the full script under [test/discrete/test_dqn.py](https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_dqn.py).
First, import the relevant packages:
First, import some relevant packages:
```python
import gym, torch, numpy as np, torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DQNPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env import VectorEnv, SubprocVectorEnv
```
Define some hyper-parameters:
@ -147,14 +163,15 @@ buffer_size = 20000
writer = SummaryWriter('log/dqn') # tensorboard is also supported!
```
Make envs:
Make environments:
```python
env = gym.make(task)
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
train_envs = SubprocVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
test_envs = SubprocVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
# you can also try with SubprocVectorEnv
train_envs = VectorEnv([lambda: gym.make(task) for _ in range(train_num)])
test_envs = VectorEnv([lambda: gym.make(task) for _ in range(test_num)])
```
Define the network:
@ -197,6 +214,7 @@ result = offpolicy_trainer(
test_num, batch_size, train_fn=lambda e: policy.set_eps(eps_train),
test_fn=lambda e: policy.set_eps(eps_test),
stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer, task=task)
print(f'Finished training! Use {result["duration"]}')
```
Saving / loading trained policy (it's exactly the same as PyTorch nn.module):
@ -211,6 +229,7 @@ Watch the performance with 35 FPS:
```python3
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1/35)
collector.close()
```
Looking at the result saved in tensorboard: (on bash script)

View File

@ -44,7 +44,6 @@ extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.coverage',
'sphinx.ext.imgmath',
'sphinx.ext.mathjax',
@ -77,7 +76,7 @@ html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_logo = '_static/images/tianshou-logo.svg'
html_logo = '_static/images/tianshou-logo.png'
def setup(app):

7
docs/contributing.rst Normal file
View File

@ -0,0 +1,7 @@
Contributing
============
We always welcome contributions to help make Tianshou better. If you would like to contribute, please check out the `guidelines <https://github.com/thu-ml/tianshou/blob/master/CONTRIBUTING.md>`_ here. Below are an incomplete list of our contributors (find more on `this page <https://github.com/thu-ml/tianshou/graphs/contributors>`_).
* Jiayi Weng (`Trinkle23897 <https://github.com/Trinkle23897>`_)
* Minghao Zhang (`Mehooz <https://github.com/Mehooz>`_)

View File

@ -3,17 +3,67 @@
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
Welcome to Tianshou's documentation!
====================================
Welcome to Tianshou!
====================
**Tianshou** (天授) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed framework and pythonic API for building the deep reinforcement learning agent. The supported interface algorithms include:
* `Policy Gradient (PG) <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
* `Deep Q-Network (DQN) <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
* `Double DQN (DDQN) <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns
* `Advantage Actor-Critic (A2C) <http://incompleteideas.net/book/RLbook2018.pdf>`_
* `Deep Deterministic Policy Gradient (DDPG) <https://arxiv.org/pdf/1509.02971.pdf>`_
* `Proximal Policy Optimization (PPO) <https://arxiv.org/pdf/1707.06347.pdf>`_
* `Twin Delayed DDPG (TD3) <https://arxiv.org/pdf/1802.09477.pdf>`_
* `Soft Actor-Critic (SAC) <https://arxiv.org/pdf/1812.05905.pdf>`_
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms.
Installation
------------
Tianshou is currently hosted on `PyPI <https://pypi.org/project/tianshou/>`_. You can simply install Tianshou with the following command:
::
pip3 install tianshou
You can also install with the newest version through GitHub:
::
pip3 install git+https://github.com/thu-ml/tianshou.git@master
After installation, open your python console and type
::
import tianshou as ts
print(ts.__version__)
If no error occurs, you have successfully installed Tianshou.
.. toctree::
:maxdepth: 2
:caption: Contents:
:maxdepth: 1
:caption: Tutorials
.. toctree::
:maxdepth: 1
:caption: API Docs
.. toctree::
:maxdepth: 1
:caption: Community
contributing
Indices and tables
==================
------------------
* :ref:`genindex`
* :ref:`modindex`

View File

@ -6,7 +6,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DQNPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.env import VectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
@ -48,10 +48,10 @@ def test_dqn(args=get_args()):
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
train_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
test_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)

View File

@ -7,7 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import PGPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.env import VectorEnv
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Batch, Collector, ReplayBuffer
@ -99,10 +99,10 @@ def test_pg(args=get_args()):
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
train_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
test_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)

View File

@ -1,7 +1,7 @@
from tianshou import data, env, utils, policy, trainer, \
exploration
__version__ = '0.2.0post1'
__version__ = '0.2.0post2'
__all__ = [
'env',
'data',