fix #69
This commit is contained in:
parent
1fce527c77
commit
ba1b3e54eb
9
.github/workflows/pytest.yml
vendored
9
.github/workflows/pytest.yml
vendored
@ -17,13 +17,14 @@ jobs:
|
||||
uses: actions/setup-python@v1
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install ".[dev]"
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install flake8
|
||||
flake8 . --count --show-source --statistics
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install ".[dev]"
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest test --cov tianshou --cov-report=xml --durations 0 -v
|
||||
|
1
CONTRIBUTING.rst
Symbolic link
1
CONTRIBUTING.rst
Symbolic link
@ -0,0 +1 @@
|
||||
docs/contributing.rst
|
@ -1,14 +1,15 @@
|
||||
============
|
||||
Contributing
|
||||
============
|
||||
|
||||
Install Develop Version
|
||||
-----------------------
|
||||
=======================
|
||||
|
||||
To install Tianshou in an "editable" mode, run
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip3 install -e .
|
||||
pip3 install -e ".[dev]"
|
||||
|
||||
in the main directory. This installation is removable by
|
||||
|
||||
@ -16,14 +17,8 @@ in the main directory. This installation is removable by
|
||||
|
||||
python3 setup.py develop --uninstall
|
||||
|
||||
Additional dependencies for developments can be installed by
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip3 install ".[dev]"
|
||||
|
||||
PEP8 Code Style Check
|
||||
---------------------
|
||||
=====================
|
||||
|
||||
We follow PEP8 python code style. To check, in the main directory, run:
|
||||
|
||||
@ -32,7 +27,7 @@ We follow PEP8 python code style. To check, in the main directory, run:
|
||||
flake8 . --count --show-source --statistics
|
||||
|
||||
Test Locally
|
||||
------------
|
||||
============
|
||||
|
||||
This command will run automatic tests in the main directory
|
||||
|
||||
@ -41,7 +36,7 @@ This command will run automatic tests in the main directory
|
||||
pytest test --cov tianshou -s --durations 0 -v
|
||||
|
||||
Test by GitHub Actions
|
||||
----------------------
|
||||
======================
|
||||
|
||||
1. Click the ``Actions`` button in your own repo:
|
||||
|
||||
@ -61,7 +56,7 @@ Test by GitHub Actions
|
||||
:align: center
|
||||
|
||||
Documentation
|
||||
-------------
|
||||
=============
|
||||
|
||||
Documentations are written under the ``docs/`` directory as ReStructuredText (``.rst``) files. ``index.rst`` is the main page. A Tutorial on ReStructuredText can be found `here <https://pythonhosted.org/an_example_pypi_project/sphinx.html>`_.
|
||||
|
||||
@ -74,3 +69,8 @@ To compile documentation into webpages, run
|
||||
make html
|
||||
|
||||
under the ``docs/`` directory. The generated webpages are in ``docs/_build`` and can be viewed with browsers.
|
||||
|
||||
Chinese Documentation
|
||||
=====================
|
||||
|
||||
Chinese documentation is in https://github.com/thu-ml/tianshou-docs-zh_CN.
|
||||
|
@ -1,3 +1,4 @@
|
||||
===========
|
||||
Contributor
|
||||
===========
|
||||
|
||||
|
@ -3,6 +3,7 @@
|
||||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
====================
|
||||
Welcome to Tianshou!
|
||||
====================
|
||||
|
||||
@ -25,9 +26,9 @@ Tianshou supports parallel workers for all algorithms as well. All of these algo
|
||||
|
||||
|
||||
Installation
|
||||
------------
|
||||
============
|
||||
|
||||
Tianshou is currently hosted on `PyPI <https://pypi.org/project/tianshou/>`_. You can simply install Tianshou with the following command:
|
||||
Tianshou is currently hosted on `PyPI <https://pypi.org/project/tianshou/>`_. You can simply install Tianshou with the following command (with Python >= 3.6):
|
||||
::
|
||||
|
||||
pip3 install tianshou
|
||||
@ -86,7 +87,7 @@ Tianshou is still under development, you can also check out the documents in sta
|
||||
|
||||
|
||||
Indices and tables
|
||||
------------------
|
||||
==================
|
||||
|
||||
* :ref:`genindex`
|
||||
* :ref:`modindex`
|
||||
|
@ -1,3 +1,4 @@
|
||||
==========================
|
||||
Basic concepts in Tianshou
|
||||
==========================
|
||||
|
||||
@ -9,7 +10,7 @@ Tianshou splits a Reinforcement Learning agent training procedure into these par
|
||||
|
||||
|
||||
Data Batch
|
||||
----------
|
||||
==========
|
||||
|
||||
.. automodule:: tianshou.data.Batch
|
||||
:members:
|
||||
@ -17,7 +18,7 @@ Data Batch
|
||||
|
||||
|
||||
Data Buffer
|
||||
-----------
|
||||
===========
|
||||
|
||||
.. automodule:: tianshou.data.ReplayBuffer
|
||||
:members:
|
||||
@ -28,7 +29,7 @@ Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListR
|
||||
.. _policy_concept:
|
||||
|
||||
Policy
|
||||
------
|
||||
======
|
||||
|
||||
Tianshou aims to modularizing RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`.
|
||||
|
||||
@ -90,7 +91,7 @@ For other method, you can check out :doc:`/api/tianshou.policy`. We give the usa
|
||||
|
||||
|
||||
Collector
|
||||
---------
|
||||
=========
|
||||
|
||||
The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently.
|
||||
In short, :class:`~tianshou.data.Collector` has two main methods:
|
||||
@ -106,7 +107,7 @@ The general explanation is listed in :ref:`pseudocode`. Other usages of collecto
|
||||
|
||||
|
||||
Trainer
|
||||
-------
|
||||
=======
|
||||
|
||||
Once you have a collector and a policy, you can start writing the training method for your RL agent. Trainer, to be honest, is a simple wrapper. It helps you save energy for writing the training loop. You can also construct your own trainer: :ref:`customized_trainer`.
|
||||
|
||||
@ -118,7 +119,7 @@ There will be more types of trainers, for instance, multi-agent trainer.
|
||||
.. _pseudocode:
|
||||
|
||||
A High-level Explanation
|
||||
------------------------
|
||||
========================
|
||||
|
||||
We give a high-level explanation through the pseudocode used in section :ref:`policy_concept`:
|
||||
::
|
||||
@ -141,6 +142,6 @@ We give a high-level explanation through the pseudocode used in section :ref:`po
|
||||
|
||||
|
||||
Conclusion
|
||||
----------
|
||||
==========
|
||||
|
||||
So far, we go through the overall framework of Tianshou. Really simple, isn't it?
|
||||
|
@ -1,3 +1,4 @@
|
||||
==============
|
||||
Deep Q Network
|
||||
==============
|
||||
|
||||
@ -10,7 +11,7 @@ Contrary to existing Deep RL libraries such as `RLlib <https://github.com/ray-pr
|
||||
|
||||
|
||||
Make an Environment
|
||||
-------------------
|
||||
===================
|
||||
|
||||
First of all, you have to make an environment for your agent to interact with. For environment interfaces, we follow the convention of `OpenAI Gym <https://github.com/openai/gym>`_. In your Python code, simply import Tianshou and make the environment:
|
||||
::
|
||||
@ -24,7 +25,7 @@ CartPole-v0 is a simple environment with a discrete action space, for which DQN
|
||||
|
||||
|
||||
Setup Multi-environment Wrapper
|
||||
-------------------------------
|
||||
===============================
|
||||
|
||||
It is available if you want the original ``gym.Env``:
|
||||
::
|
||||
@ -44,7 +45,7 @@ For the demonstration, here we use the second block of codes.
|
||||
|
||||
|
||||
Build the Network
|
||||
-----------------
|
||||
=================
|
||||
|
||||
Tianshou supports any user-defined PyTorch networks and optimizers but with the limitation of input and output API. Here is an example code:
|
||||
::
|
||||
@ -80,7 +81,7 @@ The rules of self-defined networks are:
|
||||
|
||||
|
||||
Setup Policy
|
||||
------------
|
||||
============
|
||||
|
||||
We use the defined ``net`` and ``optim``, with extra policy hyper-parameters, to define a policy. Here we define a DQN policy with using a target network:
|
||||
::
|
||||
@ -91,7 +92,7 @@ We use the defined ``net`` and ``optim``, with extra policy hyper-parameters, to
|
||||
|
||||
|
||||
Setup Collector
|
||||
---------------
|
||||
===============
|
||||
|
||||
The collector is a key concept in Tianshou. It allows the policy to interact with different types of environments conveniently.
|
||||
In each step, the collector will let the policy perform (at least) a specified number of steps or episodes and store the data in a replay buffer.
|
||||
@ -102,7 +103,7 @@ In each step, the collector will let the policy perform (at least) a specified n
|
||||
|
||||
|
||||
Train Policy with a Trainer
|
||||
---------------------------
|
||||
===========================
|
||||
|
||||
Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tianshou.trainer.offpolicy_trainer`. The trainer will automatically stop training when the policy reach the stop condition ``stop_fn`` on test collector. Since DQN is an off-policy algorithm, we use the :class:`~tianshou.trainer.offpolicy_trainer` as follows:
|
||||
::
|
||||
@ -158,7 +159,7 @@ It shows that within approximately 4 seconds, we finished training a DQN agent o
|
||||
|
||||
|
||||
Save/Load Policy
|
||||
----------------
|
||||
================
|
||||
|
||||
Since the policy inherits the ``torch.nn.Module`` class, saving and loading the policy are exactly the same as a torch module:
|
||||
::
|
||||
@ -168,7 +169,7 @@ Since the policy inherits the ``torch.nn.Module`` class, saving and loading the
|
||||
|
||||
|
||||
Watch the Agent's Performance
|
||||
-----------------------------
|
||||
=============================
|
||||
|
||||
:class:`~tianshou.data.Collector` supports rendering. Here is the example of watching the agent's performance in 35 FPS:
|
||||
::
|
||||
@ -181,7 +182,7 @@ Watch the Agent's Performance
|
||||
.. _customized_trainer:
|
||||
|
||||
Train a Policy with Customized Codes
|
||||
------------------------------------
|
||||
====================================
|
||||
|
||||
"I don't want to use your provided trainer. I want to customize it!"
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
======================================
|
||||
Train a model-free RL agent within 30s
|
||||
======================================
|
||||
|
||||
@ -7,7 +8,7 @@ You can also contribute to this page with your own tricks :)
|
||||
|
||||
|
||||
Avoid batch-size = 1
|
||||
--------------------
|
||||
====================
|
||||
|
||||
In the traditional RL training loop, we always use the policy to interact with only one environment for collecting data. That means most of the time the network use batch-size = 1. Quite inefficient!
|
||||
Here is an example of showing how inefficient it is:
|
||||
@ -52,7 +53,7 @@ By the way, A2C is better than A3C in some cases: A3C needs to act independently
|
||||
|
||||
|
||||
Algorithm specific tricks
|
||||
-------------------------
|
||||
=========================
|
||||
|
||||
Here is about the experience of hyper-parameter tuning on CartPole and Pendulum:
|
||||
|
||||
@ -66,7 +67,7 @@ Here is about the experience of hyper-parameter tuning on CartPole and Pendulum:
|
||||
|
||||
|
||||
Code-level optimization
|
||||
-----------------------
|
||||
=======================
|
||||
|
||||
Tianshou has many short-but-efficient lines of code. For example, when we want to compute :math:`V(s)` and :math:`V(s')` by the same network, the best way is to concatenate :math:`s` and :math:`s'` together instead of computing the value function using twice of network forward.
|
||||
|
||||
@ -74,7 +75,7 @@ Tianshou has many short-but-efficient lines of code. For example, when we want t
|
||||
|
||||
|
||||
Finally
|
||||
-------
|
||||
=======
|
||||
|
||||
With fast-speed sampling, we could use large batch-size and large learning rate for faster convergence.
|
||||
|
||||
|
@ -28,7 +28,7 @@ class MyTestEnv(gym.Env):
|
||||
if action == 0:
|
||||
self.index = max(self.index - 1, 0)
|
||||
if self.dict_state:
|
||||
return {'index': self.index}, 0, False, {}
|
||||
return {'index': self.index}, 0, False, {'key': 1, 'env': self}
|
||||
else:
|
||||
return self.index, 0, False, {}
|
||||
elif action == 1:
|
||||
@ -36,6 +36,7 @@ class MyTestEnv(gym.Env):
|
||||
self.done = self.index == self.size
|
||||
if self.dict_state:
|
||||
return {'index': self.index}, int(self.done), self.done, \
|
||||
{'key': 1}
|
||||
{'key': 1, 'env': self}
|
||||
else:
|
||||
return self.index, int(self.done), self.done, {'key': 1}
|
||||
return self.index, int(self.done), self.done, \
|
||||
{'key': 1, 'env': self}
|
||||
|
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
import pickle
|
||||
import torch
|
||||
import pickle
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from tianshou.data import Batch, to_torch
|
||||
@ -15,11 +15,15 @@ def test_batch():
|
||||
assert batch.obs == [1, 1]
|
||||
assert batch.np.shape == (6, 4)
|
||||
assert batch[0].obs == batch[1].obs
|
||||
with pytest.raises(IndexError):
|
||||
batch[2]
|
||||
batch.obs = np.arange(5)
|
||||
for i, b in enumerate(batch.split(1, shuffle=False)):
|
||||
assert b.obs == batch[i].obs
|
||||
if i != 5:
|
||||
assert b.obs == batch[i].obs
|
||||
else:
|
||||
with pytest.raises(AttributeError):
|
||||
batch[i].obs
|
||||
with pytest.raises(AttributeError):
|
||||
b.obs
|
||||
print(batch)
|
||||
|
||||
|
||||
@ -95,3 +99,7 @@ def test_batch_from_to_numpy_without_copy():
|
||||
if __name__ == '__main__':
|
||||
test_batch()
|
||||
test_batch_over_batch()
|
||||
test_batch_over_batch_to_torch()
|
||||
test_utils_to_torch()
|
||||
test_batch_pickle()
|
||||
test_batch_from_to_numpy_without_copy()
|
||||
|
@ -39,7 +39,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
|
||||
A = v.reset(np.where(C)[0])
|
||||
o.append([A, B, C, D])
|
||||
for i in zip(*o):
|
||||
for j in range(1, len(i)):
|
||||
for j in range(1, len(i) - 1):
|
||||
assert (i[0] == i[j]).all()
|
||||
else:
|
||||
t = [0, 0, 0]
|
||||
|
@ -122,8 +122,11 @@ class Batch:
|
||||
return self.__getattr__(index)
|
||||
b = Batch()
|
||||
for k, v in self.__dict__.items():
|
||||
if k != '_meta' and v is not None:
|
||||
b.__dict__.update(**{k: v[index]})
|
||||
if k != '_meta' and hasattr(v, '__len__'):
|
||||
try:
|
||||
b.__dict__.update(**{k: v[index]})
|
||||
except IndexError:
|
||||
continue
|
||||
b._meta = self._meta
|
||||
return b
|
||||
|
||||
@ -238,8 +241,8 @@ class Batch:
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return len(self)."""
|
||||
return min([len(v) for k, v in self.__dict__.items()
|
||||
if k != '_meta' and v is not None])
|
||||
r = [len(v) for k, v in self.__dict__.items() if hasattr(v, '__len__')]
|
||||
return max(r) if len(r) > 0 else 0
|
||||
|
||||
def split(self, size: Optional[int] = None,
|
||||
shuffle: bool = True) -> Iterator['Batch']:
|
||||
|
@ -1,7 +1,6 @@
|
||||
import pprint
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from typing import Tuple, Union, Optional
|
||||
from typing import Any, Tuple, Union, Optional
|
||||
|
||||
from tianshou.data.batch import Batch
|
||||
|
||||
@ -60,7 +59,7 @@ class ReplayBuffer(object):
|
||||
ReplayBuffer(
|
||||
act: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
|
||||
done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]),
|
||||
info: array([{}, {}, {}, {}, {}, {}, {}, {}, {}], dtype=object),
|
||||
info: Batch(),
|
||||
obs: Batch(
|
||||
id: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
|
||||
),
|
||||
@ -137,9 +136,7 @@ class ReplayBuffer(object):
|
||||
d[k_] = self.__dict__[k__]
|
||||
return Batch(**d)
|
||||
|
||||
def _add_to_buffer(
|
||||
self, name: str,
|
||||
inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None:
|
||||
def _add_to_buffer(self, name: str, inst: Any) -> None:
|
||||
if inst is None:
|
||||
if getattr(self, name, None) is None:
|
||||
self.__dict__[name] = None
|
||||
@ -153,18 +150,17 @@ class ReplayBuffer(object):
|
||||
self.__dict__[name] = np.zeros(
|
||||
(self._maxsize, *inst.shape), dtype=inst.dtype)
|
||||
elif isinstance(inst, (dict, Batch)):
|
||||
if name == 'info':
|
||||
self.__dict__[name] = np.array(
|
||||
[{} for _ in range(self._maxsize)])
|
||||
else:
|
||||
if self._meta.get(name, None) is None:
|
||||
self._meta[name] = list(inst.keys())
|
||||
for k in inst.keys():
|
||||
k_ = '_' + name + '@' + k
|
||||
self._add_to_buffer(k_, inst[k])
|
||||
else: # assume `inst` is a number
|
||||
if self._meta.get(name, None) is None:
|
||||
self._meta[name] = list(inst.keys())
|
||||
for k in inst.keys():
|
||||
k_ = '_' + name + '@' + k
|
||||
self._add_to_buffer(k_, inst[k])
|
||||
elif np.isscalar(inst):
|
||||
self.__dict__[name] = np.zeros(
|
||||
(self._maxsize,), dtype=np.asarray(inst).dtype)
|
||||
else: # fall back to np.object
|
||||
self.__dict__[name] = np.array(
|
||||
[None for _ in range(self._maxsize)])
|
||||
if isinstance(inst, np.ndarray) and \
|
||||
self.__dict__[name].shape[1:] != inst.shape:
|
||||
raise ValueError(
|
||||
@ -172,8 +168,6 @@ class ReplayBuffer(object):
|
||||
f"key: {name}, expect shape: {self.__dict__[name].shape[1:]}, "
|
||||
f"given shape: {inst.shape}.")
|
||||
if name not in self._meta:
|
||||
if name == 'info':
|
||||
inst = deepcopy(inst)
|
||||
self.__dict__[name][self._index] = inst
|
||||
|
||||
def update(self, buffer: 'ReplayBuffer') -> None:
|
||||
@ -198,7 +192,7 @@ class ReplayBuffer(object):
|
||||
policy: Optional[Union[dict, Batch]] = {},
|
||||
**kwargs) -> None:
|
||||
"""Add a batch of data into replay buffer."""
|
||||
assert isinstance(info, dict), \
|
||||
assert isinstance(info, (dict, Batch)), \
|
||||
'You should return a dict in the last argument of env.step().'
|
||||
self._add_to_buffer('obs', obs)
|
||||
self._add_to_buffer('act', act)
|
||||
@ -330,8 +324,6 @@ class ListReplayBuffer(ReplayBuffer):
|
||||
return
|
||||
if self.__dict__.get(name, None) is None:
|
||||
self.__dict__[name] = []
|
||||
if name == 'info':
|
||||
inst = deepcopy(inst)
|
||||
self.__dict__[name].append(inst)
|
||||
|
||||
def reset(self) -> None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user