docs fix and v0.2.5 (#156)

* pre

* update docs

* update docs

* $ in bash

* size -> hidden_layer_size

* doctest

* doctest again

* filter a warning

* fix bug

* fix examples

* test fail

* test succ
This commit is contained in:
n+e 2020-07-22 14:42:08 +08:00 committed by GitHub
parent 089b85b6a2
commit bd9c3c7f8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 139 additions and 122 deletions

View File

@ -3,15 +3,10 @@
+ [ ] RL algorithm bug
+ [ ] documentation request (i.e. "X is missing from the documentation.")
+ [ ] new feature request
- [ ] I have visited the [source website], and in particular read the [known issues]
- [ ] I have searched through the [issue tracker] and [issue categories] for duplicates
- [ ] I have visited the [source website](https://github.com/thu-ml/tianshou/)
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
```python
import tianshou, torch, sys
print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
```
[source website]: https://github.com/thu-ml/tianshou/
[known issues]: https://github.com/thu-ml/tianshou/#faq-and-known-issues
[issue categories]: https://github.com/thu-ml/tianshou/projects/2
[issue tracker]: https://github.com/thu-ml/tianshou/issues?q=

View File

@ -7,15 +7,10 @@
Less important but also useful:
- [ ] I have visited the [source website], and in particular read the [known issues]
- [ ] I have searched through the [issue tracker] and [issue categories] for duplicates
- [ ] I have visited the [source website](https://github.com/thu-ml/tianshou)
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
```python
import tianshou, torch, sys
print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
```
[source website]: https://github.com/thu-ml/tianshou
[known issues]: https://github.com/thu-ml/tianshou/#faq-and-known-issues
[issue categories]: https://github.com/thu-ml/tianshou/projects/2
[issue tracker]: https://github.com/thu-ml/tianshou/issues?q=

View File

@ -1,4 +1,4 @@
name: PEP8 Check
name: PEP8 and Docs Check
on: [push, pull_request]
@ -11,9 +11,20 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: 3.8
- name: Upgrade pip
run: |
python -m pip install --upgrade pip setuptools wheel
- name: Install dependencies
run: |
python -m pip install flake8
- name: Lint with flake8
run: |
flake8 . --count --show-source --statistics
- name: Install dependencies
run: |
pip install ".[dev]" --upgrade
- name: Documentation test
run: |
cd docs
make html SPHINXOPTS="-W"
cd ..

View File

@ -38,7 +38,7 @@ Here is Tianshou's other features:
- Support any type of environment state (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process)
- Support n-step returns estimation for all Q-learning based algorithms
- Support multi-agent RL easily [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning)
- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning)
In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.
@ -49,24 +49,27 @@ In Chinese, Tianshou means divinely ordained and is derived to the gift of being
Tianshou is currently hosted on [PyPI](https://pypi.org/project/tianshou/). It requires Python >= 3.6. You can simply install Tianshou with the following command:
```bash
pip3 install tianshou
$ pip install tianshou
```
You can also install with the newest version through GitHub:
```bash
pip3 install git+https://github.com/thu-ml/tianshou.git@master
# latest release
$ pip install git+https://github.com/thu-ml/tianshou.git@master
# develop version
$ pip install git+https://github.com/thu-ml/tianshou.git@dev
```
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
```bash
# create a new virtualenv and install pip, change the env name if you like
conda create -n myenv pip
$ conda create -n myenv pip
# activate the environment
conda activate myenv
$ conda activate myenv
# install tianshou
pip install tianshou
$ pip install tianshou
```
After installation, open your python console and type
@ -82,9 +85,9 @@ If no error occurs, you have successfully installed Tianshou.
The tutorials and API documentation are hosted on [tianshou.readthedocs.io](https://tianshou.readthedocs.io/).
The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/master/test) folder and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folder. It may fail to run with PyPI installation, so please re-install the github version through `pip3 install git+https://github.com/thu-ml/tianshou.git@master`.
The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/master/test) folder and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folder.
中文文档位于 [https://tianshou.readthedocs.io/zh/latest/](https://tianshou.readthedocs.io/zh/latest/)
中文文档位于 [https://tianshou.readthedocs.io/zh/latest/](https://tianshou.readthedocs.io/zh/latest/)
<!-- 这里有一份天授平台简短的中文简介https://www.zhihu.com/question/377263715 -->
@ -95,7 +98,7 @@ The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/ma
Tianshou is a lightweight but high-speed reinforcement learning platform. For example, here is a test on a laptop (i7-8750H + GTX1060). It only uses 3 seconds for training an agent based on vanilla policy gradient on the CartPole-v0 task: (seed may be different across different platform and device)
```bash
python3 test/discrete/test_pg.py --seed 0 --render 0.03
$ python3 test/discrete/test_pg.py --seed 0 --render 0.03
```
<div align="center">
@ -108,10 +111,10 @@ We select some of famous reinforcement learning platforms: 2 GitHub repos with m
| --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| GitHub Stars | [![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/openai/baselines)](https://github.com/openai/baselines/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/hill-a/stable-baselines)](https://github.com/hill-a/stable-baselines/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/ray-project/ray)](https://github.com/ray-project/ray/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch)](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/astooke/rlpyt)](https://github.com/astooke/rlpyt/stargazers) |
| Algo - Task | PyTorch | TensorFlow | TensorFlow | TF/PyTorch | PyTorch | PyTorch |
| PG - CartPole | 6.09±4.60s | None | None | 19.26±2.29s | None | ? |
| DQN - CartPole | 6.09±0.87s | 1046.34±291.27s | 93.47±58.05s | 28.56±4.60s | 31.58±11.30s \*\* | ? |
| A2C - CartPole | 10.59±2.04s | \*(~1612s) | 57.56±12.87s | 57.92±9.94s | \*(Not converged) | ? |
| PPO - CartPole | 31.82±7.76s | \*(~1179s) | 34.79±17.02s | 44.60±17.04s | 23.99±9.26s \*\* | ? |
| PG - CartPole | 9.02±6.79s | None | None | 19.26±2.29s | None | ? |
| DQN - CartPole | 6.72±1.28s | 1046.34±291.27s | 93.47±58.05s | 28.56±4.60s | 31.58±11.30s \*\* | ? |
| A2C - CartPole | 15.33±4.48s | \*(~1612s) | 57.56±12.87s | 57.92±9.94s | \*(Not converged) | ? |
| PPO - CartPole | 6.01±1.14s | \*(~1179s) | 34.79±17.02s | 44.60±17.04s | 23.99±9.26s \*\* | ? |
| PPO - Pendulum | 16.18±2.49s | 745.43±160.82s | 259.73±27.37s | 123.62±44.23s | Runtime Error | ? |
| DDPG - Pendulum | 37.26±9.55s | \*(>1h) | 277.52±92.67s | 314.70±7.92s | 59.05±10.03s \*\* | 172.18±62.48s |
| TD3 - Pendulum | 44.04±6.37s | None | 99.75±21.63s | 149.90±7.54s | 57.52±17.71s \*\* | 210.31±76.30s |
@ -142,7 +145,7 @@ We decouple all of the algorithms into 4 parts:
- `process_fn`: to preprocess data from replay buffer (since we have reformulated all algorithms to replay-buffer based algorithms);
- `learn`: to learn from a given batch data.
Within these API, we can interact with different policies conveniently.
Within this API, we can interact with different policies conveniently.
### Elegant and Flexible
@ -182,17 +185,12 @@ Define some hyper-parameters:
```python
task = 'CartPole-v0'
lr = 1e-3
gamma = 0.9
n_step = 3
eps_train, eps_test = 0.1, 0.05
epoch = 10
step_per_epoch = 1000
collect_per_step = 10
target_freq = 320
batch_size = 64
lr, epoch, batch_size = 1e-3, 10, 64
train_num, test_num = 8, 100
gamma, n_step, target_freq = 0.9, 3, 320
buffer_size = 20000
eps_train, eps_test = 0.1, 0.05
step_per_epoch, collect_per_step = 1000, 10
writer = SummaryWriter('log/dqn') # tensorboard is also supported!
```
@ -208,7 +206,8 @@ Define the network:
```python
from tianshou.utils.net.common import Net
# you can define other net by following the API:
# https://tianshou.readthedocs.io/en/latest/tutorials/dqn.html#build-the-network
env = gym.make(task)
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
@ -219,8 +218,7 @@ optim = torch.optim.Adam(net.parameters(), lr=lr)
Setup policy and collectors:
```python
policy = ts.policy.DQNPolicy(net, optim, gamma, n_step,
target_update_freq=target_freq)
policy = ts.policy.DQNPolicy(net, optim, gamma, n_step, target_update_freq=target_freq)
train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(buffer_size))
test_collector = ts.data.Collector(policy, test_envs)
```
@ -236,7 +234,7 @@ result = ts.trainer.offpolicy_trainer(
print(f'Finished training! Use {result["duration"]}')
```
Save / load the trained policy (it's exactly the same as PyTorch nn.module):
Save / load the trained policy (it's exactly the same as PyTorch `nn.module`):
```python
torch.save(policy.state_dict(), 'dqn.pth')
@ -254,18 +252,18 @@ collector.close()
Look at the result saved in tensorboard: (with bash script in your terminal)
```bash
tensorboard --logdir log/dqn
$ tensorboard --logdir log/dqn
```
You can check out the [documentation](https://tianshou.readthedocs.io) for advanced usage.
## Contributing
Tianshou is still under development. More algorithms and features are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out [docs/contributing.rst](https://github.com/thu-ml/tianshou/blob/master/docs/contributing.rst).
Tianshou is still under development. More algorithms and features are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out [this link](https://tianshou.readthedocs.io/en/latest/contributing.html).
## TODO
Check out the [Issue/PR Categories](https://github.com/thu-ml/tianshou/projects/2) and [Support Status](https://github.com/thu-ml/tianshou/projects/3) page for more detail.
Check out the [Project](https://github.com/thu-ml/tianshou/projects) page for more detail.
## Citing Tianshou
@ -273,7 +271,7 @@ If you find Tianshou useful, please cite it in your publications.
```latex
@misc{tianshou,
author = {Jiayi Weng, Minghao Zhang, Dong Yan, Hang Su, Jun Zhu},
author = {Jiayi Weng, Minghao Zhang, Alexis Duburcq, Kaichao You, Dong Yan, Hang Su, Jun Zhu},
title = {Tianshou},
year = {2020},
publisher = {GitHub},

View File

@ -41,7 +41,7 @@ extensions = [
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.coverage',
'sphinx.ext.imgmath',
# 'sphinx.ext.imgmath',
'sphinx.ext.mathjax',
'sphinx.ext.ifconfig',
'sphinx.ext.viewcode',
@ -58,7 +58,9 @@ master_doc = 'index'
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
autodoc_default_options = {'special-members': '__call__, __getitem__, __len__'}
autodoc_default_options = {'special-members': ', '.join([
'__len__', '__call__', '__getitem__', '__setitem__',
'__getattr__', '__setattr__'])}
# -- Options for HTML output -------------------------------------------------

View File

@ -8,13 +8,14 @@ To install Tianshou in an "editable" mode, run
.. code-block:: bash
pip3 install -e ".[dev]"
$ git checkout dev
$ pip install -e ".[dev]"
in the main directory. This installation is removable by
.. code-block:: bash
python3 setup.py develop --uninstall
$ python setup.py develop --uninstall
PEP8 Code Style Check
---------------------
@ -23,7 +24,7 @@ We follow PEP8 python code style. To check, in the main directory, run:
.. code-block:: bash
flake8 . --count --show-source --statistics
$ flake8 . --count --show-source --statistics
Test Locally
------------
@ -32,7 +33,7 @@ This command will run automatic tests in the main directory
.. code-block:: bash
pytest test --cov tianshou -s --durations 0 -v
$ pytest test --cov tianshou -s --durations 0 -v
Test by GitHub Actions
----------------------
@ -65,11 +66,13 @@ To compile documentation into webpages, run
.. code-block:: bash
make html
$ make html
under the ``docs/`` directory. The generated webpages are in ``docs/_build`` and can be viewed with browsers.
Chinese Documentation
---------------------
Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/, and the develop version of documentation is in https://tianshou.readthedocs.io/en/dev/.
Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/
Pull Request
------------
All of the commits should merge through the pull request to the ``dev`` branch. The pull request must have 2 approvals before merging.

View File

@ -6,4 +6,4 @@ We always welcome contributions to help make Tianshou better. Below are an incom
* Jiayi Weng (`Trinkle23897 <https://github.com/Trinkle23897>`_)
* Minghao Zhang (`Mehooz <https://github.com/Mehooz>`_)
* Alexis Duburcq (`duburcqa <https://github.com/duburcqa>`_)
* Kaichao You (`youkaichao <https://github.com/youkaichao>`_)
* Kaichao You (`youkaichao <https://github.com/youkaichao>`_)

View File

@ -10,7 +10,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
@ -28,32 +28,38 @@ Here is Tianshou's other features:
* Support any type of environment state (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env`
* Support customized training process: :ref:`customize_training`
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` for all Q-learning based algorithms
* Support multi-agent RL easily (a tutorial is available at :doc:`/tutorials/tictactoe`)
* Support multi-agent RL: :doc:`/tutorials/tictactoe`
中文文档位于 https://tianshou.readthedocs.io/zh/latest/
中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ <https://tianshou.readthedocs.io/zh/latest/>`_
Installation
------------
Tianshou is currently hosted on `PyPI <https://pypi.org/project/tianshou/>`_. You can simply install Tianshou with the following command (with Python >= 3.6):
::
pip3 install tianshou
.. code-block:: bash
$ pip install tianshou
You can also install with the newest version through GitHub:
::
pip3 install git+https://github.com/thu-ml/tianshou.git@master
.. code-block:: bash
# latest release
$ pip install git+https://github.com/thu-ml/tianshou.git@master
# develop version
$ pip install git+https://github.com/thu-ml/tianshou.git@dev
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
::
.. code-block:: bash
# create a new virtualenv and install pip, change the env name if you like
conda create -n myenv pip
$ conda create -n myenv pip
# activate the environment
conda activate myenv
$ conda activate myenv
# install tianshou
pip install tianshou
$ pip install tianshou
After installation, open your python console and type
::
@ -63,7 +69,7 @@ After installation, open your python console and type
If no error occurs, you have successfully installed Tianshou.
Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ <https://tianshou.readthedocs.io/en/stable/>`_.
Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ <https://tianshou.readthedocs.io/en/stable/>`_ and the develop version through `tianshou.readthedocs.io/en/dev/ <https://tianshou.readthedocs.io/en/dev/>`_.
.. toctree::
:maxdepth: 1

View File

@ -8,7 +8,6 @@ The full script is at `test/discrete/test_dqn.py <https://github.com/thu-ml/tian
Contrary to existing Deep RL libraries such as `RLlib <https://github.com/ray-project/ray/tree/master/rllib/>`_, which could only accept a config specification of hyperparameters, network, and others, Tianshou provides an easy way of construction through the code-level.
Make an Environment
-------------------
@ -22,7 +21,6 @@ First of all, you have to make an environment for your agent to interact with. F
CartPole-v0 is a simple environment with a discrete action space, for which DQN applies. You have to identify whether the action space is continuous or discrete and apply eligible algorithms. DDPG :cite:`DDPG`, for example, could only be applied to continuous action spaces, while almost all other policy gradient methods could be applied to both, depending on the probability distribution on the action.
Setup Multi-environment Wrapper
-------------------------------
@ -79,17 +77,13 @@ You can also have a try with those pre-defined networks in :mod:`~tianshou.utils
1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment.
2. Output: some ``logits``, the next hidden state ``state``, and intermediate result during the policy forwarding procedure ``policy``. The logits could be a tuple instead of a ``torch.Tensor``. It depends on how the policy process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. The ``policy`` can be a Batch of torch.Tensor or other things, which will be stored in the replay buffer, and can be accessed in the policy update process (e.g. in ``policy.learn()``, the ``batch.policy`` is what you need).
Setup Policy
------------
We use the defined ``net`` and ``optim``, with extra policy hyper-parameters, to define a policy. Here we define a DQN policy with using a target network:
::
policy = ts.policy.DQNPolicy(net, optim,
discount_factor=0.9, estimation_step=3,
target_update_freq=320)
policy = ts.policy.DQNPolicy(net, optim, discount_factor=0.9, estimation_step=3, target_update_freq=320)
Setup Collector
---------------
@ -101,7 +95,6 @@ In each step, the collector will let the policy perform (at least) a specified n
train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(size=20000))
test_collector = ts.data.Collector(policy, test_envs)
Train Policy with a Trainer
---------------------------
@ -118,7 +111,7 @@ Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tians
writer=None)
print(f'Finished training! Use {result["duration"]}')
The meaning of each parameter is as follows:
The meaning of each parameter is as follows (full description can be found at :meth:`~tianshou.trainer.offpolicy_trainer`):
* ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``;
* ``step_per_epoch``: The number of step for updating policy network in one epoch;
@ -157,7 +150,6 @@ The returned result is a dictionary as follows:
It shows that within approximately 4 seconds, we finished training a DQN agent on CartPole. The mean returns over 100 consecutive episodes is 199.03.
Save/Load Policy
----------------
@ -167,7 +159,6 @@ Since the policy inherits the ``torch.nn.Module`` class, saving and loading the
torch.save(policy.state_dict(), 'dqn.pth')
policy.load_state_dict(torch.load('dqn.pth'))
Watch the Agent's Performance
-----------------------------
@ -178,7 +169,6 @@ Watch the Agent's Performance
collector.collect(n_episode=1, render=1 / 35)
collector.close()
.. _customized_trainer:
Train a Policy with Customized Codes
@ -186,7 +176,7 @@ Train a Policy with Customized Codes
"I don't want to use your provided trainer. I want to customize it!"
No problem! Tianshou supports user-defined training code. Here is the usage:
Tianshou supports user-defined training code. Here is the code snippet:
::
# pre-collect 5000 frames with random action before training
@ -212,7 +202,7 @@ No problem! Tianshou supports user-defined training code. Here is the usage:
# train policy with a sampled batch data
losses = policy.learn(train_collector.sample(batch_size=64))
For further usage, you can refer to :doc:`/tutorials/cheatsheet`.
For further usage, you can refer to the :doc:`/tutorials/cheatsheet`.
.. rubric:: References

View File

@ -162,8 +162,8 @@ Tianshou already provides some builtin classes for multi-agent learning. You can
Random agents perform badly. In the above game, although agent 2 wins finally, it is clear that a smart agent 1 would place an ``x`` at row 4 col 4 to win directly.
Train a MARL Agent
------------------
Train an MARL Agent
-------------------
So let's start to train our Tic-Tac-Toe agent! First, import some required modules.
::

View File

@ -84,7 +84,7 @@ def test_ddpg(args=get_args()):
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
args.batch_size, stop_fn=stop_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -94,7 +94,7 @@ def test_td3(args=get_args()):
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
args.batch_size, stop_fn=stop_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -19,7 +19,7 @@ def create_atari_environment(name=None, sticky_actions=True,
def preprocess_fn(obs=None, act=None, rew=None, done=None,
obs_next=None, info=None, policy=None):
obs_next=None, info=None, policy=None, **kwargs):
if obs_next is not None:
obs_next = np.reshape(obs_next, (-1, *obs_next.shape[2:]))
obs_next = np.moveaxis(obs_next, 0, -1)

View File

@ -101,7 +101,7 @@ def test_td3(args=get_args()):
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
args.batch_size, stop_fn=stop_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -89,8 +89,7 @@ def test_a2c(args=get_args()):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
task=args.task)
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
train_collector.close()
test_collector.close()
if __name__ == '__main__':

View File

@ -94,7 +94,7 @@ def test_dqn(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, writer=writer, task=args.task)
stop_fn=stop_fn, writer=writer)
train_collector.close()
test_collector.close()

View File

@ -93,8 +93,7 @@ def test_ppo(args=get_args()):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
task=args.task)
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
train_collector.close()
test_collector.close()
if __name__ == '__main__':

View File

@ -1,7 +1,7 @@
from tianshou import data, env, utils, policy, trainer, \
exploration
__version__ = '0.2.4'
__version__ = '0.2.5'
__all__ = [
'env',
'data',

View File

@ -10,29 +10,31 @@ class Net(nn.Module):
please refer to :ref:`build_the_network`.
:param concat: whether the input shape is concatenated by state_shape
and action_shape. If it is True, ``action_shape`` is not the output
shape, but affects the input shape.
and action_shape. If it is True, ``action_shape`` is not the output
shape, but affects the input shape.
"""
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu',
softmax=False, concat=False):
softmax=False, concat=False, hidden_layer_size=128):
super().__init__()
self.device = device
input_size = np.prod(state_shape)
if concat:
input_size += np.prod(action_shape)
self.model = [
nn.Linear(input_size, 128),
nn.Linear(input_size, hidden_layer_size),
nn.ReLU(inplace=True)]
for i in range(layer_num):
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
self.model += [nn.Linear(hidden_layer_size, hidden_layer_size),
nn.ReLU(inplace=True)]
if action_shape and not concat:
self.model += [nn.Linear(128, np.prod(action_shape))]
self.model += [nn.Linear(hidden_layer_size, np.prod(action_shape))]
if softmax:
self.model += [nn.Softmax(dim=-1)]
self.model = nn.Sequential(*self.model)
def forward(self, s, state=None, info={}):
"""s -> flatten -> logits"""
s = to_torch(s, device=self.device, dtype=torch.float32)
s = s.flatten(1)
logits = self.model(s)
@ -44,17 +46,23 @@ class Recurrent(nn.Module):
customize the network), please refer to :ref:`build_the_network`.
"""
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
def __init__(self, layer_num, state_shape, action_shape,
device='cpu', hidden_layer_size=128):
super().__init__()
self.state_shape = state_shape
self.action_shape = action_shape
self.device = device
self.nn = nn.LSTM(input_size=128, hidden_size=128,
self.nn = nn.LSTM(input_size=hidden_layer_size,
hidden_size=hidden_layer_size,
num_layers=layer_num, batch_first=True)
self.fc1 = nn.Linear(np.prod(state_shape), 128)
self.fc2 = nn.Linear(128, np.prod(action_shape))
self.fc1 = nn.Linear(np.prod(state_shape), hidden_layer_size)
self.fc2 = nn.Linear(hidden_layer_size, np.prod(action_shape))
def forward(self, s, state=None, info={}):
"""In the evaluation mode, s should be with shape ``[bsz, dim]``; in
the training mode, s should be with shape ``[bsz, len, dim]``. See the
code and comment for more detail.
"""
s = to_torch(s, device=self.device, dtype=torch.float32)
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which

View File

@ -11,13 +11,14 @@ class Actor(nn.Module):
"""
def __init__(self, preprocess_net, action_shape,
max_action, device='cpu'):
max_action, device='cpu', hidden_layer_size=128):
super().__init__()
self.preprocess = preprocess_net
self.last = nn.Linear(128, np.prod(action_shape))
self.last = nn.Linear(hidden_layer_size, np.prod(action_shape))
self._max = max_action
def forward(self, s, state=None, info={}):
"""s -> logits -> action"""
logits, h = self.preprocess(s, state)
logits = self._max * torch.tanh(self.last(logits))
return logits, h
@ -28,13 +29,14 @@ class Critic(nn.Module):
:ref:`build_the_network`.
"""
def __init__(self, preprocess_net, device='cpu'):
def __init__(self, preprocess_net, device='cpu', hidden_layer_size=128):
super().__init__()
self.device = device
self.preprocess = preprocess_net
self.last = nn.Linear(128, 1)
self.last = nn.Linear(hidden_layer_size, 1)
def forward(self, s, a=None, **kwargs):
"""(s, a) -> logits -> Q(s, a)"""
s = to_torch(s, device=self.device, dtype=torch.float32)
s = s.flatten(1)
if a is not None:
@ -51,17 +53,18 @@ class ActorProb(nn.Module):
:ref:`build_the_network`.
"""
def __init__(self, preprocess_net, action_shape,
max_action, device='cpu', unbounded=False):
def __init__(self, preprocess_net, action_shape, max_action,
device='cpu', unbounded=False, hidden_layer_size=128):
super().__init__()
self.preprocess = preprocess_net
self.device = device
self.mu = nn.Linear(128, np.prod(action_shape))
self.mu = nn.Linear(hidden_layer_size, np.prod(action_shape))
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
self._max = max_action
self._unbounded = unbounded
def forward(self, s, state=None, **kwargs):
"""s -> logits -> (mu, sigma)"""
logits, h = self.preprocess(s, state)
mu = self.mu(logits)
if not self._unbounded:
@ -78,15 +81,17 @@ class RecurrentActorProb(nn.Module):
"""
def __init__(self, layer_num, state_shape, action_shape,
max_action, device='cpu'):
max_action, device='cpu', hidden_layer_size=128):
super().__init__()
self.device = device
self.nn = nn.LSTM(input_size=np.prod(state_shape), hidden_size=128,
self.nn = nn.LSTM(input_size=np.prod(state_shape),
hidden_size=hidden_layer_size,
num_layers=layer_num, batch_first=True)
self.mu = nn.Linear(128, np.prod(action_shape))
self.mu = nn.Linear(hidden_layer_size, np.prod(action_shape))
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
def forward(self, s, **kwargs):
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
s = to_torch(s, device=self.device, dtype=torch.float32)
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which
@ -107,16 +112,19 @@ class RecurrentCritic(nn.Module):
:ref:`build_the_network`.
"""
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'):
def __init__(self, layer_num, state_shape,
action_shape=0, device='cpu', hidden_layer_size=128):
super().__init__()
self.state_shape = state_shape
self.action_shape = action_shape
self.device = device
self.nn = nn.LSTM(input_size=np.prod(state_shape), hidden_size=128,
self.nn = nn.LSTM(input_size=np.prod(state_shape),
hidden_size=hidden_layer_size,
num_layers=layer_num, batch_first=True)
self.fc2 = nn.Linear(128 + np.prod(action_shape), 1)
self.fc2 = nn.Linear(hidden_layer_size + np.prod(action_shape), 1)
def forward(self, s, a=None):
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
s = to_torch(s, device=self.device, dtype=torch.float32)
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which

View File

@ -9,12 +9,13 @@ class Actor(nn.Module):
:ref:`build_the_network`.
"""
def __init__(self, preprocess_net, action_shape):
def __init__(self, preprocess_net, action_shape, hidden_layer_size=128):
super().__init__()
self.preprocess = preprocess_net
self.last = nn.Linear(128, np.prod(action_shape))
self.last = nn.Linear(hidden_layer_size, np.prod(action_shape))
def forward(self, s, state=None, info={}):
r"""s -> Q(s, \*)"""
logits, h = self.preprocess(s, state)
logits = F.softmax(self.last(logits), dim=-1)
return logits, h
@ -25,12 +26,13 @@ class Critic(nn.Module):
:ref:`build_the_network`.
"""
def __init__(self, preprocess_net):
def __init__(self, preprocess_net, hidden_layer_size=128):
super().__init__()
self.preprocess = preprocess_net
self.last = nn.Linear(128, 1)
self.last = nn.Linear(hidden_layer_size, 1)
def forward(self, s, **kwargs):
"""s -> V(s)"""
logits, h = self.preprocess(s, state=kwargs.get('state', None))
logits = self.last(logits)
return logits
@ -62,6 +64,7 @@ class DQN(nn.Module):
self.head = nn.Linear(512, action_shape)
def forward(self, x, state=None, info={}):
r"""x -> Q(x, \*)"""
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, device=self.device, dtype=torch.float32)
x = x.permute(0, 3, 1, 2)