add some docs

This commit is contained in:
Trinkle23897 2020-04-03 21:28:12 +08:00
parent 6cfa876591
commit 974ade8019
25 changed files with 296 additions and 151 deletions

View File

@ -29,11 +29,11 @@ Tianshou supports parallel workers for all algorithms as well. All of these algo
In Chinese, Tianshou means the innate talent, not taught by others. Tianshou is a reinforcement learning platform. As we know, an RL agent does not learn from humans, so taking "Tianshou" means that there is no teacher to study with, but to learn by interacting with an environment.
[天授](https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88)”意指上天所授,引申为与生具有的天赋。天授是强化学习平台,不是和人类学习的,所以取“天授”意思是没有老师来教,而是自己通过跟环境不断交互来进行学习。
“天授”意指上天所授,引申为与生具有的天赋。天授是强化学习平台,而强化学习算法并不是向人类学习的,所以取“天授”意思是没有老师来教,而是自己通过跟环境不断交互来进行学习。
## Installation
Tianshou is currently hosted on [PyPI](https://pypi.org/project/tianshou/). You can simply install Tianshou with the following command:
Tianshou is currently hosted on [PyPI](https://pypi.org/project/tianshou/). It requires Python >= 3.6. You can simply install Tianshou with the following command:
```bash
pip3 install tianshou

View File

@ -58,7 +58,7 @@ master_doc = 'index'
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
autodoc_default_options = {'special-members': '__call__, __getitem__, __len__'}
# -- Options for HTML output -------------------------------------------------

View File

@ -1,4 +1,4 @@
Basic Concepts in Tianshou
Basic concepts in Tianshou
==========================
Tianshou splits a Reinforcement Learning agent training procedure into these parts: trainer, collector, policy, and data buffer. The general control flow can be described as:
@ -11,85 +11,19 @@ Tianshou splits a Reinforcement Learning agent training procedure into these par
Data Batch
----------
Tianshou provides :class:`~tianshou.data.Batch` as the internal data structure to pass any kind of data to other methods, for example, a collector gives a :class:`~tianshou.data.Batch` to policy for learning. Here is its usage:
::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data = Batch(a=4, b=[5, 5], c='2312312')
>>> data.b
[5, 5]
>>> data.b = np.array([3, 4, 5])
>>> len(data.b)
3
>>> data.b[-1]
5
In short, you can define a :class:`~tianshou.data.Batch` with any key-value pair.
The current implementation of Tianshou typically use 6 keys in :class:`~tianshou.data.Batch`:
* ``obs``: the observation of step :math:`t` ;
* ``act``: the action of step :math:`t` ;
* ``rew``: the reward of step :math:`t` ;
* ``done``: the done flag of step :math:`t` ;
* ``obs_next``: the observation of step :math:`t+1` ;
* ``info``: the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` function return 4 arguments, and the last one is ``info``);
:class:`~tianshou.data.Batch` has other methods, including ``__getitem__``, ``append``, and ``split``:
::
>>> data = Batch(obs=np.array([0, 11, 22]), rew=np.array([6, 6, 6]))
>>> # here we test __getitem__
>>> index = [2, 1]
>>> data[index].obs
array([22, 11])
>>> data.append(data) # how we use a list
>>> data.obs
array([0, 11, 22, 0, 11, 22])
>>> # split whole data into multiple small batch
>>> for d in data.split(size=2, permute=False):
... print(d.obs, d.rew)
[ 0 11] [6 6]
[22 0] [6 6]
[11 22] [6 6]
.. automodule:: tianshou.data.Batch
:members:
:noindex:
Data Buffer
-----------
:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. It stores basically 6 types of data as mentioned above (7 types with importance weight in :class:`~tianshou.data.PrioritizedReplayBuffer`). Here is the :class:`~tianshou.data.ReplayBuffer`'s usage:
::
.. automodule:: tianshou.data.ReplayBuffer
:members:
:noindex:
>>> from tianshou.data import ReplayBuffer
>>> buf = ReplayBuffer(size=20)
>>> for i in range(3):
... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> buf.obs
# since we set size = 20, len(buf.obs) == 20.
array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0.])
>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> buf2.obs
# since its size = 10, it only stores the last 10 steps' result.
array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.])
>>> # move buf2's result into buf (keep it chronologically meanwhile)
>>> buf.update(buf2)
array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
0., 0., 0., 0., 0., 0., 0.])
>>> # get a random sample from buffer, the batch_data is equal to buf[incide].
>>> batch_data, indice = buf.sample(batch_size=4)
>>> batch_data.obs == buf[indice].obs
array([ True, True, True, True])
The :class:`~tianshou.data.ReplayBuffer` is based on ``numpy.ndarray``. Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out the API documentation for more detail.
Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out the :doc:`/api/tianshou.data` API documentation for more detail.
Policy
@ -110,12 +44,12 @@ Take 2-step return DQN as an example. The 2-step return DQN compute each frame's
G_t = r_t + \gamma r_{t + 1} + \gamma^2 \max_a Q(s_{t + 2}, a)
Here is the pseudocode showing the training process **without Tianshou framework**:
where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`. Here is the pseudocode showing the training process **without Tianshou framework**:
::
# pseudocode, cannot work
buffer = Buffer(size=10000)
s = env.reset()
buffer = Buffer(size=10000)
agent = DQN()
for i in range(int(1e6)):
a = agent.compute_action(s)
@ -155,8 +89,8 @@ For other method, you can check out the API documentation for more detail. We gi
::
# pseudocode, cannot work # methods in tianshou
buffer = Buffer(size=10000)
s = env.reset()
buffer = Buffer(size=10000) # buffer = tianshou.data.ReplayBuffer(size=10000)
agent = DQN() # done in policy.__init__(...)
for i in range(int(1e6)): # done in trainer
a = agent.compute_action(s) # done in policy.__call__(batch, ...)
@ -174,7 +108,7 @@ For other method, you can check out the API documentation for more detail. We gi
Collector
---------
The collector enables the policy to interact with different types of environments conveniently.
The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently.
* :meth:`~tianshou.data.Collector.collect`: let the policy perform (at least) a specified number of steps ``n_step`` or episodes ``n_episode`` and store the data in the replay buffer;
* :meth:`~tianshou.data.Collector.sample`: sample a data batch from replay buffer; it will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data.

View File

@ -3,6 +3,8 @@ Train a model-free RL agent within 30s
This page summarizes some hyper-parameter tuning experience and code-level trick when training a model-free DRL agent.
You can also contribute to this page with your own tricks :)
Avoid batch-size = 1
--------------------
@ -44,7 +46,7 @@ Here is an example of showing how inefficient it is:
The first test uses batch-size 128, and the second test uses batch-size = 1 for 128 times. In our test, the first is 70-80 times faster than the second.
So how could we avoid the case of batch-size = 1? The answer is synchronize sampling: we create multiple independent environments and sample simultaneously. It is similar to A2C, but other algorithms can also use this method. In our experiments, sampling from more environments benefits not only the sample speed but also the convergence of neural network (we guess it lowers the sample bias).
So how could we avoid the case of batch-size = 1? The answer is synchronize sampling: we create multiple independent environments and sample simultaneously. It is similar to A2C, but other algorithms can also use this method. In our experiments, sampling from more environments benefits not only the sample speed but also the converge speed of neural network (we guess it lowers the sample bias).
By the way, A2C is better than A3C in some cases: A3C needs to act independently and sync the gradient to master, but, in a single node, using A3C to act with batch-size = 1 is quite resource-consuming.
@ -74,6 +76,9 @@ Jiayi: I write each line of code after quite a lot of time of consideration. Det
Finally
-------
With fast-speed sampling, we could use large batch-size and large learning rate
With fast-speed sampling, we could use large batch-size and large learning rate for faster convergence.
RL algorithms are seed-sensitive. Try more seeds and pick the best. But for our demo, we just used seed = 0 and found it work surprisingly well on policy gradient, so we did not try other seed.
.. image:: ../_static/images/testpg.gif
:align: center

View File

@ -6,7 +6,7 @@ from tianshou.data import Batch
def test_batch():
batch = Batch(obs=[0], np=np.zeros([3, 4]))
batch.update(obs=[1])
batch.obs = [1]
assert batch.obs == [1]
batch.append(batch)
assert batch.obs == [1, 1]

View File

@ -5,10 +5,10 @@ import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.env import VectorEnv
from tianshou.policy import DDPGPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env import VectorEnv, SubprocVectorEnv
if __name__ == '__main__':
from net import Actor, Critic
@ -49,11 +49,12 @@ def test_ddpg(args=get_args()):
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
# you can also use tianshou.env.SubprocVectorEnv
# train_envs = gym.make(args.task)
train_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
test_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
@ -89,7 +90,7 @@ def test_ddpg(args=get_args()):
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
args.batch_size, stop_fn=stop_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -6,7 +6,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import PPOPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.env import VectorEnv
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
@ -53,11 +53,12 @@ def _test_ppo(args=get_args()):
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
# you can also use tianshou.env.SubprocVectorEnv
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
train_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
test_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
@ -97,8 +98,7 @@ def _test_ppo(args=get_args()):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
task=args.task)
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -5,10 +5,10 @@ import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.env import VectorEnv
from tianshou.policy import SACPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env import VectorEnv, SubprocVectorEnv
if __name__ == '__main__':
from net import ActorProb, Critic
@ -49,11 +49,12 @@ def test_sac(args=get_args()):
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
# you can also use tianshou.env.SubprocVectorEnv
# train_envs = gym.make(args.task)
train_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
test_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
@ -94,7 +95,7 @@ def test_sac(args=get_args()):
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
args.batch_size, stop_fn=stop_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -5,10 +5,10 @@ import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.env import VectorEnv
from tianshou.policy import TD3Policy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env import VectorEnv, SubprocVectorEnv
if __name__ == '__main__':
from net import Actor, Critic
@ -52,11 +52,12 @@ def test_td3(args=get_args()):
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
# you can also use tianshou.env.SubprocVectorEnv
# train_envs = gym.make(args.task)
train_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
test_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
@ -98,7 +99,7 @@ def test_td3(args=get_args()):
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
args.batch_size, stop_fn=stop_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -6,7 +6,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import A2CPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.env import VectorEnv
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
@ -45,14 +45,16 @@ def get_args():
def test_a2c(args=get_args()):
torch.set_num_threads(1) # for poor CPU
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# you can also use tianshou.env.SubprocVectorEnv
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
train_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
test_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
@ -83,8 +85,7 @@ def test_a2c(args=get_args()):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
task=args.task)
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -5,8 +5,8 @@ import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DQNPolicy
from tianshou.env import VectorEnv
from tianshou.policy import DQNPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
@ -48,6 +48,7 @@ def test_dqn(args=get_args()):
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
@ -89,7 +90,7 @@ def test_dqn(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, writer=writer, task=args.task)
stop_fn=stop_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()

View File

@ -6,8 +6,8 @@ import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import PGPolicy
from tianshou.env import VectorEnv
from tianshou.policy import PGPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Batch, Collector, ReplayBuffer
@ -25,7 +25,7 @@ def compute_return_base(batch, aa=None, bb=None, gamma=0.1):
if not batch.done[i]:
returns[i] += last * gamma
last = returns[i]
batch.update(returns=returns)
batch.returns = returns
return batch
@ -99,6 +99,7 @@ def test_pg(args=get_args()):
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
@ -131,8 +132,7 @@ def test_pg(args=get_args()):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
task=args.task)
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -5,8 +5,8 @@ import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.env import VectorEnv
from tianshou.policy import PPOPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
@ -46,14 +46,16 @@ def get_args():
def test_ppo(args=get_args()):
torch.set_num_threads(1) # for poor CPU
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
# you can also use tianshou.env.SubprocVectorEnv
train_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
test_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
@ -88,8 +90,7 @@ def test_ppo(args=get_args()):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
task=args.task)
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -3,23 +3,83 @@ import numpy as np
class Batch(object):
"""Suggested keys: [obs, act, rew, done, obs_next, info]"""
"""
Tianshou provides :class:`~tianshou.data.Batch` as the internal data
structure to pass any kind of data to other methods, for example, a
collector gives a :class:`~tianshou.data.Batch` to policy for learning.
Here is the usage:
::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data = Batch(a=4, b=[5, 5], c='2312312')
>>> data.b
[5, 5]
>>> data.b = np.array([3, 4, 5])
>>> len(data.b)
3
>>> data.b[-1]
5
In short, you can define a :class:`Batch` with any key-value pair. The
current implementation of Tianshou typically use 6 keys in
:class:`~tianshou.data.Batch`:
* ``obs``: the observation of step :math:`t` ;
* ``act``: the action of step :math:`t` ;
* ``rew``: the reward of step :math:`t` ;
* ``done``: the done flag of step :math:`t` ;
* ``obs_next``: the observation of step :math:`t+1` ;
* ``info``: the info of step :math:`t` (in ``gym.Env``, the ``env.step()``\
function return 4 arguments, and the last one is ``info``);
:class:`~tianshou.data.Batch` has other methods, including
:meth:`~tianshou.data.Batch.__getitem__`,
:meth:`~tianshou.data.Batch.__len__`,
:meth:`~tianshou.data.Batch.append`,
and :meth:`~tianshou.data.Batch.split`:
::
>>> data = Batch(obs=np.array([0, 11, 22]), rew=np.array([6, 6, 6]))
>>> # here we test __getitem__
>>> index = [2, 1]
>>> data[index].obs
array([22, 11])
>>> # here we test __len__
>>> len(data)
3
>>> data.append(data) # similar to list.append
>>> data.obs
array([0, 11, 22, 0, 11, 22])
>>> # split whole data into multiple small batch
>>> for d in data.split(size=2, permute=False):
... print(d.obs, d.rew)
[ 0 11] [6 6]
[22 0] [6 6]
[11 22] [6 6]
"""
def __init__(self, **kwargs):
super().__init__()
self.__dict__.update(kwargs)
def __getitem__(self, index):
"""
Return self[index].
"""
b = Batch()
for k in self.__dict__.keys():
if self.__dict__[k] is not None:
b.update(**{k: self.__dict__[k][index]})
b.__dict__.update(**{k: self.__dict__[k][index]})
return b
def update(self, **kwargs):
self.__dict__.update(kwargs)
def append(self, batch):
"""
Append a :class:`~tianshou.data.Batch` object to the end.
"""
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
for k in batch.__dict__.keys():
if batch.__dict__[k] is None:
@ -40,10 +100,24 @@ class Batch(object):
+ 'in class Batch.'
raise TypeError(s)
def split(self, size=None, permute=True):
length = min([
def __len__(self):
"""
Return len(self).
"""
return min([
len(self.__dict__[k]) for k in self.__dict__.keys()
if self.__dict__[k] is not None])
def split(self, size=None, permute=True):
"""
Split whole data into multiple small batch.
:param size: if equals to ``None``, it does not split the data batch; \
otherwise it will divide the data batch with the given size.
:param permute: randomly shuffle the entire data batch if it equals to\
``True``, otherwise remain in the same.
"""
length = len(self)
if size is None:
size = length
temp = 0

View File

@ -3,7 +3,40 @@ from tianshou.data.batch import Batch
class ReplayBuffer(object):
"""docstring for ReplayBuffer"""
"""
:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction
between the policy and environment. It stores basically 6 types of data, as
mentioned in :class:`~tianshou.data.Batch`, based on ``numpy.ndarray``.
Here is the usage:
::
>>> from tianshou.data import ReplayBuffer
>>> buf = ReplayBuffer(size=20)
>>> for i in range(3):
... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> buf.obs
# since we set size = 20, len(buf.obs) == 20.
array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.])
>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> buf2.obs
# since its size = 10, it only stores the last 10 steps' result.
array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.])
>>> # move buf2's result into buf (keep it chronologically meanwhile)
>>> buf.update(buf2)
array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
0., 0., 0., 0., 0., 0., 0.])
>>> # get a random sample from buffer
>>> # the batch_data is equal to buf[incide].
>>> batch_data, indice = buf.sample(batch_size=4)
>>> batch_data.obs == buf[indice].obs
array([ True, True, True, True])
"""
def __init__(self, size):
super().__init__()
@ -15,6 +48,9 @@ class ReplayBuffer(object):
del self.__dict__[k]
def __len__(self):
"""
Return len(self).
"""
return self._size
def _add_to_buffer(self, name, inst):
@ -34,6 +70,9 @@ class ReplayBuffer(object):
self.__dict__[name][self._index] = inst
def update(self, buffer):
"""
Move the data from the given buffer to self.
"""
i = begin = buffer._index % len(buffer)
while True:
self.add(
@ -45,7 +84,7 @@ class ReplayBuffer(object):
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None):
'''
weight: importance weights, disabled here
Add a batch of data into replay buffer.
'''
assert isinstance(info, dict), \
'You should return a dict in the last argument of env.step().'
@ -62,10 +101,18 @@ class ReplayBuffer(object):
self._size = self._index = self._index + 1
def reset(self):
"""
Clear all the data in replay buffer.
"""
self._index = self._size = 0
self.indice = []
def sample(self, batch_size):
"""
Get a random sample from buffer with size = ``batch_size``
:return: Sample data and its corresponding index inside the buffer.
"""
if batch_size > 0:
indice = np.random.choice(self._size, batch_size)
else:
@ -76,6 +123,9 @@ class ReplayBuffer(object):
return self[indice], indice
def __getitem__(self, index):
"""
Return a data batch: self[index].
"""
return Batch(
obs=self.obs[index],
act=self.act[index],
@ -87,7 +137,12 @@ class ReplayBuffer(object):
class ListReplayBuffer(ReplayBuffer):
"""docstring for ListReplayBuffer"""
"""
The function of :class:`~tianshou.data.ListReplayBuffer` is almost the same
as :class:`~tianshou.data.ReplayBuffer`. The only difference is that
:class:`~tianshou.data.ListReplayBuffer` is based on ``list``.
"""
def __init__(self):
super().__init__(size=0)
@ -111,7 +166,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
def __init__(self, size):
super().__init__(size)
def add(self, obs, act, rew, done, obs_next, info={}, weight=None):
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None):
raise NotImplementedError
def sample(self, batch_size):

View File

@ -2,7 +2,21 @@ import numpy as np
class OUNoise(object):
"""docstring for OUNoise"""
"""
Class for Ornstein-Uhlenbeck process, as used for exploration in DDPG.
Usage:
::
# init
self.noise = OUNoise()
# generate noise
noise = self.noise(logits.shape, eps)
For required parameters, you can refer to the stackoverflow page. However,
our experiment result shows that (similar to OpenAI SpinningUp) using
vanilla gaussian process has little difference from using the
Ornstein-Uhlenbeck process.
"""
def __init__(self, sigma=0.3, theta=0.15, dt=1e-2, x0=None):
self.alpha = theta * dt
@ -11,6 +25,10 @@ class OUNoise(object):
self.reset()
def __call__(self, size, mu=.1):
"""
Generate new noise. Return a ``numpy.ndarray`` which size is equal to
``size``.
"""
if self.x is None or self.x.shape != size:
self.x = 0
r = self.beta * np.random.normal(size=size)
@ -18,4 +36,7 @@ class OUNoise(object):
return self.x
def reset(self):
"""
Reset to the initial state.
"""
self.x = None

View File

@ -46,10 +46,10 @@ class A2CPolicy(PGPolicy):
nn.utils.clip_grad_norm_(
self.model.parameters(), max_norm=self._grad_norm)
self.optim.step()
actor_losses.append(a_loss.detach().cpu().numpy())
vf_losses.append(vf_loss.detach().cpu().numpy())
ent_losses.append(ent_loss.detach().cpu().numpy())
losses.append(loss.detach().cpu().numpy())
actor_losses.append(a_loss.item())
vf_losses.append(vf_loss.item())
ent_losses.append(ent_loss.item())
losses.append(loss.item())
return {
'loss': losses,
'loss/actor': actor_losses,

View File

@ -107,6 +107,6 @@ class DDPGPolicy(BasePolicy):
self.actor_optim.step()
self.sync_weight()
return {
'loss/actor': actor_loss.detach().cpu().numpy(),
'loss/critic': critic_loss.detach().cpu().numpy(),
'loss/actor': actor_loss.item(),
'loss/critic': critic_loss.item(),
}

View File

@ -17,7 +17,7 @@ class DQNPolicy(BasePolicy):
self.model = model
self.optim = optim
self.eps = 0
assert 0 < discount_factor <= 1, 'discount_factor should in (0, 1]'
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
self._gamma = discount_factor
assert estimation_step > 0, 'estimation_step should greater than 0'
self._n_step = estimation_step
@ -66,7 +66,7 @@ class DQNPolicy(BasePolicy):
target_q = target_q.max(axis=1)
target_q[gammas != self._n_step] = 0
returns += (self._gamma ** gammas) * target_q
batch.update(returns=returns)
batch.returns = returns
return batch
def __call__(self, batch, state=None,
@ -96,4 +96,4 @@ class DQNPolicy(BasePolicy):
loss.backward()
self.optim.step()
self._cnt += 1
return {'loss': loss.detach().cpu().numpy()}
return {'loss': loss.item()}

View File

@ -20,9 +20,8 @@ class PGPolicy(BasePolicy):
self._gamma = discount_factor
def process_fn(self, batch, buffer, indice):
returns = self._vanilla_returns(batch)
# returns = self._vectorized_returns(batch)
batch.update(returns=returns)
batch.returns = self._vanilla_returns(batch)
# batch.returns = self._vectorized_returns(batch)
return batch
def __call__(self, batch, state=None):
@ -45,7 +44,7 @@ class PGPolicy(BasePolicy):
loss = -(dist.log_prob(a) * r).sum()
loss.backward()
self.optim.step()
losses.append(loss.detach().cpu().numpy())
losses.append(loss.item())
return {'loss': losses}
def _vanilla_returns(self, batch):

View File

@ -76,19 +76,19 @@ class PPOPolicy(PGPolicy):
surr2 = ratio.clamp(
1. - self._eps_clip, 1. + self._eps_clip) * adv
clip_loss = -torch.min(surr1, surr2).mean()
clip_losses.append(clip_loss.detach().cpu().numpy())
clip_losses.append(clip_loss.item())
vf_loss = F.smooth_l1_loss(self.critic(b.obs), target_v)
vf_losses.append(vf_loss.detach().cpu().numpy())
vf_losses.append(vf_loss.item())
e_loss = dist.entropy().mean()
ent_losses.append(e_loss.detach().cpu().numpy())
ent_losses.append(e_loss.item())
loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss
losses.append(loss.detach().cpu().numpy())
losses.append(loss.item())
self.optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(list(
self.actor.parameters()) + list(self.critic.parameters()),
self._max_grad_norm)
self._max_grad_norm)
self.optim.step()
self.sync_weight()
return {

View File

@ -99,7 +99,7 @@ class SACPolicy(DDPGPolicy):
self.actor_optim.step()
self.sync_weight()
return {
'loss/actor': actor_loss.detach().cpu().numpy(),
'loss/critic1': critic1_loss.detach().cpu().numpy(),
'loss/critic2': critic2_loss.detach().cpu().numpy(),
'loss/actor': actor_loss.item(),
'loss/critic1': critic1_loss.item(),
'loss/critic2': critic2_loss.item(),
}

View File

@ -81,12 +81,12 @@ class TD3Policy(DDPGPolicy):
batch.obs, self(batch, eps=0).act).mean()
self.actor_optim.zero_grad()
actor_loss.backward()
self._last = actor_loss.detach().cpu().numpy()
self._last = actor_loss.item()
self.actor_optim.step()
self.sync_weight()
self._cnt += 1
return {
'loss/actor': self._last,
'loss/critic1': critic1_loss.detach().cpu().numpy(),
'loss/critic2': critic2_loss.detach().cpu().numpy(),
'loss/critic1': critic1_loss.item(),
'loss/critic2': critic2_loss.item(),
}

View File

@ -3,6 +3,9 @@ import numpy as np
def test_episode(policy, collector, test_fn, epoch, n_episode):
"""
A simple wrapper of testing policy in collector.
"""
collector.reset_env()
collector.reset_buffer()
policy.eval()
@ -17,6 +20,24 @@ def test_episode(policy, collector, test_fn, epoch, n_episode):
def gather_info(start_time, train_c, test_c, best_reward):
"""
A simple wrapper of gathering information from collectors.
:return: A dictionary with following keys:
* ``train_step``: the total collected step of training collector;
* ``train_episode``: the total collected episode of training collector;
* ``train_time/collector``: the time for collecting frames in the\
training collector;
* ``train_time/model``: the time for training models;
* ``train_speed``: the speed of training (frames per second);
* ``test_step``: the total collected step of test collector;
* ``test_episode``: the total collected episode of test collector;
* ``test_time``: the time for testing;
* ``test_speed``: the speed of testing (frames per second);
* ``best_reward``: the best reward over the test results;
* ``duration``: the total elapsed time.
"""
duration = time.time() - start_time
model_time = duration - train_c.collect_time - test_c.collect_time
train_speed = train_c.collect_step / (duration - test_c.collect_time)

View File

@ -3,14 +3,35 @@ import numpy as np
class MovAvg(object):
"""
Class for moving average. Usage:
::
>>> stat = MovAvg(size=66)
>>> stat.add(torch.tensor(5))
5.0
>>> stat.add(float('inf')) # which will not add to stat
5.0
>>> stat.add([6, 7, 8])
6.5
>>> stat.get()
6.5
>>> print(f'{stat.mean():.2f}±{stat.std():.2f}')
6.50±1.12
"""
def __init__(self, size=100):
super().__init__()
self.size = size
self.cache = []
def add(self, x):
"""
Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with
only one element, a python scalar, or a list of python scalar. It will
exclude the infinity.
"""
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
x = x.item()
if isinstance(x, list):
for _ in x:
if _ != np.inf:
@ -22,14 +43,23 @@ class MovAvg(object):
return self.get()
def get(self):
"""
Get the average.
"""
if len(self.cache) == 0:
return 0
return np.mean(self.cache)
def mean(self):
"""
Get the average. Same as :meth:`get`.
"""
return self.get()
def std(self):
"""
Get the standard deviation.
"""
if len(self.cache) == 0:
return 0
return np.std(self.cache)