docs fix and v0.2.5 (#156)
* pre * update docs * update docs * $ in bash * size -> hidden_layer_size * doctest * doctest again * filter a warning * fix bug * fix examples * test fail * test succ
This commit is contained in:
parent
089b85b6a2
commit
bd9c3c7f8d
9
.github/ISSUE_TEMPLATE.md
vendored
9
.github/ISSUE_TEMPLATE.md
vendored
@ -3,15 +3,10 @@
|
|||||||
+ [ ] RL algorithm bug
|
+ [ ] RL algorithm bug
|
||||||
+ [ ] documentation request (i.e. "X is missing from the documentation.")
|
+ [ ] documentation request (i.e. "X is missing from the documentation.")
|
||||||
+ [ ] new feature request
|
+ [ ] new feature request
|
||||||
- [ ] I have visited the [source website], and in particular read the [known issues]
|
- [ ] I have visited the [source website](https://github.com/thu-ml/tianshou/)
|
||||||
- [ ] I have searched through the [issue tracker] and [issue categories] for duplicates
|
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
|
||||||
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
|
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
|
||||||
```python
|
```python
|
||||||
import tianshou, torch, sys
|
import tianshou, torch, sys
|
||||||
print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
|
print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
|
||||||
```
|
```
|
||||||
|
|
||||||
[source website]: https://github.com/thu-ml/tianshou/
|
|
||||||
[known issues]: https://github.com/thu-ml/tianshou/#faq-and-known-issues
|
|
||||||
[issue categories]: https://github.com/thu-ml/tianshou/projects/2
|
|
||||||
[issue tracker]: https://github.com/thu-ml/tianshou/issues?q=
|
|
||||||
|
|||||||
9
.github/PULL_REQUEST_TEMPLATE.md
vendored
9
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -7,15 +7,10 @@
|
|||||||
|
|
||||||
Less important but also useful:
|
Less important but also useful:
|
||||||
|
|
||||||
- [ ] I have visited the [source website], and in particular read the [known issues]
|
- [ ] I have visited the [source website](https://github.com/thu-ml/tianshou)
|
||||||
- [ ] I have searched through the [issue tracker] and [issue categories] for duplicates
|
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
|
||||||
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
|
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
|
||||||
```python
|
```python
|
||||||
import tianshou, torch, sys
|
import tianshou, torch, sys
|
||||||
print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
|
print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
|
||||||
```
|
```
|
||||||
|
|
||||||
[source website]: https://github.com/thu-ml/tianshou
|
|
||||||
[known issues]: https://github.com/thu-ml/tianshou/#faq-and-known-issues
|
|
||||||
[issue categories]: https://github.com/thu-ml/tianshou/projects/2
|
|
||||||
[issue tracker]: https://github.com/thu-ml/tianshou/issues?q=
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
name: PEP8 Check
|
name: PEP8 and Docs Check
|
||||||
|
|
||||||
on: [push, pull_request]
|
on: [push, pull_request]
|
||||||
|
|
||||||
@ -11,9 +11,20 @@ jobs:
|
|||||||
uses: actions/setup-python@v2
|
uses: actions/setup-python@v2
|
||||||
with:
|
with:
|
||||||
python-version: 3.8
|
python-version: 3.8
|
||||||
|
- name: Upgrade pip
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip setuptools wheel
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install flake8
|
python -m pip install flake8
|
||||||
- name: Lint with flake8
|
- name: Lint with flake8
|
||||||
run: |
|
run: |
|
||||||
flake8 . --count --show-source --statistics
|
flake8 . --count --show-source --statistics
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
pip install ".[dev]" --upgrade
|
||||||
|
- name: Documentation test
|
||||||
|
run: |
|
||||||
|
cd docs
|
||||||
|
make html SPHINXOPTS="-W"
|
||||||
|
cd ..
|
||||||
60
README.md
60
README.md
@ -38,7 +38,7 @@ Here is Tianshou's other features:
|
|||||||
- Support any type of environment state (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
|
- Support any type of environment state (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
|
||||||
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process)
|
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process)
|
||||||
- Support n-step returns estimation for all Q-learning based algorithms
|
- Support n-step returns estimation for all Q-learning based algorithms
|
||||||
- Support multi-agent RL easily [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning)
|
- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning)
|
||||||
|
|
||||||
In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.
|
In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.
|
||||||
|
|
||||||
@ -49,24 +49,27 @@ In Chinese, Tianshou means divinely ordained and is derived to the gift of being
|
|||||||
Tianshou is currently hosted on [PyPI](https://pypi.org/project/tianshou/). It requires Python >= 3.6. You can simply install Tianshou with the following command:
|
Tianshou is currently hosted on [PyPI](https://pypi.org/project/tianshou/). It requires Python >= 3.6. You can simply install Tianshou with the following command:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install tianshou
|
$ pip install tianshou
|
||||||
```
|
```
|
||||||
|
|
||||||
You can also install with the newest version through GitHub:
|
You can also install with the newest version through GitHub:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install git+https://github.com/thu-ml/tianshou.git@master
|
# latest release
|
||||||
|
$ pip install git+https://github.com/thu-ml/tianshou.git@master
|
||||||
|
# develop version
|
||||||
|
$ pip install git+https://github.com/thu-ml/tianshou.git@dev
|
||||||
```
|
```
|
||||||
|
|
||||||
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
|
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# create a new virtualenv and install pip, change the env name if you like
|
# create a new virtualenv and install pip, change the env name if you like
|
||||||
conda create -n myenv pip
|
$ conda create -n myenv pip
|
||||||
# activate the environment
|
# activate the environment
|
||||||
conda activate myenv
|
$ conda activate myenv
|
||||||
# install tianshou
|
# install tianshou
|
||||||
pip install tianshou
|
$ pip install tianshou
|
||||||
```
|
```
|
||||||
|
|
||||||
After installation, open your python console and type
|
After installation, open your python console and type
|
||||||
@ -82,9 +85,9 @@ If no error occurs, you have successfully installed Tianshou.
|
|||||||
|
|
||||||
The tutorials and API documentation are hosted on [tianshou.readthedocs.io](https://tianshou.readthedocs.io/).
|
The tutorials and API documentation are hosted on [tianshou.readthedocs.io](https://tianshou.readthedocs.io/).
|
||||||
|
|
||||||
The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/master/test) folder and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folder. It may fail to run with PyPI installation, so please re-install the github version through `pip3 install git+https://github.com/thu-ml/tianshou.git@master`.
|
The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/master/test) folder and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folder.
|
||||||
|
|
||||||
中文文档位于 [https://tianshou.readthedocs.io/zh/latest/](https://tianshou.readthedocs.io/zh/latest/)
|
中文文档位于 [https://tianshou.readthedocs.io/zh/latest/](https://tianshou.readthedocs.io/zh/latest/)。
|
||||||
|
|
||||||
<!-- 这里有一份天授平台简短的中文简介:https://www.zhihu.com/question/377263715 -->
|
<!-- 这里有一份天授平台简短的中文简介:https://www.zhihu.com/question/377263715 -->
|
||||||
|
|
||||||
@ -95,7 +98,7 @@ The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/ma
|
|||||||
Tianshou is a lightweight but high-speed reinforcement learning platform. For example, here is a test on a laptop (i7-8750H + GTX1060). It only uses 3 seconds for training an agent based on vanilla policy gradient on the CartPole-v0 task: (seed may be different across different platform and device)
|
Tianshou is a lightweight but high-speed reinforcement learning platform. For example, here is a test on a laptop (i7-8750H + GTX1060). It only uses 3 seconds for training an agent based on vanilla policy gradient on the CartPole-v0 task: (seed may be different across different platform and device)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python3 test/discrete/test_pg.py --seed 0 --render 0.03
|
$ python3 test/discrete/test_pg.py --seed 0 --render 0.03
|
||||||
```
|
```
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
@ -108,10 +111,10 @@ We select some of famous reinforcement learning platforms: 2 GitHub repos with m
|
|||||||
| --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
| --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||||
| GitHub Stars | [](https://github.com/thu-ml/tianshou/stargazers) | [](https://github.com/openai/baselines/stargazers) | [](https://github.com/hill-a/stable-baselines/stargazers) | [](https://github.com/ray-project/ray/stargazers) | [](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/stargazers) | [](https://github.com/astooke/rlpyt/stargazers) |
|
| GitHub Stars | [](https://github.com/thu-ml/tianshou/stargazers) | [](https://github.com/openai/baselines/stargazers) | [](https://github.com/hill-a/stable-baselines/stargazers) | [](https://github.com/ray-project/ray/stargazers) | [](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/stargazers) | [](https://github.com/astooke/rlpyt/stargazers) |
|
||||||
| Algo - Task | PyTorch | TensorFlow | TensorFlow | TF/PyTorch | PyTorch | PyTorch |
|
| Algo - Task | PyTorch | TensorFlow | TensorFlow | TF/PyTorch | PyTorch | PyTorch |
|
||||||
| PG - CartPole | 6.09±4.60s | None | None | 19.26±2.29s | None | ? |
|
| PG - CartPole | 9.02±6.79s | None | None | 19.26±2.29s | None | ? |
|
||||||
| DQN - CartPole | 6.09±0.87s | 1046.34±291.27s | 93.47±58.05s | 28.56±4.60s | 31.58±11.30s \*\* | ? |
|
| DQN - CartPole | 6.72±1.28s | 1046.34±291.27s | 93.47±58.05s | 28.56±4.60s | 31.58±11.30s \*\* | ? |
|
||||||
| A2C - CartPole | 10.59±2.04s | \*(~1612s) | 57.56±12.87s | 57.92±9.94s | \*(Not converged) | ? |
|
| A2C - CartPole | 15.33±4.48s | \*(~1612s) | 57.56±12.87s | 57.92±9.94s | \*(Not converged) | ? |
|
||||||
| PPO - CartPole | 31.82±7.76s | \*(~1179s) | 34.79±17.02s | 44.60±17.04s | 23.99±9.26s \*\* | ? |
|
| PPO - CartPole | 6.01±1.14s | \*(~1179s) | 34.79±17.02s | 44.60±17.04s | 23.99±9.26s \*\* | ? |
|
||||||
| PPO - Pendulum | 16.18±2.49s | 745.43±160.82s | 259.73±27.37s | 123.62±44.23s | Runtime Error | ? |
|
| PPO - Pendulum | 16.18±2.49s | 745.43±160.82s | 259.73±27.37s | 123.62±44.23s | Runtime Error | ? |
|
||||||
| DDPG - Pendulum | 37.26±9.55s | \*(>1h) | 277.52±92.67s | 314.70±7.92s | 59.05±10.03s \*\* | 172.18±62.48s |
|
| DDPG - Pendulum | 37.26±9.55s | \*(>1h) | 277.52±92.67s | 314.70±7.92s | 59.05±10.03s \*\* | 172.18±62.48s |
|
||||||
| TD3 - Pendulum | 44.04±6.37s | None | 99.75±21.63s | 149.90±7.54s | 57.52±17.71s \*\* | 210.31±76.30s |
|
| TD3 - Pendulum | 44.04±6.37s | None | 99.75±21.63s | 149.90±7.54s | 57.52±17.71s \*\* | 210.31±76.30s |
|
||||||
@ -142,7 +145,7 @@ We decouple all of the algorithms into 4 parts:
|
|||||||
- `process_fn`: to preprocess data from replay buffer (since we have reformulated all algorithms to replay-buffer based algorithms);
|
- `process_fn`: to preprocess data from replay buffer (since we have reformulated all algorithms to replay-buffer based algorithms);
|
||||||
- `learn`: to learn from a given batch data.
|
- `learn`: to learn from a given batch data.
|
||||||
|
|
||||||
Within these API, we can interact with different policies conveniently.
|
Within this API, we can interact with different policies conveniently.
|
||||||
|
|
||||||
### Elegant and Flexible
|
### Elegant and Flexible
|
||||||
|
|
||||||
@ -182,17 +185,12 @@ Define some hyper-parameters:
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
task = 'CartPole-v0'
|
task = 'CartPole-v0'
|
||||||
lr = 1e-3
|
lr, epoch, batch_size = 1e-3, 10, 64
|
||||||
gamma = 0.9
|
|
||||||
n_step = 3
|
|
||||||
eps_train, eps_test = 0.1, 0.05
|
|
||||||
epoch = 10
|
|
||||||
step_per_epoch = 1000
|
|
||||||
collect_per_step = 10
|
|
||||||
target_freq = 320
|
|
||||||
batch_size = 64
|
|
||||||
train_num, test_num = 8, 100
|
train_num, test_num = 8, 100
|
||||||
|
gamma, n_step, target_freq = 0.9, 3, 320
|
||||||
buffer_size = 20000
|
buffer_size = 20000
|
||||||
|
eps_train, eps_test = 0.1, 0.05
|
||||||
|
step_per_epoch, collect_per_step = 1000, 10
|
||||||
writer = SummaryWriter('log/dqn') # tensorboard is also supported!
|
writer = SummaryWriter('log/dqn') # tensorboard is also supported!
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -208,7 +206,8 @@ Define the network:
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
|
# you can define other net by following the API:
|
||||||
|
# https://tianshou.readthedocs.io/en/latest/tutorials/dqn.html#build-the-network
|
||||||
env = gym.make(task)
|
env = gym.make(task)
|
||||||
state_shape = env.observation_space.shape or env.observation_space.n
|
state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
action_shape = env.action_space.shape or env.action_space.n
|
action_shape = env.action_space.shape or env.action_space.n
|
||||||
@ -219,8 +218,7 @@ optim = torch.optim.Adam(net.parameters(), lr=lr)
|
|||||||
Setup policy and collectors:
|
Setup policy and collectors:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
policy = ts.policy.DQNPolicy(net, optim, gamma, n_step,
|
policy = ts.policy.DQNPolicy(net, optim, gamma, n_step, target_update_freq=target_freq)
|
||||||
target_update_freq=target_freq)
|
|
||||||
train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(buffer_size))
|
train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(buffer_size))
|
||||||
test_collector = ts.data.Collector(policy, test_envs)
|
test_collector = ts.data.Collector(policy, test_envs)
|
||||||
```
|
```
|
||||||
@ -236,7 +234,7 @@ result = ts.trainer.offpolicy_trainer(
|
|||||||
print(f'Finished training! Use {result["duration"]}')
|
print(f'Finished training! Use {result["duration"]}')
|
||||||
```
|
```
|
||||||
|
|
||||||
Save / load the trained policy (it's exactly the same as PyTorch nn.module):
|
Save / load the trained policy (it's exactly the same as PyTorch `nn.module`):
|
||||||
|
|
||||||
```python
|
```python
|
||||||
torch.save(policy.state_dict(), 'dqn.pth')
|
torch.save(policy.state_dict(), 'dqn.pth')
|
||||||
@ -254,18 +252,18 @@ collector.close()
|
|||||||
Look at the result saved in tensorboard: (with bash script in your terminal)
|
Look at the result saved in tensorboard: (with bash script in your terminal)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
tensorboard --logdir log/dqn
|
$ tensorboard --logdir log/dqn
|
||||||
```
|
```
|
||||||
|
|
||||||
You can check out the [documentation](https://tianshou.readthedocs.io) for advanced usage.
|
You can check out the [documentation](https://tianshou.readthedocs.io) for advanced usage.
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
Tianshou is still under development. More algorithms and features are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out [docs/contributing.rst](https://github.com/thu-ml/tianshou/blob/master/docs/contributing.rst).
|
Tianshou is still under development. More algorithms and features are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out [this link](https://tianshou.readthedocs.io/en/latest/contributing.html).
|
||||||
|
|
||||||
## TODO
|
## TODO
|
||||||
|
|
||||||
Check out the [Issue/PR Categories](https://github.com/thu-ml/tianshou/projects/2) and [Support Status](https://github.com/thu-ml/tianshou/projects/3) page for more detail.
|
Check out the [Project](https://github.com/thu-ml/tianshou/projects) page for more detail.
|
||||||
|
|
||||||
## Citing Tianshou
|
## Citing Tianshou
|
||||||
|
|
||||||
@ -273,7 +271,7 @@ If you find Tianshou useful, please cite it in your publications.
|
|||||||
|
|
||||||
```latex
|
```latex
|
||||||
@misc{tianshou,
|
@misc{tianshou,
|
||||||
author = {Jiayi Weng, Minghao Zhang, Dong Yan, Hang Su, Jun Zhu},
|
author = {Jiayi Weng, Minghao Zhang, Alexis Duburcq, Kaichao You, Dong Yan, Hang Su, Jun Zhu},
|
||||||
title = {Tianshou},
|
title = {Tianshou},
|
||||||
year = {2020},
|
year = {2020},
|
||||||
publisher = {GitHub},
|
publisher = {GitHub},
|
||||||
|
|||||||
@ -41,7 +41,7 @@ extensions = [
|
|||||||
'sphinx.ext.doctest',
|
'sphinx.ext.doctest',
|
||||||
'sphinx.ext.intersphinx',
|
'sphinx.ext.intersphinx',
|
||||||
'sphinx.ext.coverage',
|
'sphinx.ext.coverage',
|
||||||
'sphinx.ext.imgmath',
|
# 'sphinx.ext.imgmath',
|
||||||
'sphinx.ext.mathjax',
|
'sphinx.ext.mathjax',
|
||||||
'sphinx.ext.ifconfig',
|
'sphinx.ext.ifconfig',
|
||||||
'sphinx.ext.viewcode',
|
'sphinx.ext.viewcode',
|
||||||
@ -58,7 +58,9 @@ master_doc = 'index'
|
|||||||
# directories to ignore when looking for source files.
|
# directories to ignore when looking for source files.
|
||||||
# This pattern also affects html_static_path and html_extra_path.
|
# This pattern also affects html_static_path and html_extra_path.
|
||||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
||||||
autodoc_default_options = {'special-members': '__call__, __getitem__, __len__'}
|
autodoc_default_options = {'special-members': ', '.join([
|
||||||
|
'__len__', '__call__', '__getitem__', '__setitem__',
|
||||||
|
'__getattr__', '__setattr__'])}
|
||||||
|
|
||||||
# -- Options for HTML output -------------------------------------------------
|
# -- Options for HTML output -------------------------------------------------
|
||||||
|
|
||||||
|
|||||||
@ -8,13 +8,14 @@ To install Tianshou in an "editable" mode, run
|
|||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
pip3 install -e ".[dev]"
|
$ git checkout dev
|
||||||
|
$ pip install -e ".[dev]"
|
||||||
|
|
||||||
in the main directory. This installation is removable by
|
in the main directory. This installation is removable by
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
python3 setup.py develop --uninstall
|
$ python setup.py develop --uninstall
|
||||||
|
|
||||||
PEP8 Code Style Check
|
PEP8 Code Style Check
|
||||||
---------------------
|
---------------------
|
||||||
@ -23,7 +24,7 @@ We follow PEP8 python code style. To check, in the main directory, run:
|
|||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
flake8 . --count --show-source --statistics
|
$ flake8 . --count --show-source --statistics
|
||||||
|
|
||||||
Test Locally
|
Test Locally
|
||||||
------------
|
------------
|
||||||
@ -32,7 +33,7 @@ This command will run automatic tests in the main directory
|
|||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
pytest test --cov tianshou -s --durations 0 -v
|
$ pytest test --cov tianshou -s --durations 0 -v
|
||||||
|
|
||||||
Test by GitHub Actions
|
Test by GitHub Actions
|
||||||
----------------------
|
----------------------
|
||||||
@ -65,11 +66,13 @@ To compile documentation into webpages, run
|
|||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
make html
|
$ make html
|
||||||
|
|
||||||
under the ``docs/`` directory. The generated webpages are in ``docs/_build`` and can be viewed with browsers.
|
under the ``docs/`` directory. The generated webpages are in ``docs/_build`` and can be viewed with browsers.
|
||||||
|
|
||||||
Chinese Documentation
|
Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/, and the develop version of documentation is in https://tianshou.readthedocs.io/en/dev/.
|
||||||
---------------------
|
|
||||||
|
|
||||||
Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/
|
Pull Request
|
||||||
|
------------
|
||||||
|
|
||||||
|
All of the commits should merge through the pull request to the ``dev`` branch. The pull request must have 2 approvals before merging.
|
||||||
|
|||||||
@ -10,7 +10,7 @@ Welcome to Tianshou!
|
|||||||
|
|
||||||
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
|
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
|
||||||
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
|
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
|
||||||
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns
|
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_
|
||||||
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
|
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
|
||||||
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
|
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
|
||||||
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
||||||
@ -28,32 +28,38 @@ Here is Tianshou's other features:
|
|||||||
* Support any type of environment state (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env`
|
* Support any type of environment state (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env`
|
||||||
* Support customized training process: :ref:`customize_training`
|
* Support customized training process: :ref:`customize_training`
|
||||||
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` for all Q-learning based algorithms
|
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` for all Q-learning based algorithms
|
||||||
* Support multi-agent RL easily (a tutorial is available at :doc:`/tutorials/tictactoe`)
|
* Support multi-agent RL: :doc:`/tutorials/tictactoe`
|
||||||
|
|
||||||
中文文档位于 https://tianshou.readthedocs.io/zh/latest/
|
中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ <https://tianshou.readthedocs.io/zh/latest/>`_
|
||||||
|
|
||||||
Installation
|
Installation
|
||||||
------------
|
------------
|
||||||
|
|
||||||
Tianshou is currently hosted on `PyPI <https://pypi.org/project/tianshou/>`_. You can simply install Tianshou with the following command (with Python >= 3.6):
|
Tianshou is currently hosted on `PyPI <https://pypi.org/project/tianshou/>`_. You can simply install Tianshou with the following command (with Python >= 3.6):
|
||||||
::
|
|
||||||
|
|
||||||
pip3 install tianshou
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ pip install tianshou
|
||||||
|
|
||||||
You can also install with the newest version through GitHub:
|
You can also install with the newest version through GitHub:
|
||||||
::
|
|
||||||
|
|
||||||
pip3 install git+https://github.com/thu-ml/tianshou.git@master
|
.. code-block:: bash
|
||||||
|
|
||||||
|
# latest release
|
||||||
|
$ pip install git+https://github.com/thu-ml/tianshou.git@master
|
||||||
|
# develop version
|
||||||
|
$ pip install git+https://github.com/thu-ml/tianshou.git@dev
|
||||||
|
|
||||||
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
|
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
|
||||||
::
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
# create a new virtualenv and install pip, change the env name if you like
|
# create a new virtualenv and install pip, change the env name if you like
|
||||||
conda create -n myenv pip
|
$ conda create -n myenv pip
|
||||||
# activate the environment
|
# activate the environment
|
||||||
conda activate myenv
|
$ conda activate myenv
|
||||||
# install tianshou
|
# install tianshou
|
||||||
pip install tianshou
|
$ pip install tianshou
|
||||||
|
|
||||||
After installation, open your python console and type
|
After installation, open your python console and type
|
||||||
::
|
::
|
||||||
@ -63,7 +69,7 @@ After installation, open your python console and type
|
|||||||
|
|
||||||
If no error occurs, you have successfully installed Tianshou.
|
If no error occurs, you have successfully installed Tianshou.
|
||||||
|
|
||||||
Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ <https://tianshou.readthedocs.io/en/stable/>`_.
|
Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ <https://tianshou.readthedocs.io/en/stable/>`_ and the develop version through `tianshou.readthedocs.io/en/dev/ <https://tianshou.readthedocs.io/en/dev/>`_.
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|||||||
@ -8,7 +8,6 @@ The full script is at `test/discrete/test_dqn.py <https://github.com/thu-ml/tian
|
|||||||
|
|
||||||
Contrary to existing Deep RL libraries such as `RLlib <https://github.com/ray-project/ray/tree/master/rllib/>`_, which could only accept a config specification of hyperparameters, network, and others, Tianshou provides an easy way of construction through the code-level.
|
Contrary to existing Deep RL libraries such as `RLlib <https://github.com/ray-project/ray/tree/master/rllib/>`_, which could only accept a config specification of hyperparameters, network, and others, Tianshou provides an easy way of construction through the code-level.
|
||||||
|
|
||||||
|
|
||||||
Make an Environment
|
Make an Environment
|
||||||
-------------------
|
-------------------
|
||||||
|
|
||||||
@ -22,7 +21,6 @@ First of all, you have to make an environment for your agent to interact with. F
|
|||||||
|
|
||||||
CartPole-v0 is a simple environment with a discrete action space, for which DQN applies. You have to identify whether the action space is continuous or discrete and apply eligible algorithms. DDPG :cite:`DDPG`, for example, could only be applied to continuous action spaces, while almost all other policy gradient methods could be applied to both, depending on the probability distribution on the action.
|
CartPole-v0 is a simple environment with a discrete action space, for which DQN applies. You have to identify whether the action space is continuous or discrete and apply eligible algorithms. DDPG :cite:`DDPG`, for example, could only be applied to continuous action spaces, while almost all other policy gradient methods could be applied to both, depending on the probability distribution on the action.
|
||||||
|
|
||||||
|
|
||||||
Setup Multi-environment Wrapper
|
Setup Multi-environment Wrapper
|
||||||
-------------------------------
|
-------------------------------
|
||||||
|
|
||||||
@ -79,17 +77,13 @@ You can also have a try with those pre-defined networks in :mod:`~tianshou.utils
|
|||||||
1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment.
|
1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment.
|
||||||
2. Output: some ``logits``, the next hidden state ``state``, and intermediate result during the policy forwarding procedure ``policy``. The logits could be a tuple instead of a ``torch.Tensor``. It depends on how the policy process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. The ``policy`` can be a Batch of torch.Tensor or other things, which will be stored in the replay buffer, and can be accessed in the policy update process (e.g. in ``policy.learn()``, the ``batch.policy`` is what you need).
|
2. Output: some ``logits``, the next hidden state ``state``, and intermediate result during the policy forwarding procedure ``policy``. The logits could be a tuple instead of a ``torch.Tensor``. It depends on how the policy process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. The ``policy`` can be a Batch of torch.Tensor or other things, which will be stored in the replay buffer, and can be accessed in the policy update process (e.g. in ``policy.learn()``, the ``batch.policy`` is what you need).
|
||||||
|
|
||||||
|
|
||||||
Setup Policy
|
Setup Policy
|
||||||
------------
|
------------
|
||||||
|
|
||||||
We use the defined ``net`` and ``optim``, with extra policy hyper-parameters, to define a policy. Here we define a DQN policy with using a target network:
|
We use the defined ``net`` and ``optim``, with extra policy hyper-parameters, to define a policy. Here we define a DQN policy with using a target network:
|
||||||
::
|
::
|
||||||
|
|
||||||
policy = ts.policy.DQNPolicy(net, optim,
|
policy = ts.policy.DQNPolicy(net, optim, discount_factor=0.9, estimation_step=3, target_update_freq=320)
|
||||||
discount_factor=0.9, estimation_step=3,
|
|
||||||
target_update_freq=320)
|
|
||||||
|
|
||||||
|
|
||||||
Setup Collector
|
Setup Collector
|
||||||
---------------
|
---------------
|
||||||
@ -101,7 +95,6 @@ In each step, the collector will let the policy perform (at least) a specified n
|
|||||||
train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(size=20000))
|
train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(size=20000))
|
||||||
test_collector = ts.data.Collector(policy, test_envs)
|
test_collector = ts.data.Collector(policy, test_envs)
|
||||||
|
|
||||||
|
|
||||||
Train Policy with a Trainer
|
Train Policy with a Trainer
|
||||||
---------------------------
|
---------------------------
|
||||||
|
|
||||||
@ -118,7 +111,7 @@ Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tians
|
|||||||
writer=None)
|
writer=None)
|
||||||
print(f'Finished training! Use {result["duration"]}')
|
print(f'Finished training! Use {result["duration"]}')
|
||||||
|
|
||||||
The meaning of each parameter is as follows:
|
The meaning of each parameter is as follows (full description can be found at :meth:`~tianshou.trainer.offpolicy_trainer`):
|
||||||
|
|
||||||
* ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``;
|
* ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``;
|
||||||
* ``step_per_epoch``: The number of step for updating policy network in one epoch;
|
* ``step_per_epoch``: The number of step for updating policy network in one epoch;
|
||||||
@ -157,7 +150,6 @@ The returned result is a dictionary as follows:
|
|||||||
|
|
||||||
It shows that within approximately 4 seconds, we finished training a DQN agent on CartPole. The mean returns over 100 consecutive episodes is 199.03.
|
It shows that within approximately 4 seconds, we finished training a DQN agent on CartPole. The mean returns over 100 consecutive episodes is 199.03.
|
||||||
|
|
||||||
|
|
||||||
Save/Load Policy
|
Save/Load Policy
|
||||||
----------------
|
----------------
|
||||||
|
|
||||||
@ -167,7 +159,6 @@ Since the policy inherits the ``torch.nn.Module`` class, saving and loading the
|
|||||||
torch.save(policy.state_dict(), 'dqn.pth')
|
torch.save(policy.state_dict(), 'dqn.pth')
|
||||||
policy.load_state_dict(torch.load('dqn.pth'))
|
policy.load_state_dict(torch.load('dqn.pth'))
|
||||||
|
|
||||||
|
|
||||||
Watch the Agent's Performance
|
Watch the Agent's Performance
|
||||||
-----------------------------
|
-----------------------------
|
||||||
|
|
||||||
@ -178,7 +169,6 @@ Watch the Agent's Performance
|
|||||||
collector.collect(n_episode=1, render=1 / 35)
|
collector.collect(n_episode=1, render=1 / 35)
|
||||||
collector.close()
|
collector.close()
|
||||||
|
|
||||||
|
|
||||||
.. _customized_trainer:
|
.. _customized_trainer:
|
||||||
|
|
||||||
Train a Policy with Customized Codes
|
Train a Policy with Customized Codes
|
||||||
@ -186,7 +176,7 @@ Train a Policy with Customized Codes
|
|||||||
|
|
||||||
"I don't want to use your provided trainer. I want to customize it!"
|
"I don't want to use your provided trainer. I want to customize it!"
|
||||||
|
|
||||||
No problem! Tianshou supports user-defined training code. Here is the usage:
|
Tianshou supports user-defined training code. Here is the code snippet:
|
||||||
::
|
::
|
||||||
|
|
||||||
# pre-collect 5000 frames with random action before training
|
# pre-collect 5000 frames with random action before training
|
||||||
@ -212,7 +202,7 @@ No problem! Tianshou supports user-defined training code. Here is the usage:
|
|||||||
# train policy with a sampled batch data
|
# train policy with a sampled batch data
|
||||||
losses = policy.learn(train_collector.sample(batch_size=64))
|
losses = policy.learn(train_collector.sample(batch_size=64))
|
||||||
|
|
||||||
For further usage, you can refer to :doc:`/tutorials/cheatsheet`.
|
For further usage, you can refer to the :doc:`/tutorials/cheatsheet`.
|
||||||
|
|
||||||
.. rubric:: References
|
.. rubric:: References
|
||||||
|
|
||||||
|
|||||||
@ -162,8 +162,8 @@ Tianshou already provides some builtin classes for multi-agent learning. You can
|
|||||||
|
|
||||||
Random agents perform badly. In the above game, although agent 2 wins finally, it is clear that a smart agent 1 would place an ``x`` at row 4 col 4 to win directly.
|
Random agents perform badly. In the above game, although agent 2 wins finally, it is clear that a smart agent 1 would place an ``x`` at row 4 col 4 to win directly.
|
||||||
|
|
||||||
Train a MARL Agent
|
Train an MARL Agent
|
||||||
------------------
|
-------------------
|
||||||
|
|
||||||
So let's start to train our Tic-Tac-Toe agent! First, import some required modules.
|
So let's start to train our Tic-Tac-Toe agent! First, import some required modules.
|
||||||
::
|
::
|
||||||
|
|||||||
@ -84,7 +84,7 @@ def test_ddpg(args=get_args()):
|
|||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
|
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
|
|||||||
@ -94,7 +94,7 @@ def test_td3(args=get_args()):
|
|||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
|
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
|
|||||||
@ -19,7 +19,7 @@ def create_atari_environment(name=None, sticky_actions=True,
|
|||||||
|
|
||||||
|
|
||||||
def preprocess_fn(obs=None, act=None, rew=None, done=None,
|
def preprocess_fn(obs=None, act=None, rew=None, done=None,
|
||||||
obs_next=None, info=None, policy=None):
|
obs_next=None, info=None, policy=None, **kwargs):
|
||||||
if obs_next is not None:
|
if obs_next is not None:
|
||||||
obs_next = np.reshape(obs_next, (-1, *obs_next.shape[2:]))
|
obs_next = np.reshape(obs_next, (-1, *obs_next.shape[2:]))
|
||||||
obs_next = np.moveaxis(obs_next, 0, -1)
|
obs_next = np.moveaxis(obs_next, 0, -1)
|
||||||
|
|||||||
@ -101,7 +101,7 @@ def test_td3(args=get_args()):
|
|||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
|
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
|
|||||||
@ -89,8 +89,7 @@ def test_a2c(args=get_args()):
|
|||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
|
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||||
task=args.task)
|
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -94,7 +94,7 @@ def test_dqn(args=get_args()):
|
|||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
stop_fn=stop_fn, writer=writer, task=args.task)
|
stop_fn=stop_fn, writer=writer)
|
||||||
|
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
|
|||||||
@ -93,8 +93,7 @@ def test_ppo(args=get_args()):
|
|||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
|
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||||
task=args.task)
|
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from tianshou import data, env, utils, policy, trainer, \
|
from tianshou import data, env, utils, policy, trainer, \
|
||||||
exploration
|
exploration
|
||||||
|
|
||||||
__version__ = '0.2.4'
|
__version__ = '0.2.5'
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'env',
|
'env',
|
||||||
'data',
|
'data',
|
||||||
|
|||||||
@ -10,29 +10,31 @@ class Net(nn.Module):
|
|||||||
please refer to :ref:`build_the_network`.
|
please refer to :ref:`build_the_network`.
|
||||||
|
|
||||||
:param concat: whether the input shape is concatenated by state_shape
|
:param concat: whether the input shape is concatenated by state_shape
|
||||||
and action_shape. If it is True, ``action_shape`` is not the output
|
and action_shape. If it is True, ``action_shape`` is not the output
|
||||||
shape, but affects the input shape.
|
shape, but affects the input shape.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu',
|
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu',
|
||||||
softmax=False, concat=False):
|
softmax=False, concat=False, hidden_layer_size=128):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
input_size = np.prod(state_shape)
|
input_size = np.prod(state_shape)
|
||||||
if concat:
|
if concat:
|
||||||
input_size += np.prod(action_shape)
|
input_size += np.prod(action_shape)
|
||||||
self.model = [
|
self.model = [
|
||||||
nn.Linear(input_size, 128),
|
nn.Linear(input_size, hidden_layer_size),
|
||||||
nn.ReLU(inplace=True)]
|
nn.ReLU(inplace=True)]
|
||||||
for i in range(layer_num):
|
for i in range(layer_num):
|
||||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
self.model += [nn.Linear(hidden_layer_size, hidden_layer_size),
|
||||||
|
nn.ReLU(inplace=True)]
|
||||||
if action_shape and not concat:
|
if action_shape and not concat:
|
||||||
self.model += [nn.Linear(128, np.prod(action_shape))]
|
self.model += [nn.Linear(hidden_layer_size, np.prod(action_shape))]
|
||||||
if softmax:
|
if softmax:
|
||||||
self.model += [nn.Softmax(dim=-1)]
|
self.model += [nn.Softmax(dim=-1)]
|
||||||
self.model = nn.Sequential(*self.model)
|
self.model = nn.Sequential(*self.model)
|
||||||
|
|
||||||
def forward(self, s, state=None, info={}):
|
def forward(self, s, state=None, info={}):
|
||||||
|
"""s -> flatten -> logits"""
|
||||||
s = to_torch(s, device=self.device, dtype=torch.float32)
|
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||||
s = s.flatten(1)
|
s = s.flatten(1)
|
||||||
logits = self.model(s)
|
logits = self.model(s)
|
||||||
@ -44,17 +46,23 @@ class Recurrent(nn.Module):
|
|||||||
customize the network), please refer to :ref:`build_the_network`.
|
customize the network), please refer to :ref:`build_the_network`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
|
def __init__(self, layer_num, state_shape, action_shape,
|
||||||
|
device='cpu', hidden_layer_size=128):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.state_shape = state_shape
|
self.state_shape = state_shape
|
||||||
self.action_shape = action_shape
|
self.action_shape = action_shape
|
||||||
self.device = device
|
self.device = device
|
||||||
self.nn = nn.LSTM(input_size=128, hidden_size=128,
|
self.nn = nn.LSTM(input_size=hidden_layer_size,
|
||||||
|
hidden_size=hidden_layer_size,
|
||||||
num_layers=layer_num, batch_first=True)
|
num_layers=layer_num, batch_first=True)
|
||||||
self.fc1 = nn.Linear(np.prod(state_shape), 128)
|
self.fc1 = nn.Linear(np.prod(state_shape), hidden_layer_size)
|
||||||
self.fc2 = nn.Linear(128, np.prod(action_shape))
|
self.fc2 = nn.Linear(hidden_layer_size, np.prod(action_shape))
|
||||||
|
|
||||||
def forward(self, s, state=None, info={}):
|
def forward(self, s, state=None, info={}):
|
||||||
|
"""In the evaluation mode, s should be with shape ``[bsz, dim]``; in
|
||||||
|
the training mode, s should be with shape ``[bsz, len, dim]``. See the
|
||||||
|
code and comment for more detail.
|
||||||
|
"""
|
||||||
s = to_torch(s, device=self.device, dtype=torch.float32)
|
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||||
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||||
# In short, the tensor's shape in training phase is longer than which
|
# In short, the tensor's shape in training phase is longer than which
|
||||||
|
|||||||
@ -11,13 +11,14 @@ class Actor(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, preprocess_net, action_shape,
|
def __init__(self, preprocess_net, action_shape,
|
||||||
max_action, device='cpu'):
|
max_action, device='cpu', hidden_layer_size=128):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.preprocess = preprocess_net
|
self.preprocess = preprocess_net
|
||||||
self.last = nn.Linear(128, np.prod(action_shape))
|
self.last = nn.Linear(hidden_layer_size, np.prod(action_shape))
|
||||||
self._max = max_action
|
self._max = max_action
|
||||||
|
|
||||||
def forward(self, s, state=None, info={}):
|
def forward(self, s, state=None, info={}):
|
||||||
|
"""s -> logits -> action"""
|
||||||
logits, h = self.preprocess(s, state)
|
logits, h = self.preprocess(s, state)
|
||||||
logits = self._max * torch.tanh(self.last(logits))
|
logits = self._max * torch.tanh(self.last(logits))
|
||||||
return logits, h
|
return logits, h
|
||||||
@ -28,13 +29,14 @@ class Critic(nn.Module):
|
|||||||
:ref:`build_the_network`.
|
:ref:`build_the_network`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, preprocess_net, device='cpu'):
|
def __init__(self, preprocess_net, device='cpu', hidden_layer_size=128):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.preprocess = preprocess_net
|
self.preprocess = preprocess_net
|
||||||
self.last = nn.Linear(128, 1)
|
self.last = nn.Linear(hidden_layer_size, 1)
|
||||||
|
|
||||||
def forward(self, s, a=None, **kwargs):
|
def forward(self, s, a=None, **kwargs):
|
||||||
|
"""(s, a) -> logits -> Q(s, a)"""
|
||||||
s = to_torch(s, device=self.device, dtype=torch.float32)
|
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||||
s = s.flatten(1)
|
s = s.flatten(1)
|
||||||
if a is not None:
|
if a is not None:
|
||||||
@ -51,17 +53,18 @@ class ActorProb(nn.Module):
|
|||||||
:ref:`build_the_network`.
|
:ref:`build_the_network`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, preprocess_net, action_shape,
|
def __init__(self, preprocess_net, action_shape, max_action,
|
||||||
max_action, device='cpu', unbounded=False):
|
device='cpu', unbounded=False, hidden_layer_size=128):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.preprocess = preprocess_net
|
self.preprocess = preprocess_net
|
||||||
self.device = device
|
self.device = device
|
||||||
self.mu = nn.Linear(128, np.prod(action_shape))
|
self.mu = nn.Linear(hidden_layer_size, np.prod(action_shape))
|
||||||
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
|
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
|
||||||
self._max = max_action
|
self._max = max_action
|
||||||
self._unbounded = unbounded
|
self._unbounded = unbounded
|
||||||
|
|
||||||
def forward(self, s, state=None, **kwargs):
|
def forward(self, s, state=None, **kwargs):
|
||||||
|
"""s -> logits -> (mu, sigma)"""
|
||||||
logits, h = self.preprocess(s, state)
|
logits, h = self.preprocess(s, state)
|
||||||
mu = self.mu(logits)
|
mu = self.mu(logits)
|
||||||
if not self._unbounded:
|
if not self._unbounded:
|
||||||
@ -78,15 +81,17 @@ class RecurrentActorProb(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, layer_num, state_shape, action_shape,
|
def __init__(self, layer_num, state_shape, action_shape,
|
||||||
max_action, device='cpu'):
|
max_action, device='cpu', hidden_layer_size=128):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.nn = nn.LSTM(input_size=np.prod(state_shape), hidden_size=128,
|
self.nn = nn.LSTM(input_size=np.prod(state_shape),
|
||||||
|
hidden_size=hidden_layer_size,
|
||||||
num_layers=layer_num, batch_first=True)
|
num_layers=layer_num, batch_first=True)
|
||||||
self.mu = nn.Linear(128, np.prod(action_shape))
|
self.mu = nn.Linear(hidden_layer_size, np.prod(action_shape))
|
||||||
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
|
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
|
||||||
|
|
||||||
def forward(self, s, **kwargs):
|
def forward(self, s, **kwargs):
|
||||||
|
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
|
||||||
s = to_torch(s, device=self.device, dtype=torch.float32)
|
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||||
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||||
# In short, the tensor's shape in training phase is longer than which
|
# In short, the tensor's shape in training phase is longer than which
|
||||||
@ -107,16 +112,19 @@ class RecurrentCritic(nn.Module):
|
|||||||
:ref:`build_the_network`.
|
:ref:`build_the_network`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'):
|
def __init__(self, layer_num, state_shape,
|
||||||
|
action_shape=0, device='cpu', hidden_layer_size=128):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.state_shape = state_shape
|
self.state_shape = state_shape
|
||||||
self.action_shape = action_shape
|
self.action_shape = action_shape
|
||||||
self.device = device
|
self.device = device
|
||||||
self.nn = nn.LSTM(input_size=np.prod(state_shape), hidden_size=128,
|
self.nn = nn.LSTM(input_size=np.prod(state_shape),
|
||||||
|
hidden_size=hidden_layer_size,
|
||||||
num_layers=layer_num, batch_first=True)
|
num_layers=layer_num, batch_first=True)
|
||||||
self.fc2 = nn.Linear(128 + np.prod(action_shape), 1)
|
self.fc2 = nn.Linear(hidden_layer_size + np.prod(action_shape), 1)
|
||||||
|
|
||||||
def forward(self, s, a=None):
|
def forward(self, s, a=None):
|
||||||
|
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
|
||||||
s = to_torch(s, device=self.device, dtype=torch.float32)
|
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||||
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||||
# In short, the tensor's shape in training phase is longer than which
|
# In short, the tensor's shape in training phase is longer than which
|
||||||
|
|||||||
@ -9,12 +9,13 @@ class Actor(nn.Module):
|
|||||||
:ref:`build_the_network`.
|
:ref:`build_the_network`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, preprocess_net, action_shape):
|
def __init__(self, preprocess_net, action_shape, hidden_layer_size=128):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.preprocess = preprocess_net
|
self.preprocess = preprocess_net
|
||||||
self.last = nn.Linear(128, np.prod(action_shape))
|
self.last = nn.Linear(hidden_layer_size, np.prod(action_shape))
|
||||||
|
|
||||||
def forward(self, s, state=None, info={}):
|
def forward(self, s, state=None, info={}):
|
||||||
|
r"""s -> Q(s, \*)"""
|
||||||
logits, h = self.preprocess(s, state)
|
logits, h = self.preprocess(s, state)
|
||||||
logits = F.softmax(self.last(logits), dim=-1)
|
logits = F.softmax(self.last(logits), dim=-1)
|
||||||
return logits, h
|
return logits, h
|
||||||
@ -25,12 +26,13 @@ class Critic(nn.Module):
|
|||||||
:ref:`build_the_network`.
|
:ref:`build_the_network`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, preprocess_net):
|
def __init__(self, preprocess_net, hidden_layer_size=128):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.preprocess = preprocess_net
|
self.preprocess = preprocess_net
|
||||||
self.last = nn.Linear(128, 1)
|
self.last = nn.Linear(hidden_layer_size, 1)
|
||||||
|
|
||||||
def forward(self, s, **kwargs):
|
def forward(self, s, **kwargs):
|
||||||
|
"""s -> V(s)"""
|
||||||
logits, h = self.preprocess(s, state=kwargs.get('state', None))
|
logits, h = self.preprocess(s, state=kwargs.get('state', None))
|
||||||
logits = self.last(logits)
|
logits = self.last(logits)
|
||||||
return logits
|
return logits
|
||||||
@ -62,6 +64,7 @@ class DQN(nn.Module):
|
|||||||
self.head = nn.Linear(512, action_shape)
|
self.head = nn.Linear(512, action_shape)
|
||||||
|
|
||||||
def forward(self, x, state=None, info={}):
|
def forward(self, x, state=None, info={}):
|
||||||
|
r"""x -> Q(x, \*)"""
|
||||||
if not isinstance(x, torch.Tensor):
|
if not isinstance(x, torch.Tensor):
|
||||||
x = torch.tensor(x, device=self.device, dtype=torch.float32)
|
x = torch.tensor(x, device=self.device, dtype=torch.float32)
|
||||||
x = x.permute(0, 3, 1, 2)
|
x = x.permute(0, 3, 1, 2)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user