docs fix and v0.2.5 (#156)
* pre * update docs * update docs * $ in bash * size -> hidden_layer_size * doctest * doctest again * filter a warning * fix bug * fix examples * test fail * test succ
This commit is contained in:
parent
089b85b6a2
commit
bd9c3c7f8d
9
.github/ISSUE_TEMPLATE.md
vendored
9
.github/ISSUE_TEMPLATE.md
vendored
@ -3,15 +3,10 @@
|
||||
+ [ ] RL algorithm bug
|
||||
+ [ ] documentation request (i.e. "X is missing from the documentation.")
|
||||
+ [ ] new feature request
|
||||
- [ ] I have visited the [source website], and in particular read the [known issues]
|
||||
- [ ] I have searched through the [issue tracker] and [issue categories] for duplicates
|
||||
- [ ] I have visited the [source website](https://github.com/thu-ml/tianshou/)
|
||||
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
|
||||
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
|
||||
```python
|
||||
import tianshou, torch, sys
|
||||
print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
|
||||
```
|
||||
|
||||
[source website]: https://github.com/thu-ml/tianshou/
|
||||
[known issues]: https://github.com/thu-ml/tianshou/#faq-and-known-issues
|
||||
[issue categories]: https://github.com/thu-ml/tianshou/projects/2
|
||||
[issue tracker]: https://github.com/thu-ml/tianshou/issues?q=
|
||||
|
9
.github/PULL_REQUEST_TEMPLATE.md
vendored
9
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -7,15 +7,10 @@
|
||||
|
||||
Less important but also useful:
|
||||
|
||||
- [ ] I have visited the [source website], and in particular read the [known issues]
|
||||
- [ ] I have searched through the [issue tracker] and [issue categories] for duplicates
|
||||
- [ ] I have visited the [source website](https://github.com/thu-ml/tianshou)
|
||||
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
|
||||
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
|
||||
```python
|
||||
import tianshou, torch, sys
|
||||
print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
|
||||
```
|
||||
|
||||
[source website]: https://github.com/thu-ml/tianshou
|
||||
[known issues]: https://github.com/thu-ml/tianshou/#faq-and-known-issues
|
||||
[issue categories]: https://github.com/thu-ml/tianshou/projects/2
|
||||
[issue tracker]: https://github.com/thu-ml/tianshou/issues?q=
|
||||
|
@ -1,4 +1,4 @@
|
||||
name: PEP8 Check
|
||||
name: PEP8 and Docs Check
|
||||
|
||||
on: [push, pull_request]
|
||||
|
||||
@ -11,9 +11,20 @@ jobs:
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.8
|
||||
- name: Upgrade pip
|
||||
run: |
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install flake8
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
flake8 . --count --show-source --statistics
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install ".[dev]" --upgrade
|
||||
- name: Documentation test
|
||||
run: |
|
||||
cd docs
|
||||
make html SPHINXOPTS="-W"
|
||||
cd ..
|
60
README.md
60
README.md
@ -38,7 +38,7 @@ Here is Tianshou's other features:
|
||||
- Support any type of environment state (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
|
||||
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process)
|
||||
- Support n-step returns estimation for all Q-learning based algorithms
|
||||
- Support multi-agent RL easily [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning)
|
||||
- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning)
|
||||
|
||||
In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.
|
||||
|
||||
@ -49,24 +49,27 @@ In Chinese, Tianshou means divinely ordained and is derived to the gift of being
|
||||
Tianshou is currently hosted on [PyPI](https://pypi.org/project/tianshou/). It requires Python >= 3.6. You can simply install Tianshou with the following command:
|
||||
|
||||
```bash
|
||||
pip3 install tianshou
|
||||
$ pip install tianshou
|
||||
```
|
||||
|
||||
You can also install with the newest version through GitHub:
|
||||
|
||||
```bash
|
||||
pip3 install git+https://github.com/thu-ml/tianshou.git@master
|
||||
# latest release
|
||||
$ pip install git+https://github.com/thu-ml/tianshou.git@master
|
||||
# develop version
|
||||
$ pip install git+https://github.com/thu-ml/tianshou.git@dev
|
||||
```
|
||||
|
||||
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
|
||||
|
||||
```bash
|
||||
# create a new virtualenv and install pip, change the env name if you like
|
||||
conda create -n myenv pip
|
||||
$ conda create -n myenv pip
|
||||
# activate the environment
|
||||
conda activate myenv
|
||||
$ conda activate myenv
|
||||
# install tianshou
|
||||
pip install tianshou
|
||||
$ pip install tianshou
|
||||
```
|
||||
|
||||
After installation, open your python console and type
|
||||
@ -82,9 +85,9 @@ If no error occurs, you have successfully installed Tianshou.
|
||||
|
||||
The tutorials and API documentation are hosted on [tianshou.readthedocs.io](https://tianshou.readthedocs.io/).
|
||||
|
||||
The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/master/test) folder and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folder. It may fail to run with PyPI installation, so please re-install the github version through `pip3 install git+https://github.com/thu-ml/tianshou.git@master`.
|
||||
The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/master/test) folder and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folder.
|
||||
|
||||
中文文档位于 [https://tianshou.readthedocs.io/zh/latest/](https://tianshou.readthedocs.io/zh/latest/)
|
||||
中文文档位于 [https://tianshou.readthedocs.io/zh/latest/](https://tianshou.readthedocs.io/zh/latest/)。
|
||||
|
||||
<!-- 这里有一份天授平台简短的中文简介:https://www.zhihu.com/question/377263715 -->
|
||||
|
||||
@ -95,7 +98,7 @@ The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/ma
|
||||
Tianshou is a lightweight but high-speed reinforcement learning platform. For example, here is a test on a laptop (i7-8750H + GTX1060). It only uses 3 seconds for training an agent based on vanilla policy gradient on the CartPole-v0 task: (seed may be different across different platform and device)
|
||||
|
||||
```bash
|
||||
python3 test/discrete/test_pg.py --seed 0 --render 0.03
|
||||
$ python3 test/discrete/test_pg.py --seed 0 --render 0.03
|
||||
```
|
||||
|
||||
<div align="center">
|
||||
@ -108,10 +111,10 @@ We select some of famous reinforcement learning platforms: 2 GitHub repos with m
|
||||
| --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||
| GitHub Stars | [](https://github.com/thu-ml/tianshou/stargazers) | [](https://github.com/openai/baselines/stargazers) | [](https://github.com/hill-a/stable-baselines/stargazers) | [](https://github.com/ray-project/ray/stargazers) | [](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/stargazers) | [](https://github.com/astooke/rlpyt/stargazers) |
|
||||
| Algo - Task | PyTorch | TensorFlow | TensorFlow | TF/PyTorch | PyTorch | PyTorch |
|
||||
| PG - CartPole | 6.09±4.60s | None | None | 19.26±2.29s | None | ? |
|
||||
| DQN - CartPole | 6.09±0.87s | 1046.34±291.27s | 93.47±58.05s | 28.56±4.60s | 31.58±11.30s \*\* | ? |
|
||||
| A2C - CartPole | 10.59±2.04s | \*(~1612s) | 57.56±12.87s | 57.92±9.94s | \*(Not converged) | ? |
|
||||
| PPO - CartPole | 31.82±7.76s | \*(~1179s) | 34.79±17.02s | 44.60±17.04s | 23.99±9.26s \*\* | ? |
|
||||
| PG - CartPole | 9.02±6.79s | None | None | 19.26±2.29s | None | ? |
|
||||
| DQN - CartPole | 6.72±1.28s | 1046.34±291.27s | 93.47±58.05s | 28.56±4.60s | 31.58±11.30s \*\* | ? |
|
||||
| A2C - CartPole | 15.33±4.48s | \*(~1612s) | 57.56±12.87s | 57.92±9.94s | \*(Not converged) | ? |
|
||||
| PPO - CartPole | 6.01±1.14s | \*(~1179s) | 34.79±17.02s | 44.60±17.04s | 23.99±9.26s \*\* | ? |
|
||||
| PPO - Pendulum | 16.18±2.49s | 745.43±160.82s | 259.73±27.37s | 123.62±44.23s | Runtime Error | ? |
|
||||
| DDPG - Pendulum | 37.26±9.55s | \*(>1h) | 277.52±92.67s | 314.70±7.92s | 59.05±10.03s \*\* | 172.18±62.48s |
|
||||
| TD3 - Pendulum | 44.04±6.37s | None | 99.75±21.63s | 149.90±7.54s | 57.52±17.71s \*\* | 210.31±76.30s |
|
||||
@ -142,7 +145,7 @@ We decouple all of the algorithms into 4 parts:
|
||||
- `process_fn`: to preprocess data from replay buffer (since we have reformulated all algorithms to replay-buffer based algorithms);
|
||||
- `learn`: to learn from a given batch data.
|
||||
|
||||
Within these API, we can interact with different policies conveniently.
|
||||
Within this API, we can interact with different policies conveniently.
|
||||
|
||||
### Elegant and Flexible
|
||||
|
||||
@ -182,17 +185,12 @@ Define some hyper-parameters:
|
||||
|
||||
```python
|
||||
task = 'CartPole-v0'
|
||||
lr = 1e-3
|
||||
gamma = 0.9
|
||||
n_step = 3
|
||||
eps_train, eps_test = 0.1, 0.05
|
||||
epoch = 10
|
||||
step_per_epoch = 1000
|
||||
collect_per_step = 10
|
||||
target_freq = 320
|
||||
batch_size = 64
|
||||
lr, epoch, batch_size = 1e-3, 10, 64
|
||||
train_num, test_num = 8, 100
|
||||
gamma, n_step, target_freq = 0.9, 3, 320
|
||||
buffer_size = 20000
|
||||
eps_train, eps_test = 0.1, 0.05
|
||||
step_per_epoch, collect_per_step = 1000, 10
|
||||
writer = SummaryWriter('log/dqn') # tensorboard is also supported!
|
||||
```
|
||||
|
||||
@ -208,7 +206,8 @@ Define the network:
|
||||
|
||||
```python
|
||||
from tianshou.utils.net.common import Net
|
||||
|
||||
# you can define other net by following the API:
|
||||
# https://tianshou.readthedocs.io/en/latest/tutorials/dqn.html#build-the-network
|
||||
env = gym.make(task)
|
||||
state_shape = env.observation_space.shape or env.observation_space.n
|
||||
action_shape = env.action_space.shape or env.action_space.n
|
||||
@ -219,8 +218,7 @@ optim = torch.optim.Adam(net.parameters(), lr=lr)
|
||||
Setup policy and collectors:
|
||||
|
||||
```python
|
||||
policy = ts.policy.DQNPolicy(net, optim, gamma, n_step,
|
||||
target_update_freq=target_freq)
|
||||
policy = ts.policy.DQNPolicy(net, optim, gamma, n_step, target_update_freq=target_freq)
|
||||
train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(buffer_size))
|
||||
test_collector = ts.data.Collector(policy, test_envs)
|
||||
```
|
||||
@ -236,7 +234,7 @@ result = ts.trainer.offpolicy_trainer(
|
||||
print(f'Finished training! Use {result["duration"]}')
|
||||
```
|
||||
|
||||
Save / load the trained policy (it's exactly the same as PyTorch nn.module):
|
||||
Save / load the trained policy (it's exactly the same as PyTorch `nn.module`):
|
||||
|
||||
```python
|
||||
torch.save(policy.state_dict(), 'dqn.pth')
|
||||
@ -254,18 +252,18 @@ collector.close()
|
||||
Look at the result saved in tensorboard: (with bash script in your terminal)
|
||||
|
||||
```bash
|
||||
tensorboard --logdir log/dqn
|
||||
$ tensorboard --logdir log/dqn
|
||||
```
|
||||
|
||||
You can check out the [documentation](https://tianshou.readthedocs.io) for advanced usage.
|
||||
|
||||
## Contributing
|
||||
|
||||
Tianshou is still under development. More algorithms and features are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out [docs/contributing.rst](https://github.com/thu-ml/tianshou/blob/master/docs/contributing.rst).
|
||||
Tianshou is still under development. More algorithms and features are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out [this link](https://tianshou.readthedocs.io/en/latest/contributing.html).
|
||||
|
||||
## TODO
|
||||
|
||||
Check out the [Issue/PR Categories](https://github.com/thu-ml/tianshou/projects/2) and [Support Status](https://github.com/thu-ml/tianshou/projects/3) page for more detail.
|
||||
Check out the [Project](https://github.com/thu-ml/tianshou/projects) page for more detail.
|
||||
|
||||
## Citing Tianshou
|
||||
|
||||
@ -273,7 +271,7 @@ If you find Tianshou useful, please cite it in your publications.
|
||||
|
||||
```latex
|
||||
@misc{tianshou,
|
||||
author = {Jiayi Weng, Minghao Zhang, Dong Yan, Hang Su, Jun Zhu},
|
||||
author = {Jiayi Weng, Minghao Zhang, Alexis Duburcq, Kaichao You, Dong Yan, Hang Su, Jun Zhu},
|
||||
title = {Tianshou},
|
||||
year = {2020},
|
||||
publisher = {GitHub},
|
||||
|
@ -41,7 +41,7 @@ extensions = [
|
||||
'sphinx.ext.doctest',
|
||||
'sphinx.ext.intersphinx',
|
||||
'sphinx.ext.coverage',
|
||||
'sphinx.ext.imgmath',
|
||||
# 'sphinx.ext.imgmath',
|
||||
'sphinx.ext.mathjax',
|
||||
'sphinx.ext.ifconfig',
|
||||
'sphinx.ext.viewcode',
|
||||
@ -58,7 +58,9 @@ master_doc = 'index'
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
||||
autodoc_default_options = {'special-members': '__call__, __getitem__, __len__'}
|
||||
autodoc_default_options = {'special-members': ', '.join([
|
||||
'__len__', '__call__', '__getitem__', '__setitem__',
|
||||
'__getattr__', '__setattr__'])}
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
|
@ -8,13 +8,14 @@ To install Tianshou in an "editable" mode, run
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip3 install -e ".[dev]"
|
||||
$ git checkout dev
|
||||
$ pip install -e ".[dev]"
|
||||
|
||||
in the main directory. This installation is removable by
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python3 setup.py develop --uninstall
|
||||
$ python setup.py develop --uninstall
|
||||
|
||||
PEP8 Code Style Check
|
||||
---------------------
|
||||
@ -23,7 +24,7 @@ We follow PEP8 python code style. To check, in the main directory, run:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
flake8 . --count --show-source --statistics
|
||||
$ flake8 . --count --show-source --statistics
|
||||
|
||||
Test Locally
|
||||
------------
|
||||
@ -32,7 +33,7 @@ This command will run automatic tests in the main directory
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pytest test --cov tianshou -s --durations 0 -v
|
||||
$ pytest test --cov tianshou -s --durations 0 -v
|
||||
|
||||
Test by GitHub Actions
|
||||
----------------------
|
||||
@ -65,11 +66,13 @@ To compile documentation into webpages, run
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
make html
|
||||
$ make html
|
||||
|
||||
under the ``docs/`` directory. The generated webpages are in ``docs/_build`` and can be viewed with browsers.
|
||||
|
||||
Chinese Documentation
|
||||
---------------------
|
||||
Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/, and the develop version of documentation is in https://tianshou.readthedocs.io/en/dev/.
|
||||
|
||||
Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/
|
||||
Pull Request
|
||||
------------
|
||||
|
||||
All of the commits should merge through the pull request to the ``dev`` branch. The pull request must have 2 approvals before merging.
|
||||
|
@ -6,4 +6,4 @@ We always welcome contributions to help make Tianshou better. Below are an incom
|
||||
* Jiayi Weng (`Trinkle23897 <https://github.com/Trinkle23897>`_)
|
||||
* Minghao Zhang (`Mehooz <https://github.com/Mehooz>`_)
|
||||
* Alexis Duburcq (`duburcqa <https://github.com/duburcqa>`_)
|
||||
* Kaichao You (`youkaichao <https://github.com/youkaichao>`_)
|
||||
* Kaichao You (`youkaichao <https://github.com/youkaichao>`_)
|
||||
|
@ -10,7 +10,7 @@ Welcome to Tianshou!
|
||||
|
||||
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_
|
||||
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
|
||||
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
|
||||
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
||||
@ -28,32 +28,38 @@ Here is Tianshou's other features:
|
||||
* Support any type of environment state (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env`
|
||||
* Support customized training process: :ref:`customize_training`
|
||||
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` for all Q-learning based algorithms
|
||||
* Support multi-agent RL easily (a tutorial is available at :doc:`/tutorials/tictactoe`)
|
||||
* Support multi-agent RL: :doc:`/tutorials/tictactoe`
|
||||
|
||||
中文文档位于 https://tianshou.readthedocs.io/zh/latest/
|
||||
中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ <https://tianshou.readthedocs.io/zh/latest/>`_
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
Tianshou is currently hosted on `PyPI <https://pypi.org/project/tianshou/>`_. You can simply install Tianshou with the following command (with Python >= 3.6):
|
||||
::
|
||||
|
||||
pip3 install tianshou
|
||||
.. code-block:: bash
|
||||
|
||||
$ pip install tianshou
|
||||
|
||||
You can also install with the newest version through GitHub:
|
||||
::
|
||||
|
||||
pip3 install git+https://github.com/thu-ml/tianshou.git@master
|
||||
.. code-block:: bash
|
||||
|
||||
# latest release
|
||||
$ pip install git+https://github.com/thu-ml/tianshou.git@master
|
||||
# develop version
|
||||
$ pip install git+https://github.com/thu-ml/tianshou.git@dev
|
||||
|
||||
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
|
||||
::
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# create a new virtualenv and install pip, change the env name if you like
|
||||
conda create -n myenv pip
|
||||
$ conda create -n myenv pip
|
||||
# activate the environment
|
||||
conda activate myenv
|
||||
$ conda activate myenv
|
||||
# install tianshou
|
||||
pip install tianshou
|
||||
$ pip install tianshou
|
||||
|
||||
After installation, open your python console and type
|
||||
::
|
||||
@ -63,7 +69,7 @@ After installation, open your python console and type
|
||||
|
||||
If no error occurs, you have successfully installed Tianshou.
|
||||
|
||||
Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ <https://tianshou.readthedocs.io/en/stable/>`_.
|
||||
Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ <https://tianshou.readthedocs.io/en/stable/>`_ and the develop version through `tianshou.readthedocs.io/en/dev/ <https://tianshou.readthedocs.io/en/dev/>`_.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
@ -8,7 +8,6 @@ The full script is at `test/discrete/test_dqn.py <https://github.com/thu-ml/tian
|
||||
|
||||
Contrary to existing Deep RL libraries such as `RLlib <https://github.com/ray-project/ray/tree/master/rllib/>`_, which could only accept a config specification of hyperparameters, network, and others, Tianshou provides an easy way of construction through the code-level.
|
||||
|
||||
|
||||
Make an Environment
|
||||
-------------------
|
||||
|
||||
@ -22,7 +21,6 @@ First of all, you have to make an environment for your agent to interact with. F
|
||||
|
||||
CartPole-v0 is a simple environment with a discrete action space, for which DQN applies. You have to identify whether the action space is continuous or discrete and apply eligible algorithms. DDPG :cite:`DDPG`, for example, could only be applied to continuous action spaces, while almost all other policy gradient methods could be applied to both, depending on the probability distribution on the action.
|
||||
|
||||
|
||||
Setup Multi-environment Wrapper
|
||||
-------------------------------
|
||||
|
||||
@ -79,17 +77,13 @@ You can also have a try with those pre-defined networks in :mod:`~tianshou.utils
|
||||
1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment.
|
||||
2. Output: some ``logits``, the next hidden state ``state``, and intermediate result during the policy forwarding procedure ``policy``. The logits could be a tuple instead of a ``torch.Tensor``. It depends on how the policy process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. The ``policy`` can be a Batch of torch.Tensor or other things, which will be stored in the replay buffer, and can be accessed in the policy update process (e.g. in ``policy.learn()``, the ``batch.policy`` is what you need).
|
||||
|
||||
|
||||
Setup Policy
|
||||
------------
|
||||
|
||||
We use the defined ``net`` and ``optim``, with extra policy hyper-parameters, to define a policy. Here we define a DQN policy with using a target network:
|
||||
::
|
||||
|
||||
policy = ts.policy.DQNPolicy(net, optim,
|
||||
discount_factor=0.9, estimation_step=3,
|
||||
target_update_freq=320)
|
||||
|
||||
policy = ts.policy.DQNPolicy(net, optim, discount_factor=0.9, estimation_step=3, target_update_freq=320)
|
||||
|
||||
Setup Collector
|
||||
---------------
|
||||
@ -101,7 +95,6 @@ In each step, the collector will let the policy perform (at least) a specified n
|
||||
train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(size=20000))
|
||||
test_collector = ts.data.Collector(policy, test_envs)
|
||||
|
||||
|
||||
Train Policy with a Trainer
|
||||
---------------------------
|
||||
|
||||
@ -118,7 +111,7 @@ Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tians
|
||||
writer=None)
|
||||
print(f'Finished training! Use {result["duration"]}')
|
||||
|
||||
The meaning of each parameter is as follows:
|
||||
The meaning of each parameter is as follows (full description can be found at :meth:`~tianshou.trainer.offpolicy_trainer`):
|
||||
|
||||
* ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``;
|
||||
* ``step_per_epoch``: The number of step for updating policy network in one epoch;
|
||||
@ -157,7 +150,6 @@ The returned result is a dictionary as follows:
|
||||
|
||||
It shows that within approximately 4 seconds, we finished training a DQN agent on CartPole. The mean returns over 100 consecutive episodes is 199.03.
|
||||
|
||||
|
||||
Save/Load Policy
|
||||
----------------
|
||||
|
||||
@ -167,7 +159,6 @@ Since the policy inherits the ``torch.nn.Module`` class, saving and loading the
|
||||
torch.save(policy.state_dict(), 'dqn.pth')
|
||||
policy.load_state_dict(torch.load('dqn.pth'))
|
||||
|
||||
|
||||
Watch the Agent's Performance
|
||||
-----------------------------
|
||||
|
||||
@ -178,7 +169,6 @@ Watch the Agent's Performance
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
collector.close()
|
||||
|
||||
|
||||
.. _customized_trainer:
|
||||
|
||||
Train a Policy with Customized Codes
|
||||
@ -186,7 +176,7 @@ Train a Policy with Customized Codes
|
||||
|
||||
"I don't want to use your provided trainer. I want to customize it!"
|
||||
|
||||
No problem! Tianshou supports user-defined training code. Here is the usage:
|
||||
Tianshou supports user-defined training code. Here is the code snippet:
|
||||
::
|
||||
|
||||
# pre-collect 5000 frames with random action before training
|
||||
@ -212,7 +202,7 @@ No problem! Tianshou supports user-defined training code. Here is the usage:
|
||||
# train policy with a sampled batch data
|
||||
losses = policy.learn(train_collector.sample(batch_size=64))
|
||||
|
||||
For further usage, you can refer to :doc:`/tutorials/cheatsheet`.
|
||||
For further usage, you can refer to the :doc:`/tutorials/cheatsheet`.
|
||||
|
||||
.. rubric:: References
|
||||
|
||||
|
@ -162,8 +162,8 @@ Tianshou already provides some builtin classes for multi-agent learning. You can
|
||||
|
||||
Random agents perform badly. In the above game, although agent 2 wins finally, it is clear that a smart agent 1 would place an ``x`` at row 4 col 4 to win directly.
|
||||
|
||||
Train a MARL Agent
|
||||
------------------
|
||||
Train an MARL Agent
|
||||
-------------------
|
||||
|
||||
So let's start to train our Tic-Tac-Toe agent! First, import some required modules.
|
||||
::
|
||||
|
@ -84,7 +84,7 @@ def test_ddpg(args=get_args()):
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
@ -94,7 +94,7 @@ def test_td3(args=get_args()):
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
@ -19,7 +19,7 @@ def create_atari_environment(name=None, sticky_actions=True,
|
||||
|
||||
|
||||
def preprocess_fn(obs=None, act=None, rew=None, done=None,
|
||||
obs_next=None, info=None, policy=None):
|
||||
obs_next=None, info=None, policy=None, **kwargs):
|
||||
if obs_next is not None:
|
||||
obs_next = np.reshape(obs_next, (-1, *obs_next.shape[2:]))
|
||||
obs_next = np.moveaxis(obs_next, 0, -1)
|
||||
|
@ -101,7 +101,7 @@ def test_td3(args=get_args()):
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
@ -89,8 +89,7 @@ def test_a2c(args=get_args()):
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
|
||||
task=args.task)
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
|
@ -94,7 +94,7 @@ def test_dqn(args=get_args()):
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||
stop_fn=stop_fn, writer=writer, task=args.task)
|
||||
stop_fn=stop_fn, writer=writer)
|
||||
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
@ -93,8 +93,7 @@ def test_ppo(args=get_args()):
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
|
||||
task=args.task)
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
|
@ -1,7 +1,7 @@
|
||||
from tianshou import data, env, utils, policy, trainer, \
|
||||
exploration
|
||||
|
||||
__version__ = '0.2.4'
|
||||
__version__ = '0.2.5'
|
||||
__all__ = [
|
||||
'env',
|
||||
'data',
|
||||
|
@ -10,29 +10,31 @@ class Net(nn.Module):
|
||||
please refer to :ref:`build_the_network`.
|
||||
|
||||
:param concat: whether the input shape is concatenated by state_shape
|
||||
and action_shape. If it is True, ``action_shape`` is not the output
|
||||
shape, but affects the input shape.
|
||||
and action_shape. If it is True, ``action_shape`` is not the output
|
||||
shape, but affects the input shape.
|
||||
"""
|
||||
|
||||
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu',
|
||||
softmax=False, concat=False):
|
||||
softmax=False, concat=False, hidden_layer_size=128):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
input_size = np.prod(state_shape)
|
||||
if concat:
|
||||
input_size += np.prod(action_shape)
|
||||
self.model = [
|
||||
nn.Linear(input_size, 128),
|
||||
nn.Linear(input_size, hidden_layer_size),
|
||||
nn.ReLU(inplace=True)]
|
||||
for i in range(layer_num):
|
||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
||||
self.model += [nn.Linear(hidden_layer_size, hidden_layer_size),
|
||||
nn.ReLU(inplace=True)]
|
||||
if action_shape and not concat:
|
||||
self.model += [nn.Linear(128, np.prod(action_shape))]
|
||||
self.model += [nn.Linear(hidden_layer_size, np.prod(action_shape))]
|
||||
if softmax:
|
||||
self.model += [nn.Softmax(dim=-1)]
|
||||
self.model = nn.Sequential(*self.model)
|
||||
|
||||
def forward(self, s, state=None, info={}):
|
||||
"""s -> flatten -> logits"""
|
||||
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||
s = s.flatten(1)
|
||||
logits = self.model(s)
|
||||
@ -44,17 +46,23 @@ class Recurrent(nn.Module):
|
||||
customize the network), please refer to :ref:`build_the_network`.
|
||||
"""
|
||||
|
||||
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
|
||||
def __init__(self, layer_num, state_shape, action_shape,
|
||||
device='cpu', hidden_layer_size=128):
|
||||
super().__init__()
|
||||
self.state_shape = state_shape
|
||||
self.action_shape = action_shape
|
||||
self.device = device
|
||||
self.nn = nn.LSTM(input_size=128, hidden_size=128,
|
||||
self.nn = nn.LSTM(input_size=hidden_layer_size,
|
||||
hidden_size=hidden_layer_size,
|
||||
num_layers=layer_num, batch_first=True)
|
||||
self.fc1 = nn.Linear(np.prod(state_shape), 128)
|
||||
self.fc2 = nn.Linear(128, np.prod(action_shape))
|
||||
self.fc1 = nn.Linear(np.prod(state_shape), hidden_layer_size)
|
||||
self.fc2 = nn.Linear(hidden_layer_size, np.prod(action_shape))
|
||||
|
||||
def forward(self, s, state=None, info={}):
|
||||
"""In the evaluation mode, s should be with shape ``[bsz, dim]``; in
|
||||
the training mode, s should be with shape ``[bsz, len, dim]``. See the
|
||||
code and comment for more detail.
|
||||
"""
|
||||
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||
# In short, the tensor's shape in training phase is longer than which
|
||||
|
@ -11,13 +11,14 @@ class Actor(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, preprocess_net, action_shape,
|
||||
max_action, device='cpu'):
|
||||
max_action, device='cpu', hidden_layer_size=128):
|
||||
super().__init__()
|
||||
self.preprocess = preprocess_net
|
||||
self.last = nn.Linear(128, np.prod(action_shape))
|
||||
self.last = nn.Linear(hidden_layer_size, np.prod(action_shape))
|
||||
self._max = max_action
|
||||
|
||||
def forward(self, s, state=None, info={}):
|
||||
"""s -> logits -> action"""
|
||||
logits, h = self.preprocess(s, state)
|
||||
logits = self._max * torch.tanh(self.last(logits))
|
||||
return logits, h
|
||||
@ -28,13 +29,14 @@ class Critic(nn.Module):
|
||||
:ref:`build_the_network`.
|
||||
"""
|
||||
|
||||
def __init__(self, preprocess_net, device='cpu'):
|
||||
def __init__(self, preprocess_net, device='cpu', hidden_layer_size=128):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.preprocess = preprocess_net
|
||||
self.last = nn.Linear(128, 1)
|
||||
self.last = nn.Linear(hidden_layer_size, 1)
|
||||
|
||||
def forward(self, s, a=None, **kwargs):
|
||||
"""(s, a) -> logits -> Q(s, a)"""
|
||||
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||
s = s.flatten(1)
|
||||
if a is not None:
|
||||
@ -51,17 +53,18 @@ class ActorProb(nn.Module):
|
||||
:ref:`build_the_network`.
|
||||
"""
|
||||
|
||||
def __init__(self, preprocess_net, action_shape,
|
||||
max_action, device='cpu', unbounded=False):
|
||||
def __init__(self, preprocess_net, action_shape, max_action,
|
||||
device='cpu', unbounded=False, hidden_layer_size=128):
|
||||
super().__init__()
|
||||
self.preprocess = preprocess_net
|
||||
self.device = device
|
||||
self.mu = nn.Linear(128, np.prod(action_shape))
|
||||
self.mu = nn.Linear(hidden_layer_size, np.prod(action_shape))
|
||||
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
|
||||
self._max = max_action
|
||||
self._unbounded = unbounded
|
||||
|
||||
def forward(self, s, state=None, **kwargs):
|
||||
"""s -> logits -> (mu, sigma)"""
|
||||
logits, h = self.preprocess(s, state)
|
||||
mu = self.mu(logits)
|
||||
if not self._unbounded:
|
||||
@ -78,15 +81,17 @@ class RecurrentActorProb(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, layer_num, state_shape, action_shape,
|
||||
max_action, device='cpu'):
|
||||
max_action, device='cpu', hidden_layer_size=128):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.nn = nn.LSTM(input_size=np.prod(state_shape), hidden_size=128,
|
||||
self.nn = nn.LSTM(input_size=np.prod(state_shape),
|
||||
hidden_size=hidden_layer_size,
|
||||
num_layers=layer_num, batch_first=True)
|
||||
self.mu = nn.Linear(128, np.prod(action_shape))
|
||||
self.mu = nn.Linear(hidden_layer_size, np.prod(action_shape))
|
||||
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
|
||||
|
||||
def forward(self, s, **kwargs):
|
||||
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
|
||||
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||
# In short, the tensor's shape in training phase is longer than which
|
||||
@ -107,16 +112,19 @@ class RecurrentCritic(nn.Module):
|
||||
:ref:`build_the_network`.
|
||||
"""
|
||||
|
||||
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'):
|
||||
def __init__(self, layer_num, state_shape,
|
||||
action_shape=0, device='cpu', hidden_layer_size=128):
|
||||
super().__init__()
|
||||
self.state_shape = state_shape
|
||||
self.action_shape = action_shape
|
||||
self.device = device
|
||||
self.nn = nn.LSTM(input_size=np.prod(state_shape), hidden_size=128,
|
||||
self.nn = nn.LSTM(input_size=np.prod(state_shape),
|
||||
hidden_size=hidden_layer_size,
|
||||
num_layers=layer_num, batch_first=True)
|
||||
self.fc2 = nn.Linear(128 + np.prod(action_shape), 1)
|
||||
self.fc2 = nn.Linear(hidden_layer_size + np.prod(action_shape), 1)
|
||||
|
||||
def forward(self, s, a=None):
|
||||
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
|
||||
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||
# In short, the tensor's shape in training phase is longer than which
|
||||
|
@ -9,12 +9,13 @@ class Actor(nn.Module):
|
||||
:ref:`build_the_network`.
|
||||
"""
|
||||
|
||||
def __init__(self, preprocess_net, action_shape):
|
||||
def __init__(self, preprocess_net, action_shape, hidden_layer_size=128):
|
||||
super().__init__()
|
||||
self.preprocess = preprocess_net
|
||||
self.last = nn.Linear(128, np.prod(action_shape))
|
||||
self.last = nn.Linear(hidden_layer_size, np.prod(action_shape))
|
||||
|
||||
def forward(self, s, state=None, info={}):
|
||||
r"""s -> Q(s, \*)"""
|
||||
logits, h = self.preprocess(s, state)
|
||||
logits = F.softmax(self.last(logits), dim=-1)
|
||||
return logits, h
|
||||
@ -25,12 +26,13 @@ class Critic(nn.Module):
|
||||
:ref:`build_the_network`.
|
||||
"""
|
||||
|
||||
def __init__(self, preprocess_net):
|
||||
def __init__(self, preprocess_net, hidden_layer_size=128):
|
||||
super().__init__()
|
||||
self.preprocess = preprocess_net
|
||||
self.last = nn.Linear(128, 1)
|
||||
self.last = nn.Linear(hidden_layer_size, 1)
|
||||
|
||||
def forward(self, s, **kwargs):
|
||||
"""s -> V(s)"""
|
||||
logits, h = self.preprocess(s, state=kwargs.get('state', None))
|
||||
logits = self.last(logits)
|
||||
return logits
|
||||
@ -62,6 +64,7 @@ class DQN(nn.Module):
|
||||
self.head = nn.Linear(512, action_shape)
|
||||
|
||||
def forward(self, x, state=None, info={}):
|
||||
r"""x -> Q(x, \*)"""
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = torch.tensor(x, device=self.device, dtype=torch.float32)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user