This commit is contained in:
Trinkle23897 2020-06-01 08:30:09 +08:00
parent 1fce527c77
commit ba1b3e54eb
13 changed files with 84 additions and 73 deletions

View File

@ -17,13 +17,14 @@ jobs:
uses: actions/setup-python@v1 uses: actions/setup-python@v1
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ".[dev]"
- name: Lint with flake8 - name: Lint with flake8
run: | run: |
python -m pip install --upgrade pip
pip install flake8
flake8 . --count --show-source --statistics flake8 . --count --show-source --statistics
- name: Install dependencies
run: |
pip install ".[dev]"
- name: Test with pytest - name: Test with pytest
run: | run: |
pytest test --cov tianshou --cov-report=xml --durations 0 -v pytest test --cov tianshou --cov-report=xml --durations 0 -v

1
CONTRIBUTING.rst Symbolic link
View File

@ -0,0 +1 @@
docs/contributing.rst

View File

@ -1,14 +1,15 @@
============
Contributing Contributing
============ ============
Install Develop Version Install Develop Version
----------------------- =======================
To install Tianshou in an "editable" mode, run To install Tianshou in an "editable" mode, run
.. code-block:: bash .. code-block:: bash
pip3 install -e . pip3 install -e ".[dev]"
in the main directory. This installation is removable by in the main directory. This installation is removable by
@ -16,14 +17,8 @@ in the main directory. This installation is removable by
python3 setup.py develop --uninstall python3 setup.py develop --uninstall
Additional dependencies for developments can be installed by
.. code-block:: bash
pip3 install ".[dev]"
PEP8 Code Style Check PEP8 Code Style Check
--------------------- =====================
We follow PEP8 python code style. To check, in the main directory, run: We follow PEP8 python code style. To check, in the main directory, run:
@ -32,7 +27,7 @@ We follow PEP8 python code style. To check, in the main directory, run:
flake8 . --count --show-source --statistics flake8 . --count --show-source --statistics
Test Locally Test Locally
------------ ============
This command will run automatic tests in the main directory This command will run automatic tests in the main directory
@ -41,7 +36,7 @@ This command will run automatic tests in the main directory
pytest test --cov tianshou -s --durations 0 -v pytest test --cov tianshou -s --durations 0 -v
Test by GitHub Actions Test by GitHub Actions
---------------------- ======================
1. Click the ``Actions`` button in your own repo: 1. Click the ``Actions`` button in your own repo:
@ -61,7 +56,7 @@ Test by GitHub Actions
:align: center :align: center
Documentation Documentation
------------- =============
Documentations are written under the ``docs/`` directory as ReStructuredText (``.rst``) files. ``index.rst`` is the main page. A Tutorial on ReStructuredText can be found `here <https://pythonhosted.org/an_example_pypi_project/sphinx.html>`_. Documentations are written under the ``docs/`` directory as ReStructuredText (``.rst``) files. ``index.rst`` is the main page. A Tutorial on ReStructuredText can be found `here <https://pythonhosted.org/an_example_pypi_project/sphinx.html>`_.
@ -74,3 +69,8 @@ To compile documentation into webpages, run
make html make html
under the ``docs/`` directory. The generated webpages are in ``docs/_build`` and can be viewed with browsers. under the ``docs/`` directory. The generated webpages are in ``docs/_build`` and can be viewed with browsers.
Chinese Documentation
=====================
Chinese documentation is in https://github.com/thu-ml/tianshou-docs-zh_CN.

View File

@ -1,3 +1,4 @@
===========
Contributor Contributor
=========== ===========

View File

@ -3,6 +3,7 @@
You can adapt this file completely to your liking, but it should at least You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive. contain the root `toctree` directive.
====================
Welcome to Tianshou! Welcome to Tianshou!
==================== ====================
@ -25,9 +26,9 @@ Tianshou supports parallel workers for all algorithms as well. All of these algo
Installation Installation
------------ ============
Tianshou is currently hosted on `PyPI <https://pypi.org/project/tianshou/>`_. You can simply install Tianshou with the following command: Tianshou is currently hosted on `PyPI <https://pypi.org/project/tianshou/>`_. You can simply install Tianshou with the following command (with Python >= 3.6):
:: ::
pip3 install tianshou pip3 install tianshou
@ -86,7 +87,7 @@ Tianshou is still under development, you can also check out the documents in sta
Indices and tables Indices and tables
------------------ ==================
* :ref:`genindex` * :ref:`genindex`
* :ref:`modindex` * :ref:`modindex`

View File

@ -1,3 +1,4 @@
==========================
Basic concepts in Tianshou Basic concepts in Tianshou
========================== ==========================
@ -9,7 +10,7 @@ Tianshou splits a Reinforcement Learning agent training procedure into these par
Data Batch Data Batch
---------- ==========
.. automodule:: tianshou.data.Batch .. automodule:: tianshou.data.Batch
:members: :members:
@ -17,7 +18,7 @@ Data Batch
Data Buffer Data Buffer
----------- ===========
.. automodule:: tianshou.data.ReplayBuffer .. automodule:: tianshou.data.ReplayBuffer
:members: :members:
@ -28,7 +29,7 @@ Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListR
.. _policy_concept: .. _policy_concept:
Policy Policy
------ ======
Tianshou aims to modularizing RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`. Tianshou aims to modularizing RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`.
@ -90,7 +91,7 @@ For other method, you can check out :doc:`/api/tianshou.policy`. We give the usa
Collector Collector
--------- =========
The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently. The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently.
In short, :class:`~tianshou.data.Collector` has two main methods: In short, :class:`~tianshou.data.Collector` has two main methods:
@ -106,7 +107,7 @@ The general explanation is listed in :ref:`pseudocode`. Other usages of collecto
Trainer Trainer
------- =======
Once you have a collector and a policy, you can start writing the training method for your RL agent. Trainer, to be honest, is a simple wrapper. It helps you save energy for writing the training loop. You can also construct your own trainer: :ref:`customized_trainer`. Once you have a collector and a policy, you can start writing the training method for your RL agent. Trainer, to be honest, is a simple wrapper. It helps you save energy for writing the training loop. You can also construct your own trainer: :ref:`customized_trainer`.
@ -118,7 +119,7 @@ There will be more types of trainers, for instance, multi-agent trainer.
.. _pseudocode: .. _pseudocode:
A High-level Explanation A High-level Explanation
------------------------ ========================
We give a high-level explanation through the pseudocode used in section :ref:`policy_concept`: We give a high-level explanation through the pseudocode used in section :ref:`policy_concept`:
:: ::
@ -141,6 +142,6 @@ We give a high-level explanation through the pseudocode used in section :ref:`po
Conclusion Conclusion
---------- ==========
So far, we go through the overall framework of Tianshou. Really simple, isn't it? So far, we go through the overall framework of Tianshou. Really simple, isn't it?

View File

@ -1,3 +1,4 @@
==============
Deep Q Network Deep Q Network
============== ==============
@ -10,7 +11,7 @@ Contrary to existing Deep RL libraries such as `RLlib <https://github.com/ray-pr
Make an Environment Make an Environment
------------------- ===================
First of all, you have to make an environment for your agent to interact with. For environment interfaces, we follow the convention of `OpenAI Gym <https://github.com/openai/gym>`_. In your Python code, simply import Tianshou and make the environment: First of all, you have to make an environment for your agent to interact with. For environment interfaces, we follow the convention of `OpenAI Gym <https://github.com/openai/gym>`_. In your Python code, simply import Tianshou and make the environment:
:: ::
@ -24,7 +25,7 @@ CartPole-v0 is a simple environment with a discrete action space, for which DQN
Setup Multi-environment Wrapper Setup Multi-environment Wrapper
------------------------------- ===============================
It is available if you want the original ``gym.Env``: It is available if you want the original ``gym.Env``:
:: ::
@ -44,7 +45,7 @@ For the demonstration, here we use the second block of codes.
Build the Network Build the Network
----------------- =================
Tianshou supports any user-defined PyTorch networks and optimizers but with the limitation of input and output API. Here is an example code: Tianshou supports any user-defined PyTorch networks and optimizers but with the limitation of input and output API. Here is an example code:
:: ::
@ -80,7 +81,7 @@ The rules of self-defined networks are:
Setup Policy Setup Policy
------------ ============
We use the defined ``net`` and ``optim``, with extra policy hyper-parameters, to define a policy. Here we define a DQN policy with using a target network: We use the defined ``net`` and ``optim``, with extra policy hyper-parameters, to define a policy. Here we define a DQN policy with using a target network:
:: ::
@ -91,7 +92,7 @@ We use the defined ``net`` and ``optim``, with extra policy hyper-parameters, to
Setup Collector Setup Collector
--------------- ===============
The collector is a key concept in Tianshou. It allows the policy to interact with different types of environments conveniently. The collector is a key concept in Tianshou. It allows the policy to interact with different types of environments conveniently.
In each step, the collector will let the policy perform (at least) a specified number of steps or episodes and store the data in a replay buffer. In each step, the collector will let the policy perform (at least) a specified number of steps or episodes and store the data in a replay buffer.
@ -102,7 +103,7 @@ In each step, the collector will let the policy perform (at least) a specified n
Train Policy with a Trainer Train Policy with a Trainer
--------------------------- ===========================
Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tianshou.trainer.offpolicy_trainer`. The trainer will automatically stop training when the policy reach the stop condition ``stop_fn`` on test collector. Since DQN is an off-policy algorithm, we use the :class:`~tianshou.trainer.offpolicy_trainer` as follows: Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tianshou.trainer.offpolicy_trainer`. The trainer will automatically stop training when the policy reach the stop condition ``stop_fn`` on test collector. Since DQN is an off-policy algorithm, we use the :class:`~tianshou.trainer.offpolicy_trainer` as follows:
:: ::
@ -158,7 +159,7 @@ It shows that within approximately 4 seconds, we finished training a DQN agent o
Save/Load Policy Save/Load Policy
---------------- ================
Since the policy inherits the ``torch.nn.Module`` class, saving and loading the policy are exactly the same as a torch module: Since the policy inherits the ``torch.nn.Module`` class, saving and loading the policy are exactly the same as a torch module:
:: ::
@ -168,7 +169,7 @@ Since the policy inherits the ``torch.nn.Module`` class, saving and loading the
Watch the Agent's Performance Watch the Agent's Performance
----------------------------- =============================
:class:`~tianshou.data.Collector` supports rendering. Here is the example of watching the agent's performance in 35 FPS: :class:`~tianshou.data.Collector` supports rendering. Here is the example of watching the agent's performance in 35 FPS:
:: ::
@ -181,7 +182,7 @@ Watch the Agent's Performance
.. _customized_trainer: .. _customized_trainer:
Train a Policy with Customized Codes Train a Policy with Customized Codes
------------------------------------ ====================================
"I don't want to use your provided trainer. I want to customize it!" "I don't want to use your provided trainer. I want to customize it!"

View File

@ -1,3 +1,4 @@
======================================
Train a model-free RL agent within 30s Train a model-free RL agent within 30s
====================================== ======================================
@ -7,7 +8,7 @@ You can also contribute to this page with your own tricks :)
Avoid batch-size = 1 Avoid batch-size = 1
-------------------- ====================
In the traditional RL training loop, we always use the policy to interact with only one environment for collecting data. That means most of the time the network use batch-size = 1. Quite inefficient! In the traditional RL training loop, we always use the policy to interact with only one environment for collecting data. That means most of the time the network use batch-size = 1. Quite inefficient!
Here is an example of showing how inefficient it is: Here is an example of showing how inefficient it is:
@ -52,7 +53,7 @@ By the way, A2C is better than A3C in some cases: A3C needs to act independently
Algorithm specific tricks Algorithm specific tricks
------------------------- =========================
Here is about the experience of hyper-parameter tuning on CartPole and Pendulum: Here is about the experience of hyper-parameter tuning on CartPole and Pendulum:
@ -66,7 +67,7 @@ Here is about the experience of hyper-parameter tuning on CartPole and Pendulum:
Code-level optimization Code-level optimization
----------------------- =======================
Tianshou has many short-but-efficient lines of code. For example, when we want to compute :math:`V(s)` and :math:`V(s')` by the same network, the best way is to concatenate :math:`s` and :math:`s'` together instead of computing the value function using twice of network forward. Tianshou has many short-but-efficient lines of code. For example, when we want to compute :math:`V(s)` and :math:`V(s')` by the same network, the best way is to concatenate :math:`s` and :math:`s'` together instead of computing the value function using twice of network forward.
@ -74,7 +75,7 @@ Tianshou has many short-but-efficient lines of code. For example, when we want t
Finally Finally
------- =======
With fast-speed sampling, we could use large batch-size and large learning rate for faster convergence. With fast-speed sampling, we could use large batch-size and large learning rate for faster convergence.

View File

@ -28,7 +28,7 @@ class MyTestEnv(gym.Env):
if action == 0: if action == 0:
self.index = max(self.index - 1, 0) self.index = max(self.index - 1, 0)
if self.dict_state: if self.dict_state:
return {'index': self.index}, 0, False, {} return {'index': self.index}, 0, False, {'key': 1, 'env': self}
else: else:
return self.index, 0, False, {} return self.index, 0, False, {}
elif action == 1: elif action == 1:
@ -36,6 +36,7 @@ class MyTestEnv(gym.Env):
self.done = self.index == self.size self.done = self.index == self.size
if self.dict_state: if self.dict_state:
return {'index': self.index}, int(self.done), self.done, \ return {'index': self.index}, int(self.done), self.done, \
{'key': 1} {'key': 1, 'env': self}
else: else:
return self.index, int(self.done), self.done, {'key': 1} return self.index, int(self.done), self.done, \
{'key': 1, 'env': self}

View File

@ -1,6 +1,6 @@
import pytest
import pickle
import torch import torch
import pickle
import pytest
import numpy as np import numpy as np
from tianshou.data import Batch, to_torch from tianshou.data import Batch, to_torch
@ -15,11 +15,15 @@ def test_batch():
assert batch.obs == [1, 1] assert batch.obs == [1, 1]
assert batch.np.shape == (6, 4) assert batch.np.shape == (6, 4)
assert batch[0].obs == batch[1].obs assert batch[0].obs == batch[1].obs
with pytest.raises(IndexError):
batch[2]
batch.obs = np.arange(5) batch.obs = np.arange(5)
for i, b in enumerate(batch.split(1, shuffle=False)): for i, b in enumerate(batch.split(1, shuffle=False)):
assert b.obs == batch[i].obs if i != 5:
assert b.obs == batch[i].obs
else:
with pytest.raises(AttributeError):
batch[i].obs
with pytest.raises(AttributeError):
b.obs
print(batch) print(batch)
@ -95,3 +99,7 @@ def test_batch_from_to_numpy_without_copy():
if __name__ == '__main__': if __name__ == '__main__':
test_batch() test_batch()
test_batch_over_batch() test_batch_over_batch()
test_batch_over_batch_to_torch()
test_utils_to_torch()
test_batch_pickle()
test_batch_from_to_numpy_without_copy()

View File

@ -39,7 +39,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
A = v.reset(np.where(C)[0]) A = v.reset(np.where(C)[0])
o.append([A, B, C, D]) o.append([A, B, C, D])
for i in zip(*o): for i in zip(*o):
for j in range(1, len(i)): for j in range(1, len(i) - 1):
assert (i[0] == i[j]).all() assert (i[0] == i[j]).all()
else: else:
t = [0, 0, 0] t = [0, 0, 0]

View File

@ -122,8 +122,11 @@ class Batch:
return self.__getattr__(index) return self.__getattr__(index)
b = Batch() b = Batch()
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if k != '_meta' and v is not None: if k != '_meta' and hasattr(v, '__len__'):
b.__dict__.update(**{k: v[index]}) try:
b.__dict__.update(**{k: v[index]})
except IndexError:
continue
b._meta = self._meta b._meta = self._meta
return b return b
@ -238,8 +241,8 @@ class Batch:
def __len__(self) -> int: def __len__(self) -> int:
"""Return len(self).""" """Return len(self)."""
return min([len(v) for k, v in self.__dict__.items() r = [len(v) for k, v in self.__dict__.items() if hasattr(v, '__len__')]
if k != '_meta' and v is not None]) return max(r) if len(r) > 0 else 0
def split(self, size: Optional[int] = None, def split(self, size: Optional[int] = None,
shuffle: bool = True) -> Iterator['Batch']: shuffle: bool = True) -> Iterator['Batch']:

View File

@ -1,7 +1,6 @@
import pprint import pprint
import numpy as np import numpy as np
from copy import deepcopy from typing import Any, Tuple, Union, Optional
from typing import Tuple, Union, Optional
from tianshou.data.batch import Batch from tianshou.data.batch import Batch
@ -60,7 +59,7 @@ class ReplayBuffer(object):
ReplayBuffer( ReplayBuffer(
act: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), act: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]), done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]),
info: array([{}, {}, {}, {}, {}, {}, {}, {}, {}], dtype=object), info: Batch(),
obs: Batch( obs: Batch(
id: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), id: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
), ),
@ -137,9 +136,7 @@ class ReplayBuffer(object):
d[k_] = self.__dict__[k__] d[k_] = self.__dict__[k__]
return Batch(**d) return Batch(**d)
def _add_to_buffer( def _add_to_buffer(self, name: str, inst: Any) -> None:
self, name: str,
inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None:
if inst is None: if inst is None:
if getattr(self, name, None) is None: if getattr(self, name, None) is None:
self.__dict__[name] = None self.__dict__[name] = None
@ -153,18 +150,17 @@ class ReplayBuffer(object):
self.__dict__[name] = np.zeros( self.__dict__[name] = np.zeros(
(self._maxsize, *inst.shape), dtype=inst.dtype) (self._maxsize, *inst.shape), dtype=inst.dtype)
elif isinstance(inst, (dict, Batch)): elif isinstance(inst, (dict, Batch)):
if name == 'info': if self._meta.get(name, None) is None:
self.__dict__[name] = np.array( self._meta[name] = list(inst.keys())
[{} for _ in range(self._maxsize)]) for k in inst.keys():
else: k_ = '_' + name + '@' + k
if self._meta.get(name, None) is None: self._add_to_buffer(k_, inst[k])
self._meta[name] = list(inst.keys()) elif np.isscalar(inst):
for k in inst.keys():
k_ = '_' + name + '@' + k
self._add_to_buffer(k_, inst[k])
else: # assume `inst` is a number
self.__dict__[name] = np.zeros( self.__dict__[name] = np.zeros(
(self._maxsize,), dtype=np.asarray(inst).dtype) (self._maxsize,), dtype=np.asarray(inst).dtype)
else: # fall back to np.object
self.__dict__[name] = np.array(
[None for _ in range(self._maxsize)])
if isinstance(inst, np.ndarray) and \ if isinstance(inst, np.ndarray) and \
self.__dict__[name].shape[1:] != inst.shape: self.__dict__[name].shape[1:] != inst.shape:
raise ValueError( raise ValueError(
@ -172,8 +168,6 @@ class ReplayBuffer(object):
f"key: {name}, expect shape: {self.__dict__[name].shape[1:]}, " f"key: {name}, expect shape: {self.__dict__[name].shape[1:]}, "
f"given shape: {inst.shape}.") f"given shape: {inst.shape}.")
if name not in self._meta: if name not in self._meta:
if name == 'info':
inst = deepcopy(inst)
self.__dict__[name][self._index] = inst self.__dict__[name][self._index] = inst
def update(self, buffer: 'ReplayBuffer') -> None: def update(self, buffer: 'ReplayBuffer') -> None:
@ -198,7 +192,7 @@ class ReplayBuffer(object):
policy: Optional[Union[dict, Batch]] = {}, policy: Optional[Union[dict, Batch]] = {},
**kwargs) -> None: **kwargs) -> None:
"""Add a batch of data into replay buffer.""" """Add a batch of data into replay buffer."""
assert isinstance(info, dict), \ assert isinstance(info, (dict, Batch)), \
'You should return a dict in the last argument of env.step().' 'You should return a dict in the last argument of env.step().'
self._add_to_buffer('obs', obs) self._add_to_buffer('obs', obs)
self._add_to_buffer('act', act) self._add_to_buffer('act', act)
@ -330,8 +324,6 @@ class ListReplayBuffer(ReplayBuffer):
return return
if self.__dict__.get(name, None) is None: if self.__dict__.get(name, None) is None:
self.__dict__[name] = [] self.__dict__[name] = []
if name == 'info':
inst = deepcopy(inst)
self.__dict__[name].append(inst) self.__dict__[name].append(inst)
def reset(self) -> None: def reset(self) -> None: