Tianshou/README.md
youkaichao 8c32d99c65
Add multi-agent example: tic-tac-toe (#122)
* make fileds with empty Batch rather than None after reset

* dummy code

* remove dummy

* add reward_length argument for collector

* Improve Batch (#126)

* make sure the key type of Batch is string, and add unit tests

* add is_empty() function and unit tests

* enable cat of mixing dict and Batch, just like stack

* bugfix for reward_length

* add get_final_reward_fn argument to collector to deal with marl

* minor polish

* remove multibuf

* minor polish

* improve and implement Batch.cat_

* bugfix for buffer.sample with field impt_weight

* restore the usage of a.cat_(b)

* fix 2 bugs in batch and add corresponding unittest

* code fix for update

* update is_empty to recognize empty over empty; bugfix for len

* bugfix for update and add testcase

* add testcase of update

* make fileds with empty Batch rather than None after reset

* dummy code

* remove dummy

* add reward_length argument for collector

* bugfix for reward_length

* add get_final_reward_fn argument to collector to deal with marl

* make sure the key type of Batch is string, and add unit tests

* add is_empty() function and unit tests

* enable cat of mixing dict and Batch, just like stack

* dummy code

* remove dummy

* add multi-agent example: tic-tac-toe

* move TicTacToeEnv to a separate file

* remove dummy MANet

* code refactor

* move tic-tac-toe example to test

* update doc with marl-example

* fix docs

* reduce the threshold

* revert

* update player id to start from 1 and change player to agent; keep coding

* add reward_length argument for collector

* Improve Batch (#128)

* minor polish

* improve and implement Batch.cat_

* bugfix for buffer.sample with field impt_weight

* restore the usage of a.cat_(b)

* fix 2 bugs in batch and add corresponding unittest

* code fix for update

* update is_empty to recognize empty over empty; bugfix for len

* bugfix for update and add testcase

* add testcase of update

* fix docs

* fix docs

* fix docs [ci skip]

* fix docs [ci skip]

Co-authored-by: Trinkle23897 <463003665@qq.com>

* refact

* re-implement Batch.stack and add testcases

* add doc for Batch.stack

* reward_metric

* modify flag

* minor fix

* reuse _create_values and refactor stack_ & cat_

* fix pep8

* fix reward stat in collector

* fix stat of collector, simplify test/base/env.py

* fix docs

* minor fix

* raise exception for stacking with partial keys and axis!=0

* minor fix

* minor fix

* minor fix

* marl-examples

* add condense; bugfix for torch.Tensor; code refactor

* marl example can run now

* enable tic tac toe with larger board size and win-size

* add test dependency

* Fix padding of inconsistent keys with Batch.stack and Batch.cat (#130)

* re-implement Batch.stack and add testcases

* add doc for Batch.stack

* reuse _create_values and refactor stack_ & cat_

* fix pep8

* fix docs

* raise exception for stacking with partial keys and axis!=0

* minor fix

* minor fix

Co-authored-by: Trinkle23897 <463003665@qq.com>

* stash

* let agent learn to play as agent 2 which is harder

* code refactor

* Improve collector (#125)

* remove multibuf

* reward_metric

* make fileds with empty Batch rather than None after reset

* many fixes and refactor
Co-authored-by: Trinkle23897 <463003665@qq.com>

* marl for tic-tac-toe and general gomoku

* update default gamma to 0.1 for tic tac toe to win earlier

* fix name typo; change default game config; add rew_norm option

* fix pep8

* test commit

* mv test dir name

* add rew flag

* fix torch.optim import error and madqn rew_norm

* remove useless kwargs

* Vector env enable select worker (#132)

* Enable selecting worker for vector env step method.

* Update collector to match new vecenv selective worker behavior.

* Bug fix.

* Fix rebase

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>

* show the last move of tictactoe by capital letters

* add multi-agent tutorial

* fix link

* Standardized behavior of Batch.cat and misc code refactor (#137)

* code refactor; remove unused kwargs; add reward_normalization for dqn

* bugfix for __setitem__ with torch.Tensor; add Batch.condense

* minor fix

* support cat with empty Batch

* remove the dependency of is_empty on len; specify the semantic of empty Batch by test cases

* support stack with empty Batch

* remove condense

* refactor code to reflect the shared / partial / reserved categories of keys

* add is_empty(recursive=False)

* doc fix

* docfix and bugfix for _is_batch_set

* add doc for key reservation

* bugfix for algebra operators

* fix cat with lens hint

* code refactor

* bugfix for storing None

* use ValueError instead of exception

* hide lens away from users

* add comment for __cat

* move the computation of the initial value of lens in cat_ itself.

* change the place of doc string

* doc fix for Batch doc string

* change recursive to recurse

* doc string fix

* minor fix for batch doc

* write tutorials to specify the standard of Batch (#142)

* add doc for len exceptions

* doc move; unify is_scalar_value function

* remove some issubclass check

* bugfix for shape of Batch(a=1)

* keep moving doc

* keep writing batch tutorial

* draft version of Batch tutorial done

* improving doc

* keep improving doc

* batch tutorial done

* rename _is_number

* rename _is_scalar

* shape property do not raise exception

* restore some doc string

* grammarly [ci skip]

* grammarly + fix warning of building docs

* polish docs

* trim and re-arrange batch tutorial

* go straight to the point

* minor fix for batch doc

* add shape / len in basic usage

* keep improving tutorial

* unify _to_array_with_correct_type to remove duplicate code

* delegate type convertion to Batch.__init__

* further delegate type convertion to Batch.__init__

* bugfix for setattr

* add a _parse_value function

* remove dummy function call

* polish docs

Co-authored-by: Trinkle23897 <463003665@qq.com>

* bugfix for mapolicy

* pretty code

* remove debug code; remove condense

* doc fix

* check before get_agents in tutorials/tictactoe

* tutorial

* fix

* minor fix for batch doc

* minor polish

* faster test_ttt

* improve tic-tac-toe environment

* change default epoch and step-per-epoch for tic-tac-toe

* fix mapolicy

* minor polish for mapolicy

* 90% to 80% (need to change the tutorial)

* win rate

* show step number at board

* simplify mapolicy

* minor polish for mapolicy

* remove MADQN

* fix pep8

* change legal_actions to mask (need to update docs)

* simplify maenv

* fix typo

* move basevecenv to single file

* separate RandomAgent

* update docs

* grammarly

* fix pep8

* win rate typo

* format in cheatsheet

* use bool mask directly

* update doc for boolean mask

Co-authored-by: Trinkle23897 <463003665@qq.com>
Co-authored-by: Alexis DUBURCQ <alexis.duburcq@gmail.com>
Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
2020-07-21 14:59:49 +08:00

290 lines
18 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<div align="center">
<a href="http://tianshou.readthedocs.io"><img width="300px" height="auto" src="docs/_static/images/tianshou-logo.png"></a>
</div>
---
[![PyPI](https://img.shields.io/pypi/v/tianshou)](https://pypi.org/project/tianshou/)
[![Read the Docs](https://img.shields.io/readthedocs/tianshou)](https://tianshou.readthedocs.io/en/latest)
[![Read the Docs](https://img.shields.io/readthedocs/tianshou-docs-zh-cn?label=%E4%B8%AD%E6%96%87%E6%96%87%E6%A1%A3)](https://tianshou.readthedocs.io/zh/latest/)
[![Unittest](https://github.com/thu-ml/tianshou/workflows/Unittest/badge.svg?branch=master)](https://github.com/thu-ml/tianshou/actions)
[![codecov](https://img.shields.io/codecov/c/gh/thu-ml/tianshou)](https://codecov.io/gh/thu-ml/tianshou)
[![GitHub issues](https://img.shields.io/github/issues/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/issues)
[![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers)
[![GitHub forks](https://img.shields.io/github/forks/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/network)
[![GitHub license](https://img.shields.io/github/license/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/blob/master/LICENSE)
[![Join the chat at https://gitter.im/thu-ml/tianshou](https://badges.gitter.im/thu-ml/tianshou.svg)](https://gitter.im/thu-ml/tianshou?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
**Tianshou** ([天授](https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88)) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed modularized framework and pythonic API for building the deep reinforcement learning agent with the least number of lines of code. The supported interface algorithms currently include:
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf)
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
- Vanilla Imitation Learning
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
Here is Tianshou's other features:
- Elegant framework, using only ~2000 lines of code
- Support parallel environment sampling for all algorithms [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#parallel-sampling)
- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#rnn-style-training)
- Support any type of environment state (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process)
- Support n-step returns estimation for all Q-learning based algorithms
- Support multi-agent RL easily [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning)
In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.
“天授”意指上天所授,引申为与生具有的天赋。天授是强化学习平台,而强化学习算法并不是向人类学习的,所以取“天授”意思是没有老师来教,而是自己通过跟环境不断交互来进行学习。
## Installation
Tianshou is currently hosted on [PyPI](https://pypi.org/project/tianshou/). It requires Python >= 3.6. You can simply install Tianshou with the following command:
```bash
pip3 install tianshou
```
You can also install with the newest version through GitHub:
```bash
pip3 install git+https://github.com/thu-ml/tianshou.git@master
```
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
```bash
# create a new virtualenv and install pip, change the env name if you like
conda create -n myenv pip
# activate the environment
conda activate myenv
# install tianshou
pip install tianshou
```
After installation, open your python console and type
```python
import tianshou as ts
print(ts.__version__)
```
If no error occurs, you have successfully installed Tianshou.
## Documentation
The tutorials and API documentation are hosted on [tianshou.readthedocs.io](https://tianshou.readthedocs.io/).
The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/master/test) folder and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folder. It may fail to run with PyPI installation, so please re-install the github version through `pip3 install git+https://github.com/thu-ml/tianshou.git@master`.
中文文档位于 [https://tianshou.readthedocs.io/zh/latest/](https://tianshou.readthedocs.io/zh/latest/)
<!-- 这里有一份天授平台简短的中文简介https://www.zhihu.com/question/377263715 -->
## Why Tianshou?
### Fast-speed
Tianshou is a lightweight but high-speed reinforcement learning platform. For example, here is a test on a laptop (i7-8750H + GTX1060). It only uses 3 seconds for training an agent based on vanilla policy gradient on the CartPole-v0 task: (seed may be different across different platform and device)
```bash
python3 test/discrete/test_pg.py --seed 0 --render 0.03
```
<div align="center">
<img src="docs/_static/images/testpg.gif"></a>
</div>
We select some of famous reinforcement learning platforms: 2 GitHub repos with most stars in all RL platforms (OpenAI Baseline and RLlib) and 2 GitHub repos with most stars in PyTorch RL platforms (PyTorch DRL and rlpyt). Here is the benchmark result for other algorithms and platforms on toy scenarios: (tested on the same laptop as mentioned above)
| RL Platform | [Tianshou](https://github.com/thu-ml/tianshou) | [Baselines](https://github.com/openai/baselines) | [Stable-Baselines](https://github.com/hill-a/stable-baselines) | [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [PyTorch-DRL](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch) | [rlpyt](https://github.com/astooke/rlpyt) |
| --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| GitHub Stars | [![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/openai/baselines)](https://github.com/openai/baselines/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/hill-a/stable-baselines)](https://github.com/hill-a/stable-baselines/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/ray-project/ray)](https://github.com/ray-project/ray/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch)](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/astooke/rlpyt)](https://github.com/astooke/rlpyt/stargazers) |
| Algo - Task | PyTorch | TensorFlow | TensorFlow | TF/PyTorch | PyTorch | PyTorch |
| PG - CartPole | 6.09±4.60s | None | None | 19.26±2.29s | None | ? |
| DQN - CartPole | 6.09±0.87s | 1046.34±291.27s | 93.47±58.05s | 28.56±4.60s | 31.58±11.30s \*\* | ? |
| A2C - CartPole | 10.59±2.04s | \*(~1612s) | 57.56±12.87s | 57.92±9.94s | \*(Not converged) | ? |
| PPO - CartPole | 31.82±7.76s | \*(~1179s) | 34.79±17.02s | 44.60±17.04s | 23.99±9.26s \*\* | ? |
| PPO - Pendulum | 16.18±2.49s | 745.43±160.82s | 259.73±27.37s | 123.62±44.23s | Runtime Error | ? |
| DDPG - Pendulum | 37.26±9.55s | \*(>1h) | 277.52±92.67s | 314.70±7.92s | 59.05±10.03s \*\* | 172.18±62.48s |
| TD3 - Pendulum | 44.04±6.37s | None | 99.75±21.63s | 149.90±7.54s | 57.52±17.71s \*\* | 210.31±76.30s |
| SAC - Pendulum | 36.02±0.77s | None | 124.85±79.14s | 97.42±4.75s | 63.80±27.37s \*\* | 295.92±140.85s |
*\*: Could not reach the target reward threshold in 1e6 steps in any of 5 runs. The total runtime is in the brackets.*
*\*\*: Since no specific evaluation function is implemented in PyTorch-DRL, the condition is relaxed to "The average total reward for 20 consecutive complete games during training is greater than or equal to threshold".*
*?: We have tried but it is nontrivial for running non-Atari game on rlpyt. See [here](https://github.com/astooke/rlpyt/issues/135).*
All of the platforms use 5 different seeds for testing. We erase those trials which failed for training. The reward threshold is 195.0 in CartPole and -250.0 in Pendulum over consecutive 100 episodes' mean returns (except for PyTorch-DRL).
We will add results of Atari Pong / Mujoco these days.
### Reproducible
Tianshou has its unit tests. Different from other platforms, **the unit tests include the full agent training procedure for all of the implemented algorithms**. It would be failed once if it could not train an agent to perform well enough on limited epochs on toy scenarios. The unit tests secure the reproducibility of our platform.
Check out the [GitHub Actions](https://github.com/thu-ml/tianshou/actions) page for more detail.
### Modularized Policy
We decouple all of the algorithms into 4 parts:
- `__init__`: initialize the policy;
- `forward`: to compute actions over given observations;
- `process_fn`: to preprocess data from replay buffer (since we have reformulated all algorithms to replay-buffer based algorithms);
- `learn`: to learn from a given batch data.
Within these API, we can interact with different policies conveniently.
### Elegant and Flexible
Currently, the overall code of Tianshou platform is less than 1500 lines without environment wrappers for Atari and Mujoco. Most of the implemented algorithms are less than 100 lines of python code. It is quite easy to go through the framework and understand how it works. We provide many flexible API as you wish, for instance, if you want to use your policy to interact with the environment with (at least) `n` steps:
```python
result = collector.collect(n_step=n)
```
If you have 3 environments in total and want to collect 1 episode in the first environment, 3 for the third environment:
```python
result = collector.collect(n_episode=[1, 0, 3])
```
If you want to train the given policy with a sampled batch:
```python
result = policy.learn(collector.sample(batch_size))
```
You can check out the [documentation](https://tianshou.readthedocs.io) for further usage.
## Quick Start
This is an example of Deep Q Network. You can also run the full script at [test/discrete/test_dqn.py](https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_dqn.py).
First, import some relevant packages:
```python
import gym, torch, numpy as np, torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import tianshou as ts
```
Define some hyper-parameters:
```python
task = 'CartPole-v0'
lr = 1e-3
gamma = 0.9
n_step = 3
eps_train, eps_test = 0.1, 0.05
epoch = 10
step_per_epoch = 1000
collect_per_step = 10
target_freq = 320
batch_size = 64
train_num, test_num = 8, 100
buffer_size = 20000
writer = SummaryWriter('log/dqn') # tensorboard is also supported!
```
Make environments:
```python
# you can also try with SubprocVectorEnv
train_envs = ts.env.VectorEnv([lambda: gym.make(task) for _ in range(train_num)])
test_envs = ts.env.VectorEnv([lambda: gym.make(task) for _ in range(test_num)])
```
Define the network:
```python
from tianshou.utils.net.common import Net
env = gym.make(task)
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
net = Net(layer_num=2, state_shape=state_shape, action_shape=action_shape)
optim = torch.optim.Adam(net.parameters(), lr=lr)
```
Setup policy and collectors:
```python
policy = ts.policy.DQNPolicy(net, optim, gamma, n_step,
target_update_freq=target_freq)
train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(buffer_size))
test_collector = ts.data.Collector(policy, test_envs)
```
Let's train it:
```python
result = ts.trainer.offpolicy_trainer(
policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step,
test_num, batch_size, train_fn=lambda e: policy.set_eps(eps_train),
test_fn=lambda e: policy.set_eps(eps_test),
stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer, task=task)
print(f'Finished training! Use {result["duration"]}')
```
Save / load the trained policy (it's exactly the same as PyTorch nn.module):
```python
torch.save(policy.state_dict(), 'dqn.pth')
policy.load_state_dict(torch.load('dqn.pth'))
```
Watch the performance with 35 FPS:
```python
collector = ts.data.Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)
collector.close()
```
Look at the result saved in tensorboard: (with bash script in your terminal)
```bash
tensorboard --logdir log/dqn
```
You can check out the [documentation](https://tianshou.readthedocs.io) for advanced usage.
## Contributing
Tianshou is still under development. More algorithms and features are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out [docs/contributing.rst](https://github.com/thu-ml/tianshou/blob/master/docs/contributing.rst).
## TODO
Check out the [Issue/PR Categories](https://github.com/thu-ml/tianshou/projects/2) and [Support Status](https://github.com/thu-ml/tianshou/projects/3) page for more detail.
## Citing Tianshou
If you find Tianshou useful, please cite it in your publications.
```latex
@misc{tianshou,
author = {Jiayi Weng, Minghao Zhang, Dong Yan, Hang Su, Jun Zhu},
title = {Tianshou},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/thu-ml/tianshou}},
}
```
## Acknowledgment
Tianshou was previously a reinforcement learning platform based on TensorFlow. You can check out the branch [`priv`](https://github.com/thu-ml/tianshou/tree/priv) for more detail. Many thanks to [Haosheng Zou](https://github.com/HaoshengZou)'s pioneering work for Tianshou before version 0.1.1.
We would like to thank [TSAIL](http://ml.cs.tsinghua.edu.cn/) and [Institute for Artificial Intelligence, Tsinghua University](http://ml.cs.tsinghua.edu.cn/thuai/) for providing such an excellent AI research platform.