diff --git a/.github/workflows/lint_and_docs.yml b/.github/workflows/lint_and_docs.yml
index 90c6caf..07e1def 100644
--- a/.github/workflows/lint_and_docs.yml
+++ b/.github/workflows/lint_and_docs.yml
@@ -16,15 +16,14 @@ jobs:
python -m pip install --upgrade pip setuptools wheel
- name: Install dependencies
run: |
- python -m pip install flake8
+ python -m pip install ".[dev]" --upgrade
- name: Lint with flake8
run: |
flake8 . --count --show-source --statistics
- - name: Install dependencies
- run: |
- pip install ".[dev]" --upgrade
- name: Documentation test
run: |
+ pydocstyle tianshou
+ doc8 docs --max-line-length 1000
cd docs
make html SPHINXOPTS="-W"
cd ..
diff --git a/.github/workflows/profile.yml b/.github/workflows/profile.yml
index 3bdd8ea..5b0f358 100644
--- a/.github/workflows/profile.yml
+++ b/.github/workflows/profile.yml
@@ -16,7 +16,7 @@ jobs:
python -m pip install --upgrade pip setuptools wheel
- name: Install dependencies
run: |
- pip install ".[dev]" --upgrade
+ python -m pip install ".[dev]" --upgrade
- name: Test with pytest
run: |
pytest test/throughput --durations=0 -v
diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml
index 84ab74f..1b42770 100644
--- a/.github/workflows/pytest.yml
+++ b/.github/workflows/pytest.yml
@@ -26,7 +26,7 @@ jobs:
python -m pip install --upgrade pip setuptools wheel
- name: Install dependencies
run: |
- pip install ".[dev]" --upgrade
+ python -m pip install ".[dev]" --upgrade
- name: Test with pytest
# ignore test/throughput which only profiles the code
run: |
diff --git a/README.md b/README.md
index eebbfde..f65837e 100644
--- a/README.md
+++ b/README.md
@@ -40,6 +40,7 @@ Here is Tianshou's other features:
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process)
- Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning)
+- Comprehensive documentation, PEP8 code-style checking, type checking and [unit tests](https://github.com/thu-ml/tianshou/actions)
In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.
diff --git a/docs/conf.py b/docs/conf.py
index eb6f65f..2169ab8 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -23,9 +23,9 @@ version = tianshou.__version__
# -- Project information -----------------------------------------------------
-project = 'Tianshou'
-copyright = '2020, Tianshou contributors.'
-author = 'Tianshou contributors'
+project = "Tianshou"
+copyright = "2020, Tianshou contributors."
+author = "Tianshou contributors"
# The full version, including alpha/beta/rc tags
release = version
@@ -37,51 +37,61 @@ release = version
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
- 'sphinx.ext.autodoc',
- 'sphinx.ext.doctest',
- 'sphinx.ext.intersphinx',
- 'sphinx.ext.coverage',
+ "sphinx.ext.autodoc",
+ "sphinx.ext.doctest",
+ "sphinx.ext.intersphinx",
+ "sphinx.ext.coverage",
# 'sphinx.ext.imgmath',
- 'sphinx.ext.mathjax',
- 'sphinx.ext.ifconfig',
- 'sphinx.ext.viewcode',
- 'sphinx.ext.githubpages',
- 'sphinxcontrib.bibtex',
+ "sphinx.ext.mathjax",
+ "sphinx.ext.ifconfig",
+ "sphinx.ext.viewcode",
+ "sphinx.ext.githubpages",
+ "sphinxcontrib.bibtex",
]
# Add any paths that contain templates here, relative to this directory.
-templates_path = ['_templates']
-source_suffix = ['.rst', '.md']
-master_doc = 'index'
+templates_path = ["_templates"]
+source_suffix = [".rst", ".md"]
+master_doc = "index"
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
-exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
-autodoc_default_options = {'special-members': ', '.join([
- '__len__', '__call__', '__getitem__', '__setitem__',
- '__getattr__', '__setattr__'])}
+exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
+autodoc_default_options = {
+ "special-members": ", ".join(
+ [
+ "__len__",
+ "__call__",
+ "__getitem__",
+ "__setitem__",
+ # "__getattr__",
+ # "__setattr__",
+ ]
+ )
+}
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
-html_theme = 'sphinx_rtd_theme'
+html_theme = "sphinx_rtd_theme"
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
-html_static_path = ['_static']
+html_static_path = ["_static"]
-html_logo = '_static/images/tianshou-logo.png'
+html_logo = "_static/images/tianshou-logo.png"
def setup(app):
app.add_js_file("js/copybutton.js")
app.add_css_file("css/style.css")
+
# -- Extension configuration -------------------------------------------------
# -- Options for intersphinx extension ---------------------------------------
diff --git a/docs/contributing.rst b/docs/contributing.rst
index 063db78..b56a7a4 100644
--- a/docs/contributing.rst
+++ b/docs/contributing.rst
@@ -1,6 +1,7 @@
Contributing to Tianshou
========================
+
Install Develop Version
-----------------------
@@ -16,6 +17,7 @@ in the main directory. This installation is removable by
$ python setup.py develop --uninstall
+
PEP8 Code Style Check
---------------------
@@ -25,6 +27,7 @@ We follow PEP8 python code style. To check, in the main directory, run:
$ flake8 . --count --show-source --statistics
+
Test Locally
------------
@@ -34,6 +37,7 @@ This command will run automatic tests in the main directory
$ pytest test --cov tianshou -s --durations 0 -v
+
Test by GitHub Actions
----------------------
@@ -54,6 +58,7 @@ Test by GitHub Actions
.. image:: _static/images/action3.png
:align: center
+
Documentation
-------------
@@ -70,3 +75,28 @@ To compile documentation into webpages, run
under the ``docs/`` directory. The generated webpages are in ``docs/_build`` and can be viewed with browsers.
Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/.
+
+
+Documentation Generation Test
+-----------------------------
+
+We have the following three documentation tests:
+
+1. pydocstyle: test docstrings under ``tianshou/``. To check, in the main directory, run:
+
+.. code-block:: bash
+
+ $ pydocstyle tianshou
+
+2. doc8: test ReStructuredText formats. To check, in the main directory, run:
+
+.. code-block:: bash
+
+ $ doc8 docs
+
+3. sphinx test: test if there is any errors/warnings when generating front-end html documentations. To check, in the main directory, run:
+
+.. code-block:: bash
+
+ $ cd docs
+ $ make html SPHINXOPTS="-W"
diff --git a/docs/index.rst b/docs/index.rst
index bfb2ddf..5d962af 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -3,6 +3,7 @@
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
+
Welcome to Tianshou!
====================
@@ -25,14 +26,15 @@ Here is Tianshou's other features:
* Elegant framework, using only ~2000 lines of code
* Support parallel environment simulation (synchronous or asynchronous) for all algorithms: :ref:`parallel_sampling`
-* Support recurrent state/action representation in actor network and critic network (RNN-style training for POMDP): :ref:`rnn_training`
-* Support any type of environment state (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env`
-* Support customized training process: :ref:`customize_training`
+* Support recurrent state representation in actor network and critic network (RNN-style training for POMDP): :ref:`rnn_training`
+* Support any type of environment state/action (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env`
+* Support :ref:`customize_training`
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
-* Support multi-agent RL: :doc:`/tutorials/tictactoe`
+* Support :doc:`/tutorials/tictactoe`
中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ `_
+
Installation
------------
@@ -70,6 +72,7 @@ If no error occurs, you have successfully installed Tianshou.
Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ `_.
+
.. toctree::
:maxdepth: 1
:caption: Tutorials
@@ -81,6 +84,7 @@ Tianshou is still under development, you can also check out the documents in sta
tutorials/trick
tutorials/cheatsheet
+
.. toctree::
:maxdepth: 1
:caption: API Docs
@@ -92,6 +96,7 @@ Tianshou is still under development, you can also check out the documents in sta
api/tianshou.exploration
api/tianshou.utils
+
.. toctree::
:maxdepth: 1
:caption: Community
diff --git a/docs/tutorials/batch.rst b/docs/tutorials/batch.rst
index 390fc41..49d913d 100644
--- a/docs/tutorials/batch.rst
+++ b/docs/tutorials/batch.rst
@@ -3,9 +3,10 @@
Understand Batch
================
-:class:`~tianshou.data.Batch` is the internal data structure extensively used in Tianshou. It is designed to store and manipulate hierarchical named tensors. This tutorial aims to help users correctly understand the concept and the behavior of ``Batch`` so that users can make the best of Tianshou.
+:class:`~tianshou.data.Batch` is the internal data structure extensively used in Tianshou. It is designed to store and manipulate hierarchical named tensors. This tutorial aims to help users correctly understand the concept and the behavior of :class:`~tianshou.data.Batch` so that users can make the best of Tianshou.
+
+The tutorial has three parts. We first explain the concept of hierarchical named tensors, and introduce basic usage of :class:`~tianshou.data.Batch`, followed by advanced topics of :class:`~tianshou.data.Batch`.
-The tutorial has three parts. We first explain the concept of hierarchical named tensors, and introduce basic usage of ``Batch``, followed by advanced topics of ``Batch``.
Hierarchical Named Tensors
---------------------------
@@ -43,11 +44,13 @@ Note that, storing hierarchical named tensors is as easy as creating nested dict
The real problem is how to **manipulate them**, such as adding new transition tuples into replay buffer and dealing with their heterogeneity. ``Batch`` is designed to easily create, store, and manipulate these hierarchical named tensors.
+
Basic Usages
------------
Here we cover some basic usages of ``Batch``, describing what ``Batch`` contains, how to construct ``Batch`` objects and how to manipulate them.
+
What Does Batch Contain
^^^^^^^^^^^^^^^^^^^^^^^
@@ -69,6 +72,7 @@ The content of ``Batch`` objects can be defined by the following rules.
The data types of tensors are bool and numbers (any size of int and float as long as they are supported by NumPy or PyTorch). Besides, NumPy supports ndarray of objects and we take advantage of this feature to store non-number objects in ``Batch``. If one wants to store data that are neither boolean nor numbers (such as strings and sets), they can store the data in ``np.ndarray`` with the ``np.object`` data type. This way, ``Batch`` can store any type of python objects.
+
Construction of Batch
^^^^^^^^^^^^^^^^^^^^^
@@ -136,6 +140,7 @@ There are two ways to construct a ``Batch`` object: from a ``dict``, or using ``
+
Data Manipulation With Batch
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -285,11 +290,13 @@ Stacking and concatenating multiple ``Batch`` instances, or split an instance in
+
Advanced Topics
---------------
From here on, this tutorial focuses on advanced topics of ``Batch``, including key reservation, length/shape, and aggregation of heterogeneous batches.
+
.. _key_reservations:
Key Reservations
@@ -347,6 +354,7 @@ The ``Batch.is_empty`` function has an option to decide whether to identify dire
Do not get confused with ``Batch.is_empty`` and ``Batch.empty``. ``Batch.empty`` and its in-place variant ``Batch.empty_`` are used to set some values to zeros or None. Check the API documentation for further details.
+
Length and Shape
^^^^^^^^^^^^^^^^
@@ -391,6 +399,7 @@ The ``obj.shape`` attribute of ``Batch`` behaves somewhat similar to ``len(obj)`
4. The shape of reserved keys is undetermined, too. We treat their shape as ``[]``.
+
.. _aggregation:
Aggregation of Heterogeneous Batches
@@ -457,6 +466,7 @@ For a set of ``Batch`` objects denoted as :math:`S`, they can be aggregated if t
The ``Batch`` object ``b`` satisfying these rules with the minimum number of keys determines the structure of aggregating :math:`S`. The values are relatively easy to define: for any key chain ``k`` that applies to ``b``, ``b[k]`` is the stack/concatenation of ``[bi[k] for bi in S]`` (if ``k`` does not apply to ``bi``, the appropriate size of zeros or ``None`` are filled automatically). If ``bi[k]`` are all ``Batch()``, then the aggregation result is also an empty ``Batch()``.
+
Miscellaneous Notes
^^^^^^^^^^^^^^^^^^^
diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst
index 3235843..c088a8d 100644
--- a/docs/tutorials/cheatsheet.rst
+++ b/docs/tutorials/cheatsheet.rst
@@ -5,6 +5,7 @@ This page shows some code snippets of how to use Tianshou to develop new algorit
By the way, some of these issues can be resolved by using a ``gym.wrapper``. It could be a universal solution in the policy-environment interaction. But you can also use the batch processor :ref:`preprocess_fn`.
+
.. _network_api:
Build Policy Network
@@ -12,6 +13,7 @@ Build Policy Network
See :ref:`build_the_network`.
+
.. _new_policy:
Build New Policy
@@ -19,6 +21,7 @@ Build New Policy
See :class:`~tianshou.policy.BasePolicy`.
+
.. _customize_training:
Customize Training Process
@@ -26,6 +29,7 @@ Customize Training Process
See :ref:`customized_trainer`.
+
.. _parallel_sampling:
Parallel Sampling
@@ -66,7 +70,7 @@ Asynchronous simulation is a built-in functionality of :class:`~tianshou.env.Bas
# DummyVectorEnv, ShmemVectorEnv, or RayVectorEnv, whichever you like.
venv = SubprocVectorEnv(env_fns, wait_num=3, timeout=0.2)
venv.reset() # returns the initial observations of each environment
- # returns ``wait_num`` steps or finished steps after ``timeout`` seconds,
+ # returns "wait_num" steps or finished steps after "timeout" seconds,
# whichever occurs first.
venv.step(actions, ready_id)
@@ -87,6 +91,7 @@ The figure in the right gives an intuitive comparison among synchronous/asynchro
Otherwise, the outputs of these envs may be the same with each other.
+
.. _preprocess_fn:
Handle Batched Data Stream in Collector
@@ -96,17 +101,24 @@ This is related to `Issue 42 `_.
If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer.
-This function receives up to 7 keys ``obs``, ``act``, ``rew``, ``done``, ``obs_next``, ``info``, and ``policy``, as listed in :class:`~tianshou.data.Batch`, and returns the modified part within a :class:`~tianshou.data.Batch`. Only ``obs`` is defined at env reset, while every key is specified for normal steps. For example, you can write your hook as:
+This function receives up to 7 keys ``obs``, ``act``, ``rew``, ``done``, ``obs_next``, ``info``, and ``policy``, as listed in :class:`~tianshou.data.Batch`. It returns the modified part within a :class:`~tianshou.data.Batch`. Only ``obs`` is defined at env.reset, while every key is specified for normal steps.
+
+These variables are intended to gather all the information requires to keep track of a simulation step, namely the (observation, action, reward, done flag, next observation, info, intermediate result of the policy) at time t, for the whole duration of the simulation.
+
+For example, you can write your hook as:
::
import numpy as np
from collections import deque
+
+
class MyProcessor:
def __init__(self, size=100):
self.episode_log = None
self.main_log = deque(maxlen=size)
self.main_log.append(0)
self.baseline = 0
+
def preprocess_fn(**kwargs):
"""change reward to zero mean"""
# if only obs exist -> reset
@@ -136,6 +148,7 @@ And finally,
Some examples are in `test/base/test_collector.py `_.
+
.. _rnn_training:
RNN-style Training
@@ -143,18 +156,19 @@ RNN-style Training
This is related to `Issue 19 `_.
-First, add an argument ``stack_num`` to :class:`~tianshou.data.ReplayBuffer`:
+First, add an argument "stack_num" to :class:`~tianshou.data.ReplayBuffer`:
::
buf = ReplayBuffer(size=size, stack_num=stack_num)
-Then, change the network to recurrent-style, for example, class ``Recurrent`` in `code snippet 1 `_, or ``RecurrentActor`` and ``RecurrentCritic`` in `code snippet 2 `_.
+Then, change the network to recurrent-style, for example, :class:`~tianshou.utils.net.common.Recurrent`, :class:`~tianshou.utils.net.continuous.RecurrentActorProb` and :class:`~tianshou.utils.net.continuous.RecurrentCritic`.
The above code supports only stacked-observation. If you want to use stacked-action (for Q(stacked-s, stacked-a)), stacked-reward, or other stacked variables, you can add a ``gym.wrapper`` to modify the state representation. For example, if we add a wrapper that map [s, a] pair to a new state:
- Before: (s, a, s', r, d) stored in replay buffer, and get stacked s;
- After applying wrapper: ([s, a], a, [s', a'], r, d) stored in replay buffer, and get both stacked s and a.
+
.. _self_defined_env:
User-defined Environment and Different State Representation
@@ -174,11 +188,11 @@ First of all, your self-defined environment must follow the Gym's API, some of t
- close() -> None
-- observation_space
+- observation_space: gym.Space
-- action_space
+- action_space: gym.Space
-The state can be a ``numpy.ndarray`` or a Python dictionary. Take ``FetchReach-v1`` as an example:
+The state can be a ``numpy.ndarray`` or a Python dictionary. Take "FetchReach-v1" as an example:
::
>>> e = gym.make('FetchReach-v1')
@@ -285,6 +299,7 @@ But the state stored in the buffer may be a shallow-copy. To make sure each of y
...
return copy.deepcopy(self.graph), reward, done, {}
+
.. _marl_example:
Multi-Agent Reinforcement Learning
diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst
index ba771ad..d7f5971 100644
--- a/docs/tutorials/concepts.rst
+++ b/docs/tutorials/concepts.rst
@@ -14,6 +14,7 @@ Here is a more detailed description, where ``Env`` is the environment and ``Mode
:align: center
:height: 300
+
Batch
-----
@@ -48,6 +49,7 @@ In short, you can define a :class:`~tianshou.data.Batch` with any key-value pair
:ref:`batch_concept` is a dedicated tutorial for :class:`~tianshou.data.Batch`. We strongly recommend every user to read it so as to correctly understand and use :class:`~tianshou.data.Batch`.
+
Buffer
------
@@ -57,7 +59,6 @@ Buffer
Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out :class:`~tianshou.data.ReplayBuffer` for more detail.
-.. _policy_concept:
Policy
------
@@ -73,6 +74,36 @@ A policy class typically has the following parts:
* :meth:`~tianshou.policy.BasePolicy.post_process_fn`: update the buffer with a given batch of data.
* :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (such as updating prioritized replay buffer); in short, ``process_fn -> learn -> post_process_fn``.
+
+policy.forward
+^^^^^^^^^^^^^^
+
+The ``forward`` function computes the action over given observations. The input and output is algorithm-specific but generally, the function is a mapping of ``(batch, state, ...) -> batch``.
+
+The input batch is the environment data (e.g., observation, reward, done flag and info). It comes from either :meth:`~tianshou.data.Collector.collect` or :meth:`~tianshou.data.ReplayBuffer.sample`. The first dimension of all variables in the input ``batch`` should be equal to the batch-size.
+
+The output is also a Batch which must contain "act" (action) and may contain "state" (hidden state of policy), "policy" (the intermediate result of policy which needs to save into the buffer, see :meth:`~tianshou.policy.BasePolicy.forward`), and some other algorithm-specific keys.
+
+For example, if you try to use your policy to evaluate one episode (and don't want to use :meth:`~tianshou.data.Collector.collect`), use the following code-snippet:
+::
+
+ # assume env is a gym.Env
+ obs, done = env.reset(), False
+ while not done:
+ batch = Batch(obs=[obs]) # the first dimension is batch-size
+ act = policy(batch).act[0] # policy.forward return a batch, use ".act" to extract the action
+ obs, rew, done, info = env.step(act)
+
+Here, ``Batch(obs=[obs])`` will automatically create the 0-dimension to be the batch-size. Otherwise, the network cannot determine the batch-size.
+
+
+.. _process_fn:
+
+policy.process_fn
+^^^^^^^^^^^^^^^^^
+
+The ``process_fn`` function computes some variables that depends on time-series. For example, compute the N-step or GAE returns.
+
Take 2-step return DQN as an example. The 2-step return DQN compute each frame's return as:
.. math::
@@ -128,11 +159,11 @@ Collector
The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently.
-:class:`~tianshou.data.Collector` has one main method :meth:`~tianshou.data.Collector.collect`: it let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer.
+:meth:`~tianshou.data.Collector.collect` is the main method of Collector: it let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer.
Why do we mention **at least** here? For multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically.
-The solution is to add some cache buffers inside the collector. Once collecting **a full episode of trajectory**, it will move the stored data from the cache buffer to the main buffer. To satisfy this condition, the collector will interact with environments that may exceed the given step number or episode number.
+The proposed solution is to add some cache buffers inside the collector. Once collecting **a full episode of trajectory**, it will move the stored data from the cache buffer to the main buffer. To satisfy this condition, the collector will interact with environments that may exceed the given step number or episode number.
The general explanation is listed in :ref:`pseudocode`. Other usages of collector are listed in :class:`~tianshou.data.Collector` documentation.
@@ -150,7 +181,7 @@ Tianshou has two types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` an
A High-level Explanation
------------------------
-We give a high-level explanation through the pseudocode used in section :ref:`policy_concept`:
+We give a high-level explanation through the pseudocode used in section :ref:`process_fn`:
::
# pseudocode, cannot work # methods in tianshou
@@ -158,13 +189,13 @@ We give a high-level explanation through the pseudocode used in section :ref:`po
buffer = Buffer(size=10000) # buffer = tianshou.data.ReplayBuffer(size=10000)
agent = DQN() # policy.__init__(...)
for i in range(int(1e6)): # done in trainer
- a = agent.compute_action(s) # policy(batch, ...)
+ a = agent.compute_action(s) # act = policy(batch, ...).act
s_, r, d, _ = env.step(a) # collector.collect(...)
buffer.store(s, a, s_, r, d) # collector.collect(...)
s = s_ # collector.collect(...)
if i % 1000 == 0: # done in trainer
# the following is done in policy.update(batch_size, buffer)
- b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # buffer.sample(batch_size)
+ b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # batch, indice = buffer.sample(batch_size)
# compute 2-step returns. How?
b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # policy.process_fn(batch, buffer, indice)
# update DQN policy
diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst
index 9655ee8..d923b56 100644
--- a/docs/tutorials/dqn.rst
+++ b/docs/tutorials/dqn.rst
@@ -8,6 +8,7 @@ The full script is at `test/discrete/test_dqn.py `_, which could only accept a config specification of hyperparameters, network, and others, Tianshou provides an easy way of construction through the code-level.
+
Make an Environment
-------------------
@@ -21,10 +22,11 @@ First of all, you have to make an environment for your agent to interact with. F
CartPole-v0 is a simple environment with a discrete action space, for which DQN applies. You have to identify whether the action space is continuous or discrete and apply eligible algorithms. DDPG :cite:`DDPG`, for example, could only be applied to continuous action spaces, while almost all other policy gradient methods could be applied to both, depending on the probability distribution on the action.
+
Setup Multi-environment Wrapper
-------------------------------
-It is available if you want the original ``gym.Env``:
+If you want to use the original ``gym.Env``:
::
train_envs = gym.make('CartPole-v0')
@@ -38,7 +40,7 @@ Tianshou supports parallel sampling for all algorithms. It provides four types o
Here, we set up 8 environments in ``train_envs`` and 100 environments in ``test_envs``.
-For the demonstration, here we use the second block of codes.
+For the demonstration, here we use the second code-block.
.. warning::
@@ -51,12 +53,13 @@ For the demonstration, here we use the second block of codes.
Otherwise, the outputs of these envs may be the same with each other.
+
.. _build_the_network:
Build the Network
-----------------
-Tianshou supports any user-defined PyTorch networks and optimizers but with the limitation of input and output API. Here is an example code:
+Tianshou supports any user-defined PyTorch networks and optimizers. Yet, of course, the inputs and outputs must comply with Tianshou's API. Here is an example:
::
import torch, numpy as np
@@ -65,12 +68,13 @@ Tianshou supports any user-defined PyTorch networks and optimizers but with the
class Net(nn.Module):
def __init__(self, state_shape, action_shape):
super().__init__()
- self.model = nn.Sequential(*[
+ self.model = nn.Sequential(
nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True),
nn.Linear(128, 128), nn.ReLU(inplace=True),
nn.Linear(128, 128), nn.ReLU(inplace=True),
- nn.Linear(128, np.prod(action_shape))
- ])
+ nn.Linear(128, np.prod(action_shape)),
+ )
+
def forward(self, obs, state=None, info={}):
if not isinstance(obs, torch.Tensor):
obs = torch.tensor(obs, dtype=torch.float)
@@ -83,29 +87,32 @@ Tianshou supports any user-defined PyTorch networks and optimizers but with the
net = Net(state_shape, action_shape)
optim = torch.optim.Adam(net.parameters(), lr=1e-3)
-You can also have a try with those pre-defined networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are:
+It is also possible to use pre-defined MLP networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are:
1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment.
-2. Output: some ``logits``, the next hidden state ``state``, and intermediate result during the policy forwarding procedure ``policy``. The logits could be a tuple instead of a ``torch.Tensor``. It depends on how the policy process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. The ``policy`` can be a Batch of torch.Tensor or other things, which will be stored in the replay buffer, and can be accessed in the policy update process (e.g. in ``policy.learn()``, the ``batch.policy`` is what you need).
+2. Output: some ``logits``, the next hidden state ``state``. The logits could be a tuple instead of a ``torch.Tensor``, or some other useful variables or results during the policy forwarding procedure. It depends on how the policy class process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy.
+
Setup Policy
------------
-We use the defined ``net`` and ``optim``, with extra policy hyper-parameters, to define a policy. Here we define a DQN policy with using a target network:
+We use the defined ``net`` and ``optim`` above, with extra policy hyper-parameters, to define a policy. Here we define a DQN policy with a target network:
::
policy = ts.policy.DQNPolicy(net, optim, discount_factor=0.9, estimation_step=3, target_update_freq=320)
+
Setup Collector
---------------
-The collector is a key concept in Tianshou. It allows the policy to interact with different types of environments conveniently.
+The collector is a key concept in Tianshou. It allows the policy to interact with different types of environments conveniently.
In each step, the collector will let the policy perform (at least) a specified number of steps or episodes and store the data in a replay buffer.
::
train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(size=20000))
test_collector = ts.data.Collector(policy, test_envs)
+
Train Policy with a Trainer
---------------------------
@@ -161,15 +168,17 @@ The returned result is a dictionary as follows:
It shows that within approximately 4 seconds, we finished training a DQN agent on CartPole. The mean returns over 100 consecutive episodes is 199.03.
+
Save/Load Policy
----------------
-Since the policy inherits the ``torch.nn.Module`` class, saving and loading the policy are exactly the same as a torch module:
+Since the policy inherits the class ``torch.nn.Module``, saving and loading the policy are exactly the same as a torch module:
::
torch.save(policy.state_dict(), 'dqn.pth')
policy.load_state_dict(torch.load('dqn.pth'))
+
Watch the Agent's Performance
-----------------------------
@@ -181,6 +190,7 @@ Watch the Agent's Performance
collector = ts.data.Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)
+
.. _customized_trainer:
Train a Policy with Customized Codes
diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst
index 6911177..cc4116d 100644
--- a/docs/tutorials/tictactoe.rst
+++ b/docs/tutorials/tictactoe.rst
@@ -6,6 +6,7 @@ In this section, we describe how to use Tianshou to implement multi-agent reinfo
.. image:: ../_static/images/tic-tac-toe.png
:align: center
+
Tic-Tac-Toe Environment
-----------------------
@@ -15,11 +16,11 @@ The scripts are located at ``test/multiagent/``. We have implemented a Tic-Tac-T
>>> from tic_tac_toe_env import TicTacToeEnv # the module tic_tac_toe_env is in test/multiagent/
>>> board_size = 6 # the size of board size
>>> win_size = 4 # how many signs in a row are considered to win
- >>>
+ >>>
>>> # This board has 6 rows and 6 cols (36 places in total)
>>> # Players place 'x' and 'o' in turn on the board
>>> # The player who first gets 4 consecutive 'x's or 'o's wins
- >>>
+ >>>
>>> env = TicTacToeEnv(size=board_size, win_size=win_size)
>>> obs = env.reset()
>>> env.render() # render the empty board
@@ -105,6 +106,7 @@ One worth-noting case is that the game is over when there is only one empty posi
After being familiar with the environment, let's try to play with random agents first!
+
Two Random Agent
----------------
@@ -119,7 +121,7 @@ Tianshou already provides some builtin classes for multi-agent learning. You can
>>> from tianshou.data import Collector
>>> from tianshou.policy import RandomPolicy, MultiAgentPolicyManager
>>>
- >>> # agents should be wrapped into one policy,
+ >>> # agents should be wrapped into one policy,
>>> # which is responsible for calling the acting agent correctly
>>> # here we use two random agents
>>> policy = MultiAgentPolicyManager([RandomPolicy(), RandomPolicy()])
@@ -159,7 +161,8 @@ Tianshou already provides some builtin classes for multi-agent learning. You can
===x _ _ _ x x===
=================
-Random agents perform badly. In the above game, although agent 2 wins finally, it is clear that a smart agent 1 would place an ``x`` at row 4 col 4 to win directly.
+Random agents perform badly. In the above game, although agent 2 wins finally, it is clear that a smart agent 1 would place an ``x`` at row 4 col 4 to win directly.
+
Train an MARL Agent
-------------------
@@ -212,7 +215,7 @@ The explanation of each Tianshou class/function will be deferred to their first
parser.add_argument('--watch', default=False, action='store_true',
help='no training, watch the play of pre-trained models')
parser.add_argument('--agent_id', type=int, default=2,
- help='the learned agent plays as the agent_id-th player. choices are 1 and 2.')
+ help='the learned agent plays as the agent_id-th player. Choices are 1 and 2.')
parser.add_argument('--resume_path', type=str, default='',
help='the path of agent pth file for resuming from a pre-trained agent')
parser.add_argument('--opponent_path', type=str, default='',
@@ -229,7 +232,7 @@ The following ``get_agents`` function returns agents and their optimizers from e
- The action model we use is an instance of :class:`~tianshou.utils.net.common.Net`, essentially a multi-layer perceptron with the ReLU activation function;
- The network model is passed to a :class:`~tianshou.policy.DQNPolicy`, where actions are selected according to both the action mask and their Q-values;
-- The opponent can be either a random agent :class:`~tianshou.policy.RandomPolicy` that randomly chooses an action from legal actions, or it can be a pre-trained :class:`~tianshou.policy.DQNPolicy` allowing learned agents to play with themselves.
+- The opponent can be either a random agent :class:`~tianshou.policy.RandomPolicy` that randomly chooses an action from legal actions, or it can be a pre-trained :class:`~tianshou.policy.DQNPolicy` allowing learned agents to play with themselves.
Both agents are passed to :class:`~tianshou.policy.MultiAgentPolicyManager`, which is responsible to call the correct agent according to the ``agent_id`` in the observation. :class:`~tianshou.policy.MultiAgentPolicyManager` also dispatches data to each agent according to ``agent_id``, so that each agent seems to play with a virtual single-agent environment.
diff --git a/docs/tutorials/trick.rst b/docs/tutorials/trick.rst
index 5a73ff9..0de8b41 100644
--- a/docs/tutorials/trick.rst
+++ b/docs/tutorials/trick.rst
@@ -56,13 +56,13 @@ Algorithm specific tricks
Here is about the experience of hyper-parameter tuning on CartPole and Pendulum:
-* :class:`~tianshou.policy.DQNPolicy`: use estimation_step greater than 1 and target network, also with a suitable size of replay buffer;
+* :class:`~tianshou.policy.DQNPolicy`: use estimation_step = 3 or 4 and target network, also with a suitable size of replay buffer;
* :class:`~tianshou.policy.PGPolicy`: TBD
* :class:`~tianshou.policy.A2CPolicy`: TBD
* :class:`~tianshou.policy.PPOPolicy`: TBD
* :class:`~tianshou.policy.DDPGPolicy`, :class:`~tianshou.policy.TD3Policy`, and :class:`~tianshou.policy.SACPolicy`: We found two tricks. The first is to ignore the done flag. The second is to normalize reward to a standard normal distribution (it is against the theoretical analysis, but indeed works very well). The two tricks work amazingly on Mujoco tasks, typically with a faster converge speed (1M -> 200K).
-* On-policy algorithms: increase the repeat-time (to 2 or 4 for trivial benchmark, 10 for mujoco) of the given batch in each training update will make the algorithm more stable.
+* On-policy algorithms: increase the repeat-time (to 2 or 4 for trivial benchmark, 10 for mujoco) of the given batch in each training update will make the algorithm more stable.
Code-level optimization
@@ -70,8 +70,6 @@ Code-level optimization
Tianshou has many short-but-efficient lines of code. For example, when we want to compute :math:`V(s)` and :math:`V(s')` by the same network, the best way is to concatenate :math:`s` and :math:`s'` together instead of computing the value function using twice of network forward.
-.. Jiayi: I write each line of code after quite a lot of time of consideration. Details make a difference.
-
Atari/Mujoco Task Specific
--------------------------
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000..188700c
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,5 @@
+[pydocstyle]
+ignore = D100,D102,D104,D105,D107,D203,D213,D401,D402
+
+[doc8]
+max-line-length = 1000
diff --git a/setup.py b/setup.py
index 789ea2d..93acc4a 100644
--- a/setup.py
+++ b/setup.py
@@ -12,65 +12,61 @@ def get_version() -> str:
setup(
- name='tianshou',
+ name="tianshou",
version=get_version(),
- description='A Library for Deep Reinforcement Learning',
- long_description=open('README.md', encoding='utf8').read(),
- long_description_content_type='text/markdown',
- url='https://github.com/thu-ml/tianshou',
- author='TSAIL',
- author_email='trinkle23897@gmail.com',
- license='MIT',
- python_requires='>=3.6',
+ description="A Library for Deep Reinforcement Learning",
+ long_description=open("README.md", encoding="utf8").read(),
+ long_description_content_type="text/markdown",
+ url="https://github.com/thu-ml/tianshou",
+ author="TSAIL",
+ author_email="trinkle23897@gmail.com",
+ license="MIT",
+ python_requires=">=3.6",
classifiers=[
# How mature is this project? Common values are
# 3 - Alpha
# 4 - Beta
# 5 - Production/Stable
- 'Development Status :: 3 - Alpha',
+ "Development Status :: 4 - Beta",
# Indicate who your project is intended for
- 'Intended Audience :: Science/Research',
- 'Topic :: Scientific/Engineering :: Artificial Intelligence',
- 'Topic :: Software Development :: Libraries :: Python Modules',
+ "Intended Audience :: Science/Research",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: Software Development :: Libraries :: Python Modules",
# Pick your license as you wish (should match "license" above)
- 'License :: OSI Approved :: MIT License',
+ "License :: OSI Approved :: MIT License",
# Specify the Python versions you support here. In particular, ensure
# that you indicate whether you support Python 2, Python 3 or both.
- 'Programming Language :: Python :: 3.6',
- 'Programming Language :: Python :: 3.7',
- 'Programming Language :: Python :: 3.8',
+ "Programming Language :: Python :: 3.6",
+ "Programming Language :: Python :: 3.7",
+ "Programming Language :: Python :: 3.8",
],
- keywords='reinforcement learning platform pytorch',
- packages=find_packages(exclude=['test', 'test.*',
- 'examples', 'examples.*',
- 'docs', 'docs.*']),
+ keywords="reinforcement learning platform pytorch",
+ packages=find_packages(
+ exclude=["test", "test.*", "examples", "examples.*", "docs", "docs.*"]
+ ),
install_requires=[
- 'gym>=0.15.4',
- 'tqdm',
- 'numpy',
- 'tensorboard',
- 'torch>=1.4.0',
- 'numba>=0.51.0',
+ "gym>=0.15.4",
+ "tqdm",
+ "numpy",
+ "tensorboard",
+ "torch>=1.4.0",
+ "numba>=0.51.0",
],
extras_require={
- 'dev': [
- 'Sphinx',
- 'sphinx_rtd_theme',
- 'sphinxcontrib-bibtex',
- 'flake8',
- 'pytest',
- 'pytest-cov',
- 'ray>=0.8.0',
- ],
- 'atari': [
- 'atari_py',
- 'cv2',
- ],
- 'mujoco': [
- 'mujoco_py',
- ],
- 'pybullet': [
- 'pybullet',
+ "dev": [
+ "Sphinx",
+ "sphinx_rtd_theme",
+ "sphinxcontrib-bibtex",
+ "flake8",
+ "pytest",
+ "pytest-cov",
+ "ray>=0.8.0",
+ "mypy",
+ "pydocstyle",
+ "doc8",
],
+ "atari": ["atari_py", "cv2"],
+ "mujoco": ["mujoco_py"],
+ "pybullet": ["pybullet"],
},
)
diff --git a/test/base/test_collector.py b/test/base/test_collector.py
index 2175316..e7e1175 100644
--- a/test/base/test_collector.py
+++ b/test/base/test_collector.py
@@ -200,7 +200,6 @@ def test_collector_with_dict_state():
assert not np.isclose(obs[0]['rand'], obs[1]['rand'])
c1 = Collector(policy, envs, ReplayBuffer(size=100),
Logger.single_preprocess_fn)
- c1.seed(0)
c1.collect(n_step=10)
c1.collect(n_episode=[2, 1, 1, 2])
batch, _ = c1.buffer.sample(10)
diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py
index fe70d1f..7f7f10b 100644
--- a/tianshou/data/batch.py
+++ b/tianshou/data/batch.py
@@ -20,7 +20,7 @@ def _is_batch_set(data: Any) -> bool:
# or 1-D np.ndarray with np.object type,
# where each element is a dict/Batch object
if isinstance(data, np.ndarray): # most often case
- # ``for e in data`` will just unpack the first dimension,
+ # "for e in data" will just unpack the first dimension,
# but data.tolist() will flatten ndarray of objects
# so do not use data.tolist()
return data.dtype == np.object and \
@@ -79,7 +79,8 @@ def _to_array_with_correct_type(v: Any) -> np.ndarray:
def _create_value(inst: Any, size: int, stack=True) -> Union[
'Batch', np.ndarray, torch.Tensor]:
- """
+ """Create empty place-holders accroding to inst's shape.
+
:param bool stack: whether to stack or to concatenate. E.g. if inst has
shape of (3, 5), size = 10, stack=True returns an np.ndarry with shape
of (10, 3, 5), otherwise (10, 5)
@@ -154,9 +155,13 @@ def _parse_value(v: Any):
class Batch:
- """Tianshou provides :class:`~tianshou.data.Batch` as the internal data
- structure to pass any kind of data to other methods, for example, a
- collector gives a :class:`~tianshou.data.Batch` to policy for learning.
+ """The internal data structure in Tianshou.
+
+ Batch is a kind of supercharged array (of temporal data) stored
+ individually in a (recursive) dictionary of object that can be either numpy
+ array, torch tensor, or batch themself. It is designed to make it extremely
+ easily to access, manipulate and set partial view of the heterogeneous data
+ conveniently.
For a detailed description, please refer to :ref:`batch_concept`.
"""
@@ -180,12 +185,13 @@ class Batch:
self.__init__(kwargs, copy=copy)
def __setattr__(self, key: str, value: Any) -> None:
- """self.key = value"""
+ """Set self.key = value."""
self.__dict__[key] = _parse_value(value)
def __getstate__(self) -> dict:
- """Pickling interface. Only the actual data are serialized for both
- efficiency and simplicity.
+ """Pickling interface.
+
+ Only the actual data are serialized for both efficiency and simplicity.
"""
state = {}
for k, v in self.items():
@@ -195,9 +201,10 @@ class Batch:
return state
def __setstate__(self, state) -> None:
- """Unpickling interface. At this point, self is an empty Batch instance
- that has not been initialized, so it can safely be initialized by the
- pickle state.
+ """Unpickling interface.
+
+ At this point, self is an empty Batch instance that has not been
+ initialized, so it can safely be initialized by the pickle state.
"""
self.__init__(**state)
@@ -246,8 +253,7 @@ class Batch:
self.__dict__[key][index] = None
def __iadd__(self, other: Union['Batch', Number, np.number]):
- """Algebraic addition with another :class:`~tianshou.data.Batch`
- instance in-place."""
+ """Algebraic addition with another Batch instance in-place."""
if isinstance(other, Batch):
for (k, r), v in zip(self.__dict__.items(),
other.__dict__.values()):
@@ -268,14 +274,12 @@ class Batch:
raise TypeError("Only addition of Batch or number is supported.")
def __add__(self, other: Union['Batch', Number, np.number]):
- """Algebraic addition with another :class:`~tianshou.data.Batch`
- instance out-of-place."""
+ """Algebraic addition with another Batch instance out-of-place."""
return deepcopy(self).__iadd__(other)
def __imul__(self, val: Union[Number, np.number]):
"""Algebraic multiplication with a scalar value in-place."""
- assert _is_number(val), \
- "Only multiplication by a number is supported."
+ assert _is_number(val), "Only multiplication by a number is supported."
for k, r in self.__dict__.items():
if isinstance(r, Batch) and r.is_empty():
continue
@@ -288,8 +292,7 @@ class Batch:
def __itruediv__(self, val: Union[Number, np.number]):
"""Algebraic division with a scalar value in-place."""
- assert _is_number(val), \
- "Only division by a number is supported."
+ assert _is_number(val), "Only division by a number is supported."
for k, r in self.__dict__.items():
if isinstance(r, Batch) and r.is_empty():
continue
@@ -336,15 +339,11 @@ class Batch:
return self.__dict__.get(k, d)
def pop(self, k: str, d: Optional[Any] = None) -> Any:
- """Return and remove self[k] if k in self else d. d defaults to
- None.
- """
+ """Return & remove self[k] if k in self else d. d defaults to None."""
return self.__dict__.pop(k, d)
def to_numpy(self) -> None:
- """Change all torch.Tensor to numpy.ndarray. This is an in-place
- operation.
- """
+ """Change all torch.Tensor to numpy.ndarray in-place."""
for k, v in self.items():
if isinstance(v, torch.Tensor):
self.__dict__[k] = v.detach().cpu().numpy()
@@ -353,9 +352,7 @@ class Batch:
def to_torch(self, dtype: Optional[torch.dtype] = None,
device: Union[str, int, torch.device] = 'cpu') -> None:
- """Change all numpy.ndarray to torch.Tensor. This is an in-place
- operation.
- """
+ """Change all numpy.ndarray to torch.Tensor in-place."""
if not isinstance(device, torch.device):
device = torch.device(device)
@@ -382,7 +379,9 @@ class Batch:
def __cat(self,
batches: List[Union[dict, 'Batch']],
lens: List[int]) -> None:
- """::
+ """Private method for Batch.cat_.
+
+ ::
>>> a = Batch(a=np.random.randn(3, 4))
>>> x = Batch(a=a, b=np.random.randn(4, 4))
@@ -448,9 +447,7 @@ class Batch:
def cat_(self,
batches: Union['Batch', List[Union[dict, 'Batch']]]) -> None:
- """Concatenate a list of (or one) :class:`~tianshou.data.Batch` objects
- into current batch.
- """
+ """Concatenate a list of (or one) Batch objects into current batch."""
if isinstance(batches, Batch):
batches = [batches]
if len(batches) == 0:
@@ -477,10 +474,10 @@ class Batch:
@staticmethod
def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch':
- """Concatenate a list of :class:`~tianshou.data.Batch` object into a
- single new batch. For keys that are not shared across all batches,
- batches that do not have these keys will be padded by zeros with
- appropriate shapes. E.g.
+ """Concatenate a list of Batch object into a single new batch.
+
+ For keys that are not shared across all batches, batches that do not
+ have these keys will be padded by zeros with appropriate shapes. E.g.
::
>>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5])))
@@ -500,9 +497,7 @@ class Batch:
def stack_(self,
batches: List[Union[dict, 'Batch']],
axis: int = 0) -> None:
- """Stack a list of :class:`~tianshou.data.Batch` object into current
- batch.
- """
+ """Stack a list of Batch object into current batch."""
if len(batches) == 0:
return
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
@@ -553,9 +548,10 @@ class Batch:
@staticmethod
def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch':
- """Stack a list of :class:`~tianshou.data.Batch` object into a single
- new batch. For keys that are not shared across all batches,
- batches that do not have these keys will be padded by zeros. E.g.
+ """Stack a list of Batch object into a single new batch.
+
+ For keys that are not shared across all batches, batches that do not
+ have these keys will be padded by zeros. E.g.
::
>>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5])))
@@ -580,9 +576,9 @@ class Batch:
def empty_(self, index: Union[
str, slice, int, np.integer, np.ndarray, List[int]] = None
) -> 'Batch':
- """Return an empty a :class:`~tianshou.data.Batch` object with 0 or
- ``None`` filled. If ``index`` is specified, it will only reset the
- specific indexed-data.
+ """Return an empty Batch object with 0 or None filled.
+
+ If "index" is specified, it will only reset the specific indexed-data.
::
>>> data.empty_()
@@ -629,9 +625,9 @@ class Batch:
def empty(batch: 'Batch', index: Union[
str, slice, int, np.integer, np.ndarray, List[int]] = None
) -> 'Batch':
- """Return an empty :class:`~tianshou.data.Batch` object with 0 or
- ``None`` filled, the shape is the same as the given
- :class:`~tianshou.data.Batch`.
+ """Return an empty Batch object with 0 or None filled.
+
+ The shape is the same as the given Batch.
"""
return deepcopy(batch).empty_(index)
@@ -664,9 +660,10 @@ class Batch:
return min(r)
def is_empty(self, recurse: bool = False):
- """
- Test if a Batch is empty. If ``recurse=True``, it further tests the
- values of the object; else it only tests the existence of any key.
+ """Test if a Batch is empty.
+
+ If ``recurse=True``, it further tests the values of the object; else
+ it only tests the existence of any key.
``b.is_empty(recurse=True)`` is mainly used to distinguish
``Batch(a=Batch(a=Batch()))`` and ``Batch(a=1)``. They both raise
@@ -715,11 +712,11 @@ class Batch:
"""Split whole data into multiple small batches.
:param int size: divide the data batch with the given size, but one
- batch if the length of the batch is smaller than ``size``.
+ batch if the length of the batch is smaller than "size".
:param bool shuffle: randomly shuffle the entire data batch if it is
- ``True``, otherwise remain in the same. Default to ``True``.
+ True, otherwise remain in the same. Default to True.
:param bool merge_last: merge the last batch into the previous one.
- Default to ``False``.
+ Default to False.
"""
length = len(self)
assert 1 <= size # size can be greater than length, return whole batch
diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py
index 3c5658f..e9f47b9 100644
--- a/tianshou/data/buffer.py
+++ b/tianshou/data/buffer.py
@@ -7,9 +7,11 @@ from tianshou.data.batch import _create_value
class ReplayBuffer:
- """:class:`~tianshou.data.ReplayBuffer` stores data generated from
- interaction between the policy and environment. The current implementation
- of Tianshou typically use 7 reserved keys in :class:`~tianshou.data.Batch`:
+ """:class:`~tianshou.data.ReplayBuffer` stores data generated from \
+ interaction between the policy and environment.
+
+ The current implementation of Tianshou typically use 7 reserved keys in
+ :class:`~tianshou.data.Batch`
* ``obs`` the observation of step :math:`t` ;
* ``act`` the action of step :math:`t` ;
@@ -113,13 +115,12 @@ class ReplayBuffer:
:param int size: the size of replay buffer.
:param int stack_num: the frame-stack sampling argument, should be greater
than or equal to 1, defaults to 1 (no stacking).
- :param bool ignore_obs_next: whether to store obs_next, defaults to
- ``False``.
+ :param bool ignore_obs_next: whether to store obs_next, defaults to False.
:param bool save_only_last_obs: only save the last obs/obs_next when it has
a shape of (timestep, ...) because of temporal stacking, defaults to
- ``False``.
+ False.
:param bool sample_avail: the parameter indicating sampling only available
- index when using frame-stack sampling method, defaults to ``False``.
+ index when using frame-stack sampling method, defaults to False.
This feature is not supported in Prioritized Replay Buffer currently.
"""
@@ -150,15 +151,17 @@ class ReplayBuffer:
return self.__class__.__name__ + self._meta.__repr__()[5:]
def __getattr__(self, key: str) -> Any:
- """Return self.key"""
+ """Return self.key."""
try:
return self._meta[key]
except KeyError as e:
raise AttributeError from e
def __setstate__(self, state):
- """Unpickling interface. We need it because pickling buffer does not
- work out-of-the-box (``buffer.__getattr__`` is customized).
+ """Unpickling interface.
+
+ We need it because pickling buffer does not work out-of-the-box
+ ("buffer.__getattr__" is customized).
"""
self.__dict__.update(state)
@@ -260,7 +263,7 @@ class ReplayBuffer:
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
"""Get a random sample from buffer with size equal to batch_size. \
- Return all the data in the buffer if batch_size is ``0``.
+ Return all the data in the buffer if batch_size is 0.
:return: Sample data and its corresponding index inside the buffer.
"""
@@ -280,9 +283,11 @@ class ReplayBuffer:
def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str,
stack_num: Optional[int] = None) -> Union[Batch, np.ndarray]:
- """Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t],
- where s is self.key, t is indice. The stack_num (here equals to 4) is
- given from buffer initialization procedure.
+ """Return the stacked result.
+
+ E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the
+ indice. The stack_num (here equals to 4) is given from buffer
+ initialization procedure.
"""
if stack_num is None:
stack_num = self.stack_num
@@ -325,8 +330,10 @@ class ReplayBuffer:
def __getitem__(self, index: Union[
slice, int, np.integer, np.ndarray]) -> Batch:
- """Return a data batch: self[index]. If stack_num is larger than 1,
- return the stacked obs and obs_next with shape [batch, len, ...].
+ """Return a data batch: self[index].
+
+ If stack_num is larger than 1, return the stacked obs and obs_next
+ with shape (batch, len, ...).
"""
return Batch(
obs=self.get(index, 'obs'),
@@ -340,9 +347,11 @@ class ReplayBuffer:
class ListReplayBuffer(ReplayBuffer):
- """The function of :class:`~tianshou.data.ListReplayBuffer` is almost the
- same as :class:`~tianshou.data.ReplayBuffer`. The only difference is that
- :class:`~tianshou.data.ListReplayBuffer` is based on ``list``. Therefore,
+ """List-based replay buffer.
+
+ The function of :class:`~tianshou.data.ListReplayBuffer` is almost the same
+ as :class:`~tianshou.data.ReplayBuffer`. The only difference is that
+ :class:`~tianshou.data.ListReplayBuffer` is based on list. Therefore,
it does not support advanced indexing, which means you cannot sample a
batch of data out of it. It is typically used for storing data.
@@ -373,7 +382,7 @@ class ListReplayBuffer(ReplayBuffer):
class PrioritizedReplayBuffer(ReplayBuffer):
- """Implementation of Prioritized Experience Replay. arXiv:1511.05952
+ """Implementation of Prioritized Experience Replay. arXiv:1511.05952.
:param float alpha: the prioritization exponent.
:param float beta: the importance sample soft coefficient.
@@ -388,18 +397,11 @@ class PrioritizedReplayBuffer(ReplayBuffer):
super().__init__(size, **kwargs)
assert alpha > 0. and beta >= 0.
self._alpha, self._beta = alpha, beta
- self._max_prio = 1.
- self._min_prio = 1.
- # bypass the check
- self._weight = SegmentTree(size)
+ self._max_prio = self._min_prio = 1.0
+ # save weight directly in this class instead of self._meta
+ self.weight = SegmentTree(size)
self.__eps = np.finfo(np.float32).eps.item()
- def __getattr__(self, key: str) -> Union['Batch', Any]:
- """Return self.key"""
- if key == 'weight':
- return self._weight
- return super().__getattr__(key)
-
def add(self,
obs: Union[dict, Batch, np.ndarray, float],
act: Union[dict, Batch, np.ndarray, float],
@@ -418,15 +420,16 @@ class PrioritizedReplayBuffer(ReplayBuffer):
self._max_prio = max(self._max_prio, weight)
self._min_prio = min(self._min_prio, weight)
self.weight[self._index] = weight ** self._alpha
- super().add(obs, act, rew, done, obs_next, info, policy)
+ super().add(obs, act, rew, done, obs_next, info, policy, **kwargs)
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
- """Get a random sample from buffer with priority probability. Return
- all the data in the buffer if batch_size is ``0``.
+ """Get a random sample from buffer with priority probability.
+
+ Return all the data in the buffer if batch_size is 0.
:return: Sample data and its corresponding index inside the buffer.
- The ``weight`` in the returned Batch is the weight on loss function
+ The "weight" in the returned Batch is the weight on loss function
to de-bias the sampling process (some transition tuples are sampled
more often so their losses are weighted less).
"""
@@ -440,7 +443,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
scalar = np.random.rand(batch_size) * self.weight.reduce()
indice = self.weight.get_prefix_sum_idx(scalar)
batch = self[indice]
- # impt_weight
+ # important sampling weight calculation
# original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta))
# simplified formula: (p_j/p_min)**(-beta)
batch.weight = (batch.weight / self._min_prio) ** (-self._beta)
diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py
index 4ffacac..65f869a 100644
--- a/tianshou/data/collector.py
+++ b/tianshou/data/collector.py
@@ -14,8 +14,7 @@ from tianshou.data.batch import _create_value
class Collector(object):
- """The :class:`~tianshou.data.Collector` enables the policy to interact
- with different types of environments conveniently.
+ """Collector enables the policy to interact with different types of envs.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
class.
@@ -25,7 +24,7 @@ class Collector(object):
class. If set to ``None`` (testing phase), it will not store the data.
:param function preprocess_fn: a function called before the data has been
added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults
- to ``None``.
+ to None.
:param BaseNoise action_noise: add a noise to continuous action. Normally
a policy already has a noise param for exploration in training phase,
so this is recommended to use in test collector for some purpose.
@@ -42,7 +41,7 @@ class Collector(object):
:class:`~tianshou.data.Batch` with the modified keys and values. Examples
are in "test/base/test_collector.py".
- Example:
+ Here is the example:
::
policy = PGPolicy(...) # or other policies if you wish
@@ -139,9 +138,7 @@ class Collector(object):
return self.env_num
def reset_env(self) -> None:
- """Reset all of the environment(s)' states and reset all of the cache
- buffers (if need).
- """
+ """Reset all of the environment(s)' states and the cache buffers."""
self._ready_env_ids = np.arange(self.env_num)
obs = self.env.reset()
if self.preprocess_fn:
@@ -150,14 +147,6 @@ class Collector(object):
for b in self._cached_buf:
b.reset()
- def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
- """Reset all the seed(s) of the given environment(s)."""
- return self.env.seed(seed)
-
- def render(self, **kwargs) -> None:
- """Render all the environment(s)."""
- return self.env.render(**kwargs)
-
def _reset_state(self, id: Union[int, List[int]]) -> None:
"""Reset the hidden state: self.data.state[id]."""
state = self.data.state # it is a reference
@@ -183,11 +172,11 @@ class Collector(object):
a list, it means to collect exactly ``n_episode[i]`` episodes in
the i-th environment
:param bool random: whether to use random policy for collecting data,
- defaults to ``False``.
+ defaults to False.
:param float render: the sleep time between rendering consecutive
- frames, defaults to ``None`` (no rendering).
+ frames, defaults to None (no rendering).
:param bool no_grad: whether to retain gradient in policy.forward,
- defaults to ``True`` (no gradient retaining).
+ defaults to True (no gradient retaining).
.. note::
@@ -291,7 +280,7 @@ class Collector(object):
self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
if render:
- self.render()
+ self.env.render()
time.sleep(render)
# add data into the buffer
@@ -378,9 +367,10 @@ class Collector(object):
}
def sample(self, batch_size: int) -> Batch:
- """Sample a data batch from the internal replay buffer. It will call
- :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the
- final batch data.
+ """Sample a data batch from the internal replay buffer.
+
+ It will call :meth:`~tianshou.policy.BasePolicy.process_fn` before
+ returning the final batch data.
:param int batch_size: ``0`` means it will extract all the data from
the buffer, otherwise it will extract the data with the given
diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py
index e97b054..dd36c73 100644
--- a/tianshou/data/utils/converter.py
+++ b/tianshou/data/utils/converter.py
@@ -67,8 +67,9 @@ def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor],
def to_torch_as(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor],
y: torch.Tensor
) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
- """Return an object without np.ndarray. Same as
- ``to_torch(x, dtype=y.dtype, device=y.device)``.
+ """Return an object without np.ndarray.
+
+ Same as ``to_torch(x, dtype=y.dtype, device=y.device)``.
"""
assert isinstance(y, torch.Tensor)
return to_torch(x, dtype=y.dtype, device=y.device)
diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py
index e049f3a..c4b4871 100644
--- a/tianshou/data/utils/segtree.py
+++ b/tianshou/data/utils/segtree.py
@@ -4,13 +4,13 @@ from typing import Union, Optional
class SegmentTree:
- """Implementation of Segment Tree: store an array ``arr`` with size ``n``
- in a segment tree, support value update and fast query of the sum for the
- interval ``[left, right)`` in O(log n) time.
+ """Implementation of Segment Tree.
- The detailed procedure is as follows:
+ The segment tree stores an array ``arr`` with size ``n``. It supports value
+ update and fast query of the sum for the interval ``[left, right)`` in
+ O(log n) time. The detailed procedure is as follows:
- 1. Pad the array to have length of power of 2, so that leaf nodes in the\
+ 1. Pad the array to have length of power of 2, so that leaf nodes in the \
segment tree have the same depth.
2. Store the segment tree in a binary heap.
@@ -30,12 +30,14 @@ class SegmentTree:
def __getitem__(self, index: Union[int, np.ndarray]
) -> Union[float, np.ndarray]:
- """Return self[index]"""
+ """Return self[index]."""
return self._value[index + self._bound]
def __setitem__(self, index: Union[int, np.ndarray],
value: Union[float, np.ndarray]) -> None:
- """Duplicate values in ``index`` are handled by numpy: later index
+ """Update values in segment tree.
+
+ Duplicate values in ``index`` are handled by numpy: later index
overwrites previous ones.
::
@@ -61,9 +63,11 @@ class SegmentTree:
def get_prefix_sum_idx(
self, value: Union[float, np.ndarray]) -> Union[int, np.ndarray]:
- """Return the minimum index for each ``v`` in ``value`` so that
- :math:`v \\le \\mathrm{sums}_i`, where :math:`\\mathrm{sums}_i =
- \\sum_{j=0}^{i} \\mathrm{arr}_j`.
+ r"""Find the index with given value.
+
+ Return the minimum index for each ``v`` in ``value`` so that
+ :math:`v \le \mathrm{sums}_i`, where
+ :math:`\mathrm{sums}_i = \sum_{j = 0}^{i} \mathrm{arr}_j`.
.. warning::
@@ -81,7 +85,7 @@ class SegmentTree:
@njit
def _setitem(tree: np.ndarray, index: np.ndarray, value: np.ndarray) -> None:
- """4x faster: 0.1 -> 0.024"""
+ """Numba version, 4x faster: 0.1 -> 0.024."""
tree[index] = value
while index[0] > 1:
index //= 2
@@ -90,7 +94,7 @@ def _setitem(tree: np.ndarray, index: np.ndarray, value: np.ndarray) -> None:
@njit
def _reduce(tree: np.ndarray, start: int, end: int) -> float:
- """2x faster: 0.009 -> 0.005"""
+ """Numba version, 2x faster: 0.009 -> 0.005."""
# nodes in (start, end) should be aggregated
result = 0.
while end - start > 1: # (start, end) interval is not empty
@@ -106,7 +110,8 @@ def _reduce(tree: np.ndarray, start: int, end: int) -> float:
@njit
def _get_prefix_sum_idx(value: np.ndarray, bound: int,
sums: np.ndarray) -> np.ndarray:
- """numba version (v0.51), 5x speed up with size=100000 and bsz=64
+ """Numba version (v0.51), 5x speed up with size=100000 and bsz=64.
+
vectorized np: 0.0923 (numpy best) -> 0.024 (now)
for-loop: 0.2914 -> 0.019 (but not so stable)
"""
diff --git a/tianshou/env/maenv.py b/tianshou/env/maenv.py
index ea9284e..9153cf5 100644
--- a/tianshou/env/maenv.py
+++ b/tianshou/env/maenv.py
@@ -5,8 +5,10 @@ from abc import ABC, abstractmethod
class MultiAgentEnv(ABC, gym.Env):
- """The interface for multi-agent environments. Multi-agent environments
- must be wrapped as :class:`~tianshou.env.MultiAgentEnv`. Here is the usage:
+ """The interface for multi-agent environments.
+
+ Multi-agent environments must be wrapped as
+ :class:`~tianshou.env.MultiAgentEnv`. Here is the usage:
::
env = MultiAgentEnv(...)
@@ -25,18 +27,20 @@ class MultiAgentEnv(ABC, gym.Env):
@abstractmethod
def reset(self) -> dict:
- """Reset the state. Return the initial state, first agent_id, and the
- initial action set, for example,
- ``{'obs': obs, 'agent_id': agent_id, 'mask': mask}``
+ """Reset the state.
+
+ Return the initial state, first agent_id, and the initial action set,
+ for example, ``{'obs': obs, 'agent_id': agent_id, 'mask': mask}``
"""
pass
@abstractmethod
def step(self, action: np.ndarray
) -> Tuple[dict, np.ndarray, np.ndarray, np.ndarray]:
- """Run one timestep of the environment’s dynamics. When the end of
- episode is reached, you are responsible for calling reset() to reset
- the environment’s state.
+ """Run one timestep of the environment’s dynamics.
+
+ When the end of episode is reached, you are responsible for calling
+ reset() to reset the environment’s state.
Accept action and return a tuple (obs, rew, done, info).
diff --git a/tianshou/env/utils.py b/tianshou/env/utils.py
index 41b9ede..f7d8c58 100644
--- a/tianshou/env/utils.py
+++ b/tianshou/env/utils.py
@@ -2,7 +2,7 @@ import cloudpickle
class CloudpickleWrapper(object):
- """A cloudpickle wrapper used in :class:`~tianshou.env.SubprocVectorEnv`"""
+ """A cloudpickle wrapper used in SubprocVectorEnv."""
def __init__(self, data):
self.data = data
diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py
index 0432349..72c5b9e 100644
--- a/tianshou/env/venvs.py
+++ b/tianshou/env/venvs.py
@@ -8,7 +8,9 @@ from tianshou.env.worker import EnvWorker, DummyEnvWorker, SubprocEnvWorker, \
class BaseVectorEnv(gym.Env):
- """Base class for vectorized environments wrapper. Usage:
+ """Base class for vectorized environments wrapper.
+
+ Usage:
::
env_num = 8
@@ -45,7 +47,7 @@ class BaseVectorEnv(gym.Env):
:param env_fns: a list of callable envs, ``env_fns[i]()`` generates the ith
env.
:param worker_fn: a callable worker, ``worker_fn(env_fns[i])`` generates a
- worker which contains this env.
+ worker which contains the i-th env.
:param int wait_num: use in asynchronous simulation if the time cost of
``env.step`` varies with time and synchronously waiting for all
environments to finish a step is time-wasting. In that case, we can
@@ -98,10 +100,12 @@ class BaseVectorEnv(gym.Env):
return self.env_num
def __getattribute__(self, key: str) -> Any:
- """Any class who inherits ``gym.Env`` will inherit some attributes,
- like ``action_space``. However, we would like the attribute lookup to
- go straight into the worker (in fact, this vector env's action_space
- is always ``None``).
+ """Switch the attribute getter depending on the key.
+
+ Any class who inherits ``gym.Env`` will inherit some attributes, like
+ ``action_space``. However, we would like the attribute lookup to go
+ straight into the worker (in fact, this vector env's action_space is
+ always None).
"""
if key in ['metadata', 'reward_range', 'spec', 'action_space',
'observation_space']: # reserved keys in gym.Env
@@ -110,9 +114,11 @@ class BaseVectorEnv(gym.Env):
return super().__getattribute__(key)
def __getattr__(self, key: str) -> Any:
- """Try to retrieve an attribute from each individual wrapped
- environment, if it does not belong to the wrapping vector environment
- class.
+ """Fetch a list of env attributes.
+
+ This function tries to retrieve an attribute from each individual
+ wrapped environment, if it does not belong to the wrapping vector
+ environment class.
"""
return [getattr(worker, key) for worker in self.workers]
@@ -133,9 +139,11 @@ class BaseVectorEnv(gym.Env):
def reset(self, id: Optional[Union[int, List[int], np.ndarray]] = None
) -> np.ndarray:
- """Reset the state of all the environments and return initial
- observations if id is ``None``, otherwise reset the specific
- environments with the given id, either an int or a list.
+ """Reset the state of some envs and return initial observations.
+
+ If id is None, reset the state of all the environments and return
+ initial observations, otherwise reset the specific environments with
+ the given id, either an int or a list.
"""
self._assert_is_not_closed()
id = self._wrap_id(id)
@@ -148,7 +156,9 @@ class BaseVectorEnv(gym.Env):
action: np.ndarray,
id: Optional[Union[int, List[int], np.ndarray]] = None
) -> List[np.ndarray]:
- """Run one timestep of all the environments’ dynamics if id is "None",
+ """Run one timestep of some environments' dynamics.
+
+ If id is None, run one timestep of all the environments’ dynamics;
otherwise run one timestep for some environments with given id, either
an int or a list. When the end of episode is reached, you are
responsible for calling reset(id) to reset this environment’s state.
@@ -175,7 +185,7 @@ class BaseVectorEnv(gym.Env):
should correspond to the ``id`` argument, and the ``id`` argument
should be a subset of the ``env_id`` in the last returned ``info``
(initially they are env_ids of all the environments). If action is
- ``None``, fetch unfinished step() calls instead.
+ None, fetch unfinished step() calls instead.
"""
self._assert_is_not_closed()
id = self._wrap_id(id)
@@ -239,9 +249,11 @@ class BaseVectorEnv(gym.Env):
return [w.render(**kwargs) for w in self.workers]
def close(self) -> None:
- """Close all of the environments. This function will be called only
- once (if not, it will be called during garbage collected). This way,
- ``close`` of all workers can be assured.
+ """Close all of the environments.
+
+ This function will be called only once (if not, it will be called
+ during garbage collected). This way, ``close`` of all workers can be
+ assured.
"""
self._assert_is_not_closed()
for w in self.workers:
@@ -249,6 +261,7 @@ class BaseVectorEnv(gym.Env):
self.is_closed = True
def __del__(self) -> None:
+ """Redirect to self.close()."""
if not self.is_closed:
self.close()
@@ -270,6 +283,8 @@ class DummyVectorEnv(BaseVectorEnv):
class VectorEnv(DummyVectorEnv):
+ """VectorEnv is renamed to DummyVectorEnv."""
+
def __init__(self, *args, **kwargs) -> None:
warnings.warn(
'VectorEnv is renamed to DummyVectorEnv, and will be removed in '
@@ -296,9 +311,9 @@ class SubprocVectorEnv(BaseVectorEnv):
class ShmemVectorEnv(BaseVectorEnv):
- """Optimized version of SubprocVectorEnv which uses shared variables to
- communicate observations. ShmemVectorEnv has exactly the same API as
- SubprocVectorEnv.
+ """Optimized SubprocVectorEnv with shared buffers to exchange observations.
+
+ ShmemVectorEnv has exactly the same API as SubprocVectorEnv.
.. seealso::
@@ -316,9 +331,9 @@ class ShmemVectorEnv(BaseVectorEnv):
class RayVectorEnv(BaseVectorEnv):
- """Vectorized environment wrapper based on
- `ray `_. This is a choice to run
- distributed environments in a cluster.
+ """Vectorized environment wrapper based on ray.
+
+ This is a choice to run distributed environments in a cluster.
.. seealso::
diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py
index 2b56dab..c3600fa 100644
--- a/tianshou/env/worker/base.py
+++ b/tianshou/env/worker/base.py
@@ -30,10 +30,12 @@ class EnvWorker(ABC):
def step(self, action: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
- """``send_action`` and ``get_result`` are coupled in sync simulation,
- so typically users only call ``step`` function. But they can be called
- separately in async simulation, i.e. someone calls ``send_action``
- first, and calls ``get_result`` later.
+ """Perform one timestep of the environment's dynamic.
+
+ "send_action" and "get_result" are coupled in sync simulation, so
+ typically users only call "step" function. But they can be called
+ separately in async simulation, i.e. someone calls "send_action" first,
+ and calls "get_result" later.
"""
self.send_action(action)
return self.get_result()
diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py
index 893500b..b705913 100644
--- a/tianshou/env/worker/dummy.py
+++ b/tianshou/env/worker/dummy.py
@@ -22,7 +22,7 @@ class DummyEnvWorker(EnvWorker):
def wait(workers: List['DummyEnvWorker'],
wait_num: int,
timeout: Optional[float] = None) -> List['DummyEnvWorker']:
- # SequentialEnvWorker objects are always ready
+ # Sequential EnvWorker objects are always ready
return workers
def send_action(self, action: np.ndarray) -> None:
diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py
index 3186b01..857d148 100644
--- a/tianshou/env/worker/subproc.py
+++ b/tianshou/env/worker/subproc.py
@@ -76,7 +76,7 @@ _NP_TO_CT = {
class ShArray:
- """Wrapper of multiprocessing Array"""
+ """Wrapper of multiprocessing Array."""
def __init__(self, dtype, shape):
self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape)))
diff --git a/tianshou/exploration/random.py b/tianshou/exploration/random.py
index 19f4424..d9ef006 100644
--- a/tianshou/exploration/random.py
+++ b/tianshou/exploration/random.py
@@ -20,9 +20,7 @@ class BaseNoise(ABC, object):
class GaussianNoise(BaseNoise):
- """Class for vanilla gaussian process,
- used for exploration in DDPG by default.
- """
+ """The vanilla gaussian process, for exploration in DDPG by default."""
def __init__(self,
mu: float = 0.0,
@@ -38,6 +36,7 @@ class GaussianNoise(BaseNoise):
class OUNoise(BaseNoise):
"""Class for Ornstein-Uhlenbeck process, as used for exploration in DDPG.
+
Usage:
::
@@ -67,8 +66,9 @@ class OUNoise(BaseNoise):
self.reset()
def __call__(self, size: tuple, mu: Optional[float] = None) -> np.ndarray:
- """Generate new noise. Return a ``numpy.ndarray`` which size is equal
- to ``size``.
+ """Generate new noise.
+
+ Return a ``numpy.ndarray`` which size is equal to ``size``.
"""
if self._x is None or self._x.shape != size:
self._x = 0
diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py
index 7b943ae..a380670 100644
--- a/tianshou/policy/base.py
+++ b/tianshou/policy/base.py
@@ -11,8 +11,10 @@ from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \
class BasePolicy(ABC, nn.Module):
- """Tianshou aims to modularizing RL algorithms. It comes into several
- classes of policies in Tianshou. All of the policy classes must inherit
+ """The base class for any RL policy.
+
+ Tianshou aims to modularizing RL algorithms. It comes into several classes
+ of policies in Tianshou. All of the policy classes must inherit
:class:`~tianshou.policy.BasePolicy`.
A policy class typically has four parts:
@@ -29,27 +31,25 @@ class BasePolicy(ABC, nn.Module):
Most of the policy needs a neural network to predict the action and an
optimizer to optimize the policy. The rules of self-defined networks are:
- 1. Input: observation ``obs`` (may be a ``numpy.ndarray``, a \
- ``torch.Tensor``, a dict or any others), hidden state ``state`` (for \
- RNN usage), and other information ``info`` provided by the \
- environment.
- 2. Output: some ``logits``, the next hidden state ``state``, and the \
- intermediate result during policy forwarding procedure ``policy``. The\
- ``logits`` could be a tuple instead of a ``torch.Tensor``. It depends \
- on how the policy process the network output. For example, in PPO, the\
- return of the network might be ``(mu, sigma), state`` for Gaussian \
- policy. The ``policy`` can be a Batch of torch.Tensor or other things,\
- which will be stored in the replay buffer, and can be accessed in the \
- policy update process (e.g. in ``policy.learn()``, the \
- ``batch.policy`` is what you need).
+ 1. Input: observation "obs" (may be a ``numpy.ndarray``, a \
+ ``torch.Tensor``, a dict or any others), hidden state "state" (for RNN \
+ usage), and other information "info" provided by the environment.
+ 2. Output: some "logits", the next hidden state "state", and the \
+ intermediate result during policy forwarding procedure "policy". The \
+ "logits" could be a tuple instead of a ``torch.Tensor``. It depends on how\
+ the policy process the network output. For example, in PPO, the return of \
+ the network might be ``(mu, sigma), state`` for Gaussian policy. The \
+ "policy" can be a Batch of torch.Tensor or other things, which will be \
+ stored in the replay buffer, and can be accessed in the policy update \
+ process (e.g. in "policy.learn()", the "batch.policy" is what you need).
Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``,
you can use :class:`~tianshou.policy.BasePolicy` almost the same as
``torch.nn.Module``, for instance, loading and saving the model:
::
- torch.save(policy.state_dict(), 'policy.pth')
- policy.load_state_dict(torch.load('policy.pth'))
+ torch.save(policy.state_dict(), "policy.pth")
+ policy.load_state_dict(torch.load("policy.pth"))
"""
def __init__(self,
@@ -62,7 +62,7 @@ class BasePolicy(ABC, nn.Module):
self.agent_id = 0
def set_agent_id(self, agent_id: int) -> None:
- """set self.agent_id = agent_id, for MARL."""
+ """Set self.agent_id = agent_id, for MARL."""
self.agent_id = agent_id
@abstractmethod
@@ -86,7 +86,7 @@ class BasePolicy(ABC, nn.Module):
return Batch(logits=..., act=..., state=None, dist=...)
The keyword ``policy`` is reserved and the corresponding data will be
- stored into the replay buffer in numpy. For instance,
+ stored into the replay buffer. For instance,
::
# some code
@@ -98,8 +98,10 @@ class BasePolicy(ABC, nn.Module):
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
- """Pre-process the data from the provided replay buffer. Check out
- :ref:`policy_concept` for more information.
+ """Pre-process the data from the provided replay buffer.
+
+ Used in :meth:`update`. Check out :ref:`process_fn` for more
+ information.
"""
return batch
@@ -123,26 +125,28 @@ class BasePolicy(ABC, nn.Module):
def post_process_fn(self, batch: Batch,
buffer: ReplayBuffer, indice: np.ndarray) -> None:
- """Post-process the data from the provided replay buffer. Typical
- usage is to update the sampling weight in prioritized experience
- replay. Check out :ref:`policy_concept` for more information.
+ """Post-process the data from the provided replay buffer.
+
+ Typical usage is to update the sampling weight in prioritized
+ experience replay. Used in :meth:`update`.
"""
if isinstance(buffer, PrioritizedReplayBuffer) \
and hasattr(batch, 'weight'):
buffer.update_weight(indice, batch.weight)
- def update(self, batch_size: int, buffer: Optional[ReplayBuffer],
+ def update(self, sample_size: int, buffer: Optional[ReplayBuffer],
*args, **kwargs) -> Dict[str, Union[float, List[float]]]:
- """Update the policy network and replay buffer (if needed). It includes
- three function steps: process_fn, learn, and post_process_fn.
+ """Update the policy network and replay buffer.
- :param int batch_size: 0 means it will extract all the data from the
- buffer, otherwise it will sample a batch with the given batch_size.
+ It includes 3 function steps: process_fn, learn, and post_process_fn.
+
+ :param int sample_size: 0 means it will extract all the data from the
+ buffer, otherwise it will sample a batch with given sample_size.
:param ReplayBuffer buffer: the corresponding replay buffer.
"""
if buffer is None:
return {}
- batch, indice = buffer.sample(batch_size)
+ batch, indice = buffer.sample(sample_size)
batch = self.process_fn(batch, buffer, indice)
result = self.learn(batch, *args, **kwargs)
self.post_process_fn(batch, buffer, indice)
@@ -156,8 +160,9 @@ class BasePolicy(ABC, nn.Module):
gae_lambda: float = 0.95,
rew_norm: bool = False,
) -> Batch:
- """Compute returns over given full-length episodes, including the
- implementation of Generalized Advantage Estimator (arXiv:1506.02438).
+ """Compute returns over given full-length episodes.
+
+ Implementation of Generalized Advantage Estimator (arXiv:1506.02438).
:param batch: a data batch which contains several full-episode data
chronologically.
@@ -169,7 +174,7 @@ class BasePolicy(ABC, nn.Module):
:param float gae_lambda: the parameter for Generalized Advantage
Estimation, should be in [0, 1], defaults to 0.95.
:param bool rew_norm: normalize the reward to Normal(0, 1), defaults
- to ``False``.
+ to False.
:return: a Batch. The result will be stored in batch.returns as a numpy
array with shape (bsz, ).
@@ -192,13 +197,13 @@ class BasePolicy(ABC, nn.Module):
n_step: int = 1,
rew_norm: bool = False,
) -> Batch:
- r"""Compute n-step return for Q-learning targets:
+ r"""Compute n-step return for Q-learning targets.
.. math::
G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i +
\gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n})
- , where :math:`\gamma` is the discount factor,
+ where :math:`\gamma` is the discount factor,
:math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step
:math:`t`.
@@ -216,7 +221,7 @@ class BasePolicy(ABC, nn.Module):
:param int n_step: the number of estimation step, should be an int
greater than 0, defaults to 1.
:param bool rew_norm: normalize the reward to Normal(0, 1), defaults
- to ``False``.
+ to False.
:return: a Batch. The result will be stored in batch.returns as a
torch.Tensor with shape (bsz, ).
@@ -249,7 +254,7 @@ def _episodic_return(
v_s_: np.ndarray, rew: np.ndarray, done: np.ndarray,
gamma: float, gae_lambda: float,
) -> np.ndarray:
- """Numba speedup: 4.1s -> 0.057s"""
+ """Numba speedup: 4.1s -> 0.057s."""
returns = np.roll(v_s_, 1)
m = (1. - done) * gamma
delta = rew + v_s_ * m - returns
@@ -267,7 +272,7 @@ def _nstep_return(
indice: np.ndarray, gamma: float, n_step: int, buf_len: int,
mean: float, std: float
) -> np.ndarray:
- """Numba speedup: 0.3s -> 0.15s"""
+ """Numba speedup: 0.3s -> 0.15s."""
returns = np.zeros(indice.shape)
gammas = np.full(indice.shape, n_step)
for n in range(n_step - 1, -1, -1):
diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py
index 0f7cffd..dfdc02d 100644
--- a/tianshou/policy/modelfree/a2c.py
+++ b/tianshou/policy/modelfree/a2c.py
@@ -9,24 +9,23 @@ from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
class A2CPolicy(PGPolicy):
- """Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783
+ """Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783.
:param torch.nn.Module actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.nn.Module critic: the critic network. (s -> V(s))
:param torch.optim.Optimizer optim: the optimizer for actor and critic
network.
- :param torch.distributions.Distribution dist_fn: for computing the action,
- defaults to ``torch.distributions.Categorical``.
+ :param dist_fn: distribution class for computing the action.
:param float discount_factor: in [0, 1], defaults to 0.99.
:param float vf_coef: weight for value loss, defaults to 0.5.
:param float ent_coef: weight for entropy loss, defaults to 0.01.
:param float max_grad_norm: clipping gradients in back propagation,
- defaults to ``None``.
+ defaults to None.
:param float gae_lambda: in [0, 1], param for Generalized Advantage
Estimation, defaults to 0.95.
:param bool reward_normalization: normalize the reward to Normal(0, 1),
- defaults to ``False``.
+ defaults to False.
:param int max_batchsize: the maximum size of the batch when computing GAE,
depends on the size of available memory and the memory cost of the
model; should be as large as possible within the memory constraint;
diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py
index 6c34e34..f8cf60a 100644
--- a/tianshou/policy/modelfree/ddpg.py
+++ b/tianshou/policy/modelfree/ddpg.py
@@ -9,7 +9,7 @@ from tianshou.data import Batch, ReplayBuffer, to_torch_as
class DDPGPolicy(BasePolicy):
- """Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971
+ """Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971.
:param torch.nn.Module actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
@@ -25,9 +25,9 @@ class DDPGPolicy(BasePolicy):
:param action_range: the action range (minimum, maximum).
:type action_range: (float, float)
:param bool reward_normalization: normalize the reward to Normal(0, 1),
- defaults to ``False``.
+ defaults to False.
:param bool ignore_done: ignore the done flag while training the policy,
- defaults to ``False``.
+ defaults to False.
:param int estimation_step: greater than 1, the number of steps to look
ahead.
diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py
index 5c5e45d..4018750 100644
--- a/tianshou/policy/modelfree/dqn.py
+++ b/tianshou/policy/modelfree/dqn.py
@@ -8,12 +8,12 @@ from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
class DQNPolicy(BasePolicy):
- """Implementation of Deep Q Network. arXiv:1312.5602
+ """Implementation of Deep Q Network. arXiv:1312.5602.
- Implementation of Double Q-Learning. arXiv:1509.06461
+ Implementation of Double Q-Learning. arXiv:1509.06461.
Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is
- implemented in the network side, not here)
+ implemented in the network side, not here).
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
@@ -21,10 +21,10 @@ class DQNPolicy(BasePolicy):
:param float discount_factor: in [0, 1].
:param int estimation_step: greater than 1, the number of steps to look
ahead.
- :param int target_update_freq: the target network update frequency (``0``
- if you do not use the target network).
+ :param int target_update_freq: the target network update frequency (0 if
+ you do not use the target network).
:param bool reward_normalization: normalize the reward to Normal(0, 1),
- defaults to ``False``.
+ defaults to False.
.. seealso::
@@ -87,8 +87,10 @@ class DQNPolicy(BasePolicy):
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
- """Compute the n-step return for Q-learning targets. More details can
- be found at :meth:`~tianshou.policy.BasePolicy.compute_nstep_return`.
+ """Compute the n-step return for Q-learning targets.
+
+ More details can be found at
+ :meth:`~tianshou.policy.BasePolicy.compute_nstep_return`.
"""
batch = self.compute_nstep_return(
batch, buffer, indice, self._target_q,
@@ -101,9 +103,10 @@ class DQNPolicy(BasePolicy):
input: str = 'obs',
eps: Optional[float] = None,
**kwargs) -> Batch:
- """Compute action over the given batch data. If you need to mask the
- action, please add a "mask" into batch.obs, for example, if we have an
- environment that has "0/1/2" three actions:
+ """Compute action over the given batch data.
+
+ If you need to mask the action, please add a "mask" into batch.obs, for
+ example, if we have an environment that has "0/1/2" three actions:
::
batch == Batch(
diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py
index 3eaae64..e0a5212 100644
--- a/tianshou/policy/modelfree/pg.py
+++ b/tianshou/policy/modelfree/pg.py
@@ -12,7 +12,7 @@ class PGPolicy(BasePolicy):
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
- :param torch.distributions.Distribution dist_fn: for computing the action.
+ :param dist_fn: distribution class for computing the action.
:param float discount_factor: in [0, 1].
.. seealso::
@@ -38,12 +38,12 @@ class PGPolicy(BasePolicy):
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
- r"""Compute the discounted returns for each frame:
+ r"""Compute the discounted returns for each frame.
.. math::
G_t = \sum_{i=t}^T \gamma^{i-t}r_i
- , where :math:`T` is the terminal time step, :math:`\gamma` is the
+ where :math:`T` is the terminal time step, :math:`\gamma` is the
discount factor, :math:`\gamma \in [0, 1]`.
"""
# batch.returns = self._vanilla_returns(batch)
diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py
index 2db5baf..df84eb9 100644
--- a/tianshou/policy/modelfree/ppo.py
+++ b/tianshou/policy/modelfree/ppo.py
@@ -8,17 +8,17 @@ from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
class PPOPolicy(PGPolicy):
- r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347
+ r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347.
:param torch.nn.Module actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.nn.Module critic: the critic network. (s -> V(s))
:param torch.optim.Optimizer optim: the optimizer for actor and critic
network.
- :param torch.distributions.Distribution dist_fn: for computing the action.
+ :param dist_fn: distribution class for computing the action.
:param float discount_factor: in [0, 1], defaults to 0.99.
:param float max_grad_norm: clipping gradients in back propagation,
- defaults to ``None``.
+ defaults to None.
:param float eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original
paper, defaults to 0.2.
:param float vf_coef: weight for value loss, defaults to 0.5.
@@ -31,9 +31,9 @@ class PPOPolicy(PGPolicy):
where c > 1 is a constant indicating the lower bound,
defaults to 5.0 (set ``None`` if you do not want to use it).
:param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1,
- defaults to ``True``.
+ defaults to True.
:param bool reward_normalization: normalize the returns to Normal(0, 1),
- defaults to ``True``.
+ defaults to True.
:param int max_batchsize: the maximum size of the batch when computing GAE,
depends on the size of available memory and the memory cost of the
model; should be as large as possible within the memory constraint;
diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py
index dfbc60e..920bfb1 100644
--- a/tianshou/policy/modelfree/sac.py
+++ b/tianshou/policy/modelfree/sac.py
@@ -10,7 +10,7 @@ from tianshou.exploration import BaseNoise
class SACPolicy(DDPGPolicy):
- """Implementation of Soft Actor-Critic. arXiv:1812.05905
+ """Implementation of Soft Actor-Critic. arXiv:1812.05905.
:param torch.nn.Module actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
@@ -35,9 +35,9 @@ class SACPolicy(DDPGPolicy):
:param action_range: the action range (minimum, maximum).
:type action_range: (float, float)
:param bool reward_normalization: normalize the reward to Normal(0, 1),
- defaults to ``False``.
+ defaults to False.
:param bool ignore_done: ignore the done flag while training the policy,
- defaults to ``False``.
+ defaults to False.
:param BaseNoise exploration_noise: add a noise to action for exploration.
This is useful when solving hard-exploration problem.
diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py
index 9150f37..384d4b9 100644
--- a/tianshou/policy/modelfree/td3.py
+++ b/tianshou/policy/modelfree/td3.py
@@ -9,8 +9,7 @@ from tianshou.exploration import BaseNoise, GaussianNoise
class TD3Policy(DDPGPolicy):
- """Implementation of Twin Delayed Deep Deterministic Policy Gradient,
- arXiv:1802.09477
+ """Implementation of TD3, arXiv:1802.09477.
:param torch.nn.Module actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
@@ -37,9 +36,9 @@ class TD3Policy(DDPGPolicy):
:param action_range: the action range (minimum, maximum).
:type action_range: (float, float)
:param bool reward_normalization: normalize the reward to Normal(0, 1),
- defaults to ``False``.
+ defaults to False.
:param bool ignore_done: ignore the done flag while training the policy,
- defaults to ``False``.
+ defaults to False.
.. seealso::
diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py
index 74ab0ec..541481e 100644
--- a/tianshou/policy/multiagent/mapolicy.py
+++ b/tianshou/policy/multiagent/mapolicy.py
@@ -6,7 +6,9 @@ from tianshou.data import Batch, ReplayBuffer
class MultiAgentPolicyManager(BasePolicy):
- """This multi-agent policy manager accepts a list of
+ """Multi-agent policy manager for MARL.
+
+ This multi-agent policy manager accepts a list of
:class:`~tianshou.policy.BasePolicy`. It dispatches the batch data to each
of these policies when the "forward" is called. The same as "process_fn"
and "learn": it splits the data and feeds them to each policy. A figure in
@@ -28,8 +30,10 @@ class MultiAgentPolicyManager(BasePolicy):
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
- """Save original multi-dimensional rew in "save_rew", set rew to the
- reward of each agent during their ``process_fn``, and restore the
+ """Dispatch batch data from obs.agent_id to every policy's process_fn.
+
+ Save original multi-dimensional rew in "save_rew", set rew to the
+ reward of each agent during their "process_fn", and restore the
original reward afterwards.
"""
results = {}
@@ -57,7 +61,9 @@ class MultiAgentPolicyManager(BasePolicy):
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch]] = None,
**kwargs) -> Batch:
- """:param state: if None, it means all agents have no state. If not
+ """Dispatch batch data from obs.agent_id to every policy's forward.
+
+ :param state: if None, it means all agents have no state. If not
None, it should contain keys of "agent_1", "agent_2", ...
:return: a Batch with the following contents:
@@ -120,7 +126,9 @@ class MultiAgentPolicyManager(BasePolicy):
def learn(self, batch: Batch, **kwargs
) -> Dict[str, Union[float, List[float]]]:
- """:return: a dict with the following contents:
+ """Dispatch the data to all policies for learning.
+
+ :return: a dict with the following contents:
::
diff --git a/tianshou/policy/random.py b/tianshou/policy/random.py
index a300e8c..baac742 100644
--- a/tianshou/policy/random.py
+++ b/tianshou/policy/random.py
@@ -6,19 +6,20 @@ from tianshou.policy import BasePolicy
class RandomPolicy(BasePolicy):
- """A random agent used in multi-agent learning. It randomly chooses an
- action from the legal action.
+ """A random agent used in multi-agent learning.
+
+ It randomly chooses an action from the legal action.
"""
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs) -> Batch:
- """Compute the random action over the given batch data. The input
- should contain a mask in batch.obs, with "True" to be available and
- "False" to be unavailable.
- For example, ``batch.obs.mask == np.array([[False, True, False]])``
- means with batch size 1, action "1" is available but action "0" and
- "2" are unavailable.
+ """Compute the random action over the given batch data.
+
+ The input should contain a mask in batch.obs, with "True" to be
+ available and "False" to be unavailable. For example,
+ ``batch.obs.mask == np.array([[False, True, False]])`` means with batch
+ size 1, action "1" is available but action "0" and "2" are unavailable.
:return: A :class:`~tianshou.data.Batch` with "act" key, containing
the random action.
@@ -35,6 +36,5 @@ class RandomPolicy(BasePolicy):
def learn(self, batch: Batch, **kwargs
) -> Dict[str, Union[float, List[float]]]:
- """No need of a learn function for a random agent, so it returns an
- empty dict."""
+ """Since a random agent learn nothing, it returns an empty dict."""
return {}
diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py
index a7c6ffa..c04a4b5 100644
--- a/tianshou/trainer/offpolicy.py
+++ b/tianshou/trainer/offpolicy.py
@@ -28,8 +28,9 @@ def offpolicy_trainer(
verbose: bool = True,
test_in_train: bool = True,
) -> Dict[str, Union[float, str]]:
- """A wrapper for off-policy trainer procedure. The ``step`` in trainer
- means a policy network update.
+ """A wrapper for off-policy trainer procedure.
+
+ The "step" in trainer means a policy network update.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
class.
diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py
index 0a564fc..6af4317 100644
--- a/tianshou/trainer/onpolicy.py
+++ b/tianshou/trainer/onpolicy.py
@@ -28,8 +28,9 @@ def onpolicy_trainer(
verbose: bool = True,
test_in_train: bool = True,
) -> Dict[str, Union[float, str]]:
- """A wrapper for on-policy trainer procedure. The ``step`` in trainer means
- a policy network update.
+ """A wrapper for on-policy trainer procedure.
+
+ The "step" in trainer means a policy network update.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
class.
diff --git a/tianshou/utils/compile.py b/tianshou/utils/compile.py
index b3700b3..bf051bd 100644
--- a/tianshou/utils/compile.py
+++ b/tianshou/utils/compile.py
@@ -1,12 +1,13 @@
import numpy as np
-# functions that need to pre-compile for producing benchmark result
from tianshou.policy.base import _episodic_return, _nstep_return
from tianshou.data.utils.segtree import _reduce, _setitem, _get_prefix_sum_idx
def pre_compile():
- """Since Numba acceleration needs to compile the function in the first run,
+ """Functions that need to pre-compile for producing benchmark result.
+
+ Since Numba acceleration needs to compile the function in the first run,
here we use some fake data for the common-type function-call compilation.
Otherwise, the current training speed cannot compare with the previous.
"""
diff --git a/tianshou/utils/moving_average.py b/tianshou/utils/moving_average.py
index 7dfc0b1..a138b1c 100644
--- a/tianshou/utils/moving_average.py
+++ b/tianshou/utils/moving_average.py
@@ -6,8 +6,9 @@ from tianshou.data import to_numpy
class MovAvg(object):
- """Class for moving average. It will automatically exclude the infinity and
- NaN. Usage:
+ """Class for moving average.
+
+ It will automatically exclude the infinity and NaN. Usage:
::
>>> stat = MovAvg(size=66)
@@ -30,8 +31,10 @@ class MovAvg(object):
self.banned = [np.inf, np.nan, -np.inf]
def add(self, x: Union[float, list, np.ndarray, torch.Tensor]) -> float:
- """Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with
- only one element, a python scalar, or a list of python scalar.
+ """Add a scalar into :class:`MovAvg`.
+
+ You can add ``torch.Tensor`` with only one element, a python scalar, or
+ a list of python scalar.
"""
if isinstance(x, torch.Tensor):
x = to_numpy(x.flatten())
diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py
index eb68a97..8c4fcc5 100644
--- a/tianshou/utils/net/common.py
+++ b/tianshou/utils/net/common.py
@@ -8,6 +8,7 @@ from tianshou.data import to_torch
def miniblock(inp: int, oup: int,
norm_layer: nn.modules.Module) -> List[nn.modules.Module]:
+ """Construct a miniblock with given input/output-size and norm layer."""
ret = [nn.Linear(inp, oup)]
if norm_layer is not None:
ret += [norm_layer(oup)]
@@ -16,8 +17,10 @@ def miniblock(inp: int, oup: int,
class Net(nn.Module):
- """Simple MLP backbone. For advanced usage (how to customize the network),
- please refer to :ref:`build_the_network`.
+ """Simple MLP backbone.
+
+ For advanced usage (how to customize the network), please refer to
+ :ref:`build_the_network`.
:param bool concat: whether the input shape is concatenated by state_shape
and action_shape. If it is True, ``action_shape`` is not the output
@@ -76,7 +79,7 @@ class Net(nn.Module):
self.model = nn.Sequential(*self.model)
def forward(self, s, state=None, info={}):
- """s -> flatten -> logits"""
+ """Mapping: s -> flatten -> logits."""
s = to_torch(s, device=self.device, dtype=torch.float32)
s = s.reshape(s.size(0), -1)
logits = self.model(s)
@@ -89,8 +92,10 @@ class Net(nn.Module):
class Recurrent(nn.Module):
- """Simple Recurrent network based on LSTM. For advanced usage (how to
- customize the network), please refer to :ref:`build_the_network`.
+ """Simple Recurrent network based on LSTM.
+
+ For advanced usage (how to customize the network), please refer to
+ :ref:`build_the_network`.
"""
def __init__(self, layer_num, state_shape, action_shape,
@@ -106,9 +111,11 @@ class Recurrent(nn.Module):
self.fc2 = nn.Linear(hidden_layer_size, np.prod(action_shape))
def forward(self, s, state=None, info={}):
- """In the evaluation mode, s should be with shape ``[bsz, dim]``; in
- the training mode, s should be with shape ``[bsz, len, dim]``. See the
- code and comment for more detail.
+ """Mapping: s -> flatten -> logits.
+
+ In the evaluation mode, s should be with shape ``[bsz, dim]``; in the
+ training mode, s should be with shape ``[bsz, len, dim]``. See the code
+ and comment for more detail.
"""
s = to_torch(s, device=self.device, dtype=torch.float32)
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py
index 03a11f5..19586bb 100644
--- a/tianshou/utils/net/continuous.py
+++ b/tianshou/utils/net/continuous.py
@@ -6,7 +6,9 @@ from tianshou.data import to_torch, to_torch_as
class Actor(nn.Module):
- """For advanced usage (how to customize the network), please refer to
+ """Simple actor network with MLP.
+
+ For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
@@ -18,14 +20,16 @@ class Actor(nn.Module):
self._max = max_action
def forward(self, s, state=None, info={}):
- """s -> logits -> action"""
+ """Mapping: s -> logits -> action."""
logits, h = self.preprocess(s, state)
logits = self._max * torch.tanh(self.last(logits))
return logits, h
class Critic(nn.Module):
- """For advanced usage (how to customize the network), please refer to
+ """Simple critic network with MLP.
+
+ For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
@@ -36,7 +40,7 @@ class Critic(nn.Module):
self.last = nn.Linear(hidden_layer_size, 1)
def forward(self, s, a=None, info={}):
- """(s, a) -> logits -> Q(s, a)"""
+ """Mapping: (s, a) -> logits -> Q(s, a)."""
s = to_torch(s, device=self.device, dtype=torch.float32)
s = s.flatten(1)
if a is not None:
@@ -49,7 +53,9 @@ class Critic(nn.Module):
class ActorProb(nn.Module):
- """For advanced usage (how to customize the network), please refer to
+ """Simple actor network (output with a Gauss distribution) with MLP.
+
+ For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
@@ -64,7 +70,7 @@ class ActorProb(nn.Module):
self._unbounded = unbounded
def forward(self, s, state=None, info={}):
- """s -> logits -> (mu, sigma)"""
+ """Mapping: s -> logits -> (mu, sigma)."""
logits, h = self.preprocess(s, state)
mu = self.mu(logits)
if not self._unbounded:
@@ -76,7 +82,9 @@ class ActorProb(nn.Module):
class RecurrentActorProb(nn.Module):
- """For advanced usage (how to customize the network), please refer to
+ """Recurrent version of ActorProb.
+
+ For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
@@ -121,7 +129,9 @@ class RecurrentActorProb(nn.Module):
class RecurrentCritic(nn.Module):
- """For advanced usage (how to customize the network), please refer to
+ """Recurrent version of Critic.
+
+ For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py
index c7fed2b..03f4583 100644
--- a/tianshou/utils/net/discrete.py
+++ b/tianshou/utils/net/discrete.py
@@ -5,7 +5,9 @@ import torch.nn.functional as F
class Actor(nn.Module):
- """For advanced usage (how to customize the network), please refer to
+ """Simple actor network with MLP.
+
+ For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
@@ -15,14 +17,16 @@ class Actor(nn.Module):
self.last = nn.Linear(hidden_layer_size, np.prod(action_shape))
def forward(self, s, state=None, info={}):
- r"""s -> Q(s, \*)"""
+ r"""Mapping: s -> Q(s, \*)."""
logits, h = self.preprocess(s, state)
logits = F.softmax(self.last(logits), dim=-1)
return logits, h
class Critic(nn.Module):
- """For advanced usage (how to customize the network), please refer to
+ """Simple critic network with MLP.
+
+ For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
@@ -32,17 +36,17 @@ class Critic(nn.Module):
self.last = nn.Linear(hidden_layer_size, 1)
def forward(self, s, **kwargs):
- """s -> V(s)"""
+ """Mapping: s -> V(s)."""
logits, h = self.preprocess(s, state=kwargs.get('state', None))
logits = self.last(logits)
return logits
class DQN(nn.Module):
- """For advanced usage (how to customize the network), please refer to
- :ref:`build_the_network`.
+ """Reference: Human-level control through deep reinforcement learning.
- Reference paper: "Human-level control through deep reinforcement learning".
+ For advanced usage (how to customize the network), please refer to
+ :ref:`build_the_network`.
"""
def __init__(self, c, h, w, action_shape, device='cpu'):
@@ -78,7 +82,7 @@ class DQN(nn.Module):
)
def forward(self, x, state=None, info={}):
- r"""x -> Q(x, \*)"""
+ r"""Mapping: x -> Q(x, \*)."""
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, device=self.device, dtype=torch.float32)
return self.net(x), state