add docs of collector and trainer (#20)
This commit is contained in:
parent
4d4d0daf9e
commit
610390c132
@ -49,8 +49,8 @@ If no error occurs, you have successfully installed Tianshou.
|
|||||||
|
|
||||||
tutorials/dqn
|
tutorials/dqn
|
||||||
tutorials/concepts
|
tutorials/concepts
|
||||||
tutorials/trick
|
|
||||||
tutorials/tabular
|
tutorials/tabular
|
||||||
|
tutorials/trick
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
@ -23,7 +23,7 @@ Data Buffer
|
|||||||
:members:
|
:members:
|
||||||
:noindex:
|
:noindex:
|
||||||
|
|
||||||
Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out the :doc:`/api/tianshou.data` API documentation for more detail.
|
Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out :class:`~tianshou.data.ReplayBuffer` for more detail.
|
||||||
|
|
||||||
|
|
||||||
Policy
|
Policy
|
||||||
@ -85,7 +85,41 @@ Thus, we need a time-related interface for calculating the 2-step return. :meth:
|
|||||||
|
|
||||||
This code does not consider the done flag, so it may not work very well. It shows two ways to get :math:`s_{t + 2}` from the replay buffer easily in :meth:`~tianshou.policy.BasePolicy.process_fn`.
|
This code does not consider the done flag, so it may not work very well. It shows two ways to get :math:`s_{t + 2}` from the replay buffer easily in :meth:`~tianshou.policy.BasePolicy.process_fn`.
|
||||||
|
|
||||||
For other method, you can check out the API documentation for more detail. We give a high-level explanation through the same pseudocode:
|
For other method, you can check out :doc:`/api/tianshou.policy`. We give the usage of policy class a high-level explanation in :ref:`pseudocode`.
|
||||||
|
|
||||||
|
|
||||||
|
Collector
|
||||||
|
---------
|
||||||
|
|
||||||
|
The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently.
|
||||||
|
In short, :class:`~tianshou.data.Collector` has two main methods:
|
||||||
|
|
||||||
|
* :meth:`~tianshou.data.Collector.collect`: let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer;
|
||||||
|
* :meth:`~tianshou.data.Collector.sample`: sample a data batch from replay buffer; it will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data.
|
||||||
|
|
||||||
|
Why do we mention **at least** here? For a single environment, the collector will finish exactly ``n_step`` or ``n_episode``. However, for multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically.
|
||||||
|
|
||||||
|
The solution is to add some cache buffers inside the collector. Once collecting **a full episode of trajectory**, it will move the stored data from the cache buffer to the main buffer. To satisfy this condition, the collector will interact with environments that may exceed the given step number or episode number.
|
||||||
|
|
||||||
|
The general explanation is listed in :ref:`pseudocode`. Other usages of collector are listed in :class:`~tianshou.data.Collector` documentation.
|
||||||
|
|
||||||
|
|
||||||
|
Trainer
|
||||||
|
-------
|
||||||
|
|
||||||
|
Once you have a collector and a policy, you can start writing the training method for your RL agent. Trainer, to be honest, is a simple wrapper. It helps you save energy for writing the training loop. You can also construct your own trainer: :ref:`customized_trainer`.
|
||||||
|
|
||||||
|
Tianshou has two types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` and :func:`~tianshou.trainer.offpolicy_trainer`, corresponding to on-policy algorithms (such as Policy Gradient) and off-policy algorithms (such as DQN). Please check out :doc:`/api/tianshou.trainer` for the usage.
|
||||||
|
|
||||||
|
There will be more types of trainers, for instance, multi-agent trainer.
|
||||||
|
|
||||||
|
|
||||||
|
.. _pseudocode:
|
||||||
|
|
||||||
|
A High-level Explanation
|
||||||
|
------------------------
|
||||||
|
|
||||||
|
We give a high-level explanation through the pseudocode used in section Policy:
|
||||||
::
|
::
|
||||||
|
|
||||||
# pseudocode, cannot work # methods in tianshou
|
# pseudocode, cannot work # methods in tianshou
|
||||||
@ -105,31 +139,6 @@ For other method, you can check out the API documentation for more detail. We gi
|
|||||||
agent.update(b_s, b_a, b_s_, b_r, b_d, b_ret) # done in policy.learn(batch, ...)
|
agent.update(b_s, b_a, b_s_, b_r, b_d, b_ret) # done in policy.learn(batch, ...)
|
||||||
|
|
||||||
|
|
||||||
Collector
|
|
||||||
---------
|
|
||||||
|
|
||||||
The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently.
|
|
||||||
|
|
||||||
* :meth:`~tianshou.data.Collector.collect`: let the policy perform (at least) a specified number of steps ``n_step`` or episodes ``n_episode`` and store the data in the replay buffer;
|
|
||||||
* :meth:`~tianshou.data.Collector.sample`: sample a data batch from replay buffer; it will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data.
|
|
||||||
|
|
||||||
Why do we mention **at least** here? For a single environment, the collector will finish exactly ``n_step`` or ``n_episode``. However, for multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically.
|
|
||||||
|
|
||||||
The solution is to add some cache buffers inside the collector. Once collecting **a full episode of trajectory**, it will move the stored data from the cache buffer to the main buffer. To satisfy this condition, the collector will interact with environments that may exceed the given step number or episode number.
|
|
||||||
|
|
||||||
The general explanation is listed in the pseudocode above. Other usages of collector are listed in :class:`~tianshou.data.Collector` documentation.
|
|
||||||
|
|
||||||
|
|
||||||
Trainer
|
|
||||||
-------
|
|
||||||
|
|
||||||
Once you have a collector and a policy, you can start writing the training method for your RL agent. Trainer, to be honest, is a simple wrapper. It helps you save energy for writing the training loop. You can also construct your own trainer: :ref:`customized_trainer`.
|
|
||||||
|
|
||||||
Tianshou has two types of trainer: :meth:`~tianshou.trainer.onpolicy_trainer` and :meth:`~tianshou.trainer.offpolicy_trainer`, corresponding to on-policy algorithms (such as Policy Gradient) and off-policy algorithms (such as DQN). Please check out the API documentation for the usage.
|
|
||||||
|
|
||||||
There will be more types of trainers, for instance, multi-agent trainer.
|
|
||||||
|
|
||||||
|
|
||||||
Conclusion
|
Conclusion
|
||||||
----------
|
----------
|
||||||
|
|
||||||
|
@ -122,7 +122,7 @@ The meaning of each parameter is as follows:
|
|||||||
* ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``;
|
* ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``;
|
||||||
* ``step_per_epoch``: The number of step for updating policy network in one epoch;
|
* ``step_per_epoch``: The number of step for updating policy network in one epoch;
|
||||||
* ``collect_per_step``: The number of frames the collector would collect before the network update. For example, the code above means "collect 10 frames and do one policy network update";
|
* ``collect_per_step``: The number of frames the collector would collect before the network update. For example, the code above means "collect 10 frames and do one policy network update";
|
||||||
* ``episode_per_test``: The number of episode for one policy evaluation.
|
* ``episode_per_test``: The number of episodes for one policy evaluation.
|
||||||
* ``batch_size``: The batch size of sample data, which is going to feed in the policy network.
|
* ``batch_size``: The batch size of sample data, which is going to feed in the policy network.
|
||||||
* ``train_fn``: A function receives the current number of epoch index and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training".
|
* ``train_fn``: A function receives the current number of epoch index and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training".
|
||||||
* ``test_fn``: A function receives the current number of epoch index and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing".
|
* ``test_fn``: A function receives the current number of epoch index and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing".
|
||||||
|
@ -3,8 +3,7 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
class Batch(object):
|
class Batch(object):
|
||||||
"""
|
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data
|
||||||
Tianshou provides :class:`~tianshou.data.Batch` as the internal data
|
|
||||||
structure to pass any kind of data to other methods, for example, a
|
structure to pass any kind of data to other methods, for example, a
|
||||||
collector gives a :class:`~tianshou.data.Batch` to policy for learning.
|
collector gives a :class:`~tianshou.data.Batch` to policy for learning.
|
||||||
Here is the usage:
|
Here is the usage:
|
||||||
@ -25,12 +24,12 @@ class Batch(object):
|
|||||||
current implementation of Tianshou typically use 6 keys in
|
current implementation of Tianshou typically use 6 keys in
|
||||||
:class:`~tianshou.data.Batch`:
|
:class:`~tianshou.data.Batch`:
|
||||||
|
|
||||||
* ``obs``: the observation of step :math:`t` ;
|
* ``obs`` the observation of step :math:`t` ;
|
||||||
* ``act``: the action of step :math:`t` ;
|
* ``act`` the action of step :math:`t` ;
|
||||||
* ``rew``: the reward of step :math:`t` ;
|
* ``rew`` the reward of step :math:`t` ;
|
||||||
* ``done``: the done flag of step :math:`t` ;
|
* ``done`` the done flag of step :math:`t` ;
|
||||||
* ``obs_next``: the observation of step :math:`t+1` ;
|
* ``obs_next`` the observation of step :math:`t+1` ;
|
||||||
* ``info``: the info of step :math:`t` (in ``gym.Env``, the ``env.step()``\
|
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()``\
|
||||||
function return 4 arguments, and the last one is ``info``);
|
function return 4 arguments, and the last one is ``info``);
|
||||||
|
|
||||||
:class:`~tianshou.data.Batch` has other methods, including
|
:class:`~tianshou.data.Batch` has other methods, including
|
||||||
@ -75,7 +74,7 @@ class Batch(object):
|
|||||||
return b
|
return b
|
||||||
|
|
||||||
def append(self, batch):
|
def append(self, batch):
|
||||||
"""Append a :class:`~tianshou.data.Batch` object to the end."""
|
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
|
||||||
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
|
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
|
||||||
for k in batch.__dict__.keys():
|
for k in batch.__dict__.keys():
|
||||||
if batch.__dict__[k] is None:
|
if batch.__dict__[k] is None:
|
||||||
@ -103,12 +102,11 @@ class Batch(object):
|
|||||||
if self.__dict__[k] is not None])
|
if self.__dict__[k] is not None])
|
||||||
|
|
||||||
def split(self, size=None, permute=True):
|
def split(self, size=None, permute=True):
|
||||||
"""
|
"""Split whole data into multiple small batch.
|
||||||
Split whole data into multiple small batch.
|
|
||||||
|
|
||||||
:param size: if equals to ``None``, it does not split the data batch; \
|
:param size: if it is ``None``, it does not split the data batch;
|
||||||
otherwise it will divide the data batch with the given size.
|
otherwise it will divide the data batch with the given size.
|
||||||
:param permute: randomly shuffle the entire data batch if it equals to\
|
:param permute: randomly shuffle the entire data batch if it is
|
||||||
``True``, otherwise remain in the same.
|
``True``, otherwise remain in the same.
|
||||||
"""
|
"""
|
||||||
length = len(self)
|
length = len(self)
|
||||||
|
@ -3,17 +3,18 @@ from tianshou.data.batch import Batch
|
|||||||
|
|
||||||
|
|
||||||
class ReplayBuffer(object):
|
class ReplayBuffer(object):
|
||||||
"""
|
""":class:`~tianshou.data.ReplayBuffer` stores data generated from
|
||||||
:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction
|
interaction between the policy and environment. It stores basically 6 types
|
||||||
between the policy and environment. It stores basically 6 types of data, as
|
of data, as mentioned in :class:`~tianshou.data.Batch`, based on
|
||||||
mentioned in :class:`~tianshou.data.Batch`, based on ``numpy.ndarray``.
|
``numpy.ndarray``. Here is the usage:
|
||||||
Here is the usage:
|
|
||||||
::
|
::
|
||||||
|
|
||||||
>>> from tianshou.data import ReplayBuffer
|
>>> from tianshou.data import ReplayBuffer
|
||||||
>>> buf = ReplayBuffer(size=20)
|
>>> buf = ReplayBuffer(size=20)
|
||||||
>>> for i in range(3):
|
>>> for i in range(3):
|
||||||
... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
|
... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
|
||||||
|
>>> len(buf)
|
||||||
|
3
|
||||||
>>> buf.obs
|
>>> buf.obs
|
||||||
# since we set size = 20, len(buf.obs) == 20.
|
# since we set size = 20, len(buf.obs) == 20.
|
||||||
array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||||
@ -22,11 +23,13 @@ class ReplayBuffer(object):
|
|||||||
>>> buf2 = ReplayBuffer(size=10)
|
>>> buf2 = ReplayBuffer(size=10)
|
||||||
>>> for i in range(15):
|
>>> for i in range(15):
|
||||||
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
|
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
|
||||||
|
>>> len(buf2)
|
||||||
|
10
|
||||||
>>> buf2.obs
|
>>> buf2.obs
|
||||||
# since its size = 10, it only stores the last 10 steps' result.
|
# since its size = 10, it only stores the last 10 steps' result.
|
||||||
array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.])
|
array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.])
|
||||||
|
|
||||||
>>> # move buf2's result into buf (keep it chronologically meanwhile)
|
>>> # move buf2's result into buf (meanwhile keep it chronologically)
|
||||||
>>> buf.update(buf2)
|
>>> buf.update(buf2)
|
||||||
array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
|
array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
|
||||||
0., 0., 0., 0., 0., 0., 0.])
|
0., 0., 0., 0., 0., 0., 0.])
|
||||||
@ -96,8 +99,8 @@ class ReplayBuffer(object):
|
|||||||
self.indice = []
|
self.indice = []
|
||||||
|
|
||||||
def sample(self, batch_size):
|
def sample(self, batch_size):
|
||||||
"""
|
"""Get a random sample from buffer with size equal to batch_size. \
|
||||||
Get a random sample from buffer with size = ``batch_size``
|
Return all the data in the buffer if batch_size is ``0``.
|
||||||
|
|
||||||
:return: Sample data and its corresponding index inside the buffer.
|
:return: Sample data and its corresponding index inside the buffer.
|
||||||
"""
|
"""
|
||||||
@ -123,9 +126,8 @@ class ReplayBuffer(object):
|
|||||||
|
|
||||||
|
|
||||||
class ListReplayBuffer(ReplayBuffer):
|
class ListReplayBuffer(ReplayBuffer):
|
||||||
"""
|
"""The function of :class:`~tianshou.data.ListReplayBuffer` is almost the
|
||||||
The function of :class:`~tianshou.data.ListReplayBuffer` is almost the same
|
same as :class:`~tianshou.data.ReplayBuffer`. The only difference is that
|
||||||
as :class:`~tianshou.data.ReplayBuffer`. The only difference is that
|
|
||||||
:class:`~tianshou.data.ListReplayBuffer` is based on ``list``.
|
:class:`~tianshou.data.ListReplayBuffer` is based on ``list``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -9,7 +9,51 @@ from tianshou.utils import MovAvg
|
|||||||
|
|
||||||
|
|
||||||
class Collector(object):
|
class Collector(object):
|
||||||
"""docstring for Collector"""
|
"""The :class:`~tianshou.data.Collector` enables the policy to interact
|
||||||
|
with different types of environments conveniently. Here is the usage:
|
||||||
|
::
|
||||||
|
|
||||||
|
policy = PGPolicy(...) # or other policies if you wish
|
||||||
|
env = gym.make('CartPole-v0')
|
||||||
|
replay_buffer = ReplayBuffer(size=10000)
|
||||||
|
# here we set up a collector with a single environment
|
||||||
|
collector = Collector(policy, env, buffer=replay_buffer)
|
||||||
|
|
||||||
|
# the collector supports vectorized environments as well
|
||||||
|
envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)])
|
||||||
|
buffers = [ReplayBuffer(size=5000) for _ in range(3)]
|
||||||
|
# you can also pass a list of replay buffer to collector, for multi-env
|
||||||
|
# collector = Collector(policy, envs, buffer=buffers)
|
||||||
|
collector = Collector(policy, envs, buffer=replay_buffer)
|
||||||
|
|
||||||
|
# collect at least 3 episodes
|
||||||
|
collector.collect(n_episode=3)
|
||||||
|
# collect 1 episode for the first env, 3 for the third env
|
||||||
|
collector.collect(n_episode=[1, 0, 3])
|
||||||
|
# collect at least 2 steps
|
||||||
|
collector.collect(n_step=2)
|
||||||
|
# collect episodes with visual rendering (the render argument is the
|
||||||
|
# sleep time between rendering consecutive frames)
|
||||||
|
collector.collect(n_episode=1, render=0.03)
|
||||||
|
|
||||||
|
# sample data with a given number of batch-size:
|
||||||
|
batch_data = collector.sample(batch_size=64)
|
||||||
|
# policy.learn(batch_data) # btw, vanilla policy gradient only
|
||||||
|
# supports on-policy training, so here we pick all data in the buffer
|
||||||
|
batch_data = collector.sample(batch_size=0)
|
||||||
|
policy.learn(batch_data)
|
||||||
|
# on-policy algorithms use the collected data only once, so here we
|
||||||
|
# clear the buffer
|
||||||
|
collector.reset_buffer()
|
||||||
|
|
||||||
|
For the scenario of collecting data from multiple environments to a single
|
||||||
|
buffer, the cache buffers will turn on automatically. It may return the
|
||||||
|
data more than the given limitation.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Please make sure the given environment has a time limitation.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, policy, env, buffer=None, stat_size=100):
|
def __init__(self, policy, env, buffer=None, stat_size=100):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -48,6 +92,7 @@ class Collector(object):
|
|||||||
self.episode_speed = MovAvg(stat_size)
|
self.episode_speed = MovAvg(stat_size)
|
||||||
|
|
||||||
def reset_buffer(self):
|
def reset_buffer(self):
|
||||||
|
"""Reset the main data buffer."""
|
||||||
if self._multi_buf:
|
if self._multi_buf:
|
||||||
for b in self.buffer:
|
for b in self.buffer:
|
||||||
b.reset()
|
b.reset()
|
||||||
@ -55,9 +100,13 @@ class Collector(object):
|
|||||||
self.buffer.reset()
|
self.buffer.reset()
|
||||||
|
|
||||||
def get_env_num(self):
|
def get_env_num(self):
|
||||||
|
"""Return the number of environments the collector has."""
|
||||||
return self.env_num
|
return self.env_num
|
||||||
|
|
||||||
def reset_env(self):
|
def reset_env(self):
|
||||||
|
"""Reset all of the environment(s)' states and reset all of the cache
|
||||||
|
buffers (if need).
|
||||||
|
"""
|
||||||
self._obs = self.env.reset()
|
self._obs = self.env.reset()
|
||||||
self._act = self._rew = self._done = self._info = None
|
self._act = self._rew = self._done = self._info = None
|
||||||
if self._multi_env:
|
if self._multi_env:
|
||||||
@ -69,14 +118,17 @@ class Collector(object):
|
|||||||
b.reset()
|
b.reset()
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
|
"""Reset all the seed(s) of the given environment(s)."""
|
||||||
if hasattr(self.env, 'seed'):
|
if hasattr(self.env, 'seed'):
|
||||||
return self.env.seed(seed)
|
return self.env.seed(seed)
|
||||||
|
|
||||||
def render(self, **kwargs):
|
def render(self, **kwargs):
|
||||||
|
"""Render all the environment(s)."""
|
||||||
if hasattr(self.env, 'render'):
|
if hasattr(self.env, 'render'):
|
||||||
return self.env.render(**kwargs)
|
return self.env.render(**kwargs)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
"""Close the environment(s)."""
|
||||||
if hasattr(self.env, 'close'):
|
if hasattr(self.env, 'close'):
|
||||||
self.env.close()
|
self.env.close()
|
||||||
|
|
||||||
@ -87,12 +139,34 @@ class Collector(object):
|
|||||||
return np.array([data])
|
return np.array([data])
|
||||||
|
|
||||||
def collect(self, n_step=0, n_episode=0, render=0):
|
def collect(self, n_step=0, n_episode=0, render=0):
|
||||||
|
"""Collect a specified number of step or episode.
|
||||||
|
|
||||||
|
:param n_step: an int, indicates how many steps you want to collect.
|
||||||
|
:param n_episode: an int or a list, indicates how many episodes you
|
||||||
|
want to collect (in each environment).
|
||||||
|
:param render: a float, the sleep time between rendering consecutive
|
||||||
|
frames. ``0`` means no rendering.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
One and only one collection number specification is permitted,
|
||||||
|
either ``n_step`` or ``n_episode``.
|
||||||
|
|
||||||
|
:return: A dict including the following keys
|
||||||
|
|
||||||
|
* ``n/ep`` the collected number of episodes.
|
||||||
|
* ``n/st`` the collected number of steps.
|
||||||
|
* ``v/st`` the speed of steps per second.
|
||||||
|
* ``v/ep`` the speed of episode per second.
|
||||||
|
* ``rew`` the mean reward over collected episodes.
|
||||||
|
* ``len`` the mean length over collected episodes.
|
||||||
|
"""
|
||||||
warning_count = 0
|
warning_count = 0
|
||||||
if not self._multi_env:
|
if not self._multi_env:
|
||||||
n_episode = np.sum(n_episode)
|
n_episode = np.sum(n_episode)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
assert sum([(n_step != 0), (n_episode != 0)]) == 1, \
|
assert sum([(n_step != 0), (n_episode != 0)]) == 1, \
|
||||||
"One and only one collection number specification permitted!"
|
"One and only one collection number specification is permitted!"
|
||||||
cur_step = 0
|
cur_step = 0
|
||||||
cur_episode = np.zeros(self.env_num) if self._multi_env else 0
|
cur_episode = np.zeros(self.env_num) if self._multi_env else 0
|
||||||
reward_sum = 0
|
reward_sum = 0
|
||||||
@ -218,6 +292,14 @@ class Collector(object):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def sample(self, batch_size):
|
def sample(self, batch_size):
|
||||||
|
"""Sample a data batch from the internal replay buffer. It will call
|
||||||
|
:meth:`~tianshou.policy.BasePolicy.process_fn` before returning
|
||||||
|
the final batch data.
|
||||||
|
|
||||||
|
:param batch_size: an int, ``0`` means it will extract all the data
|
||||||
|
from the buffer, otherwise it will extract the given batch_size of
|
||||||
|
data.
|
||||||
|
"""
|
||||||
if self._multi_buf:
|
if self._multi_buf:
|
||||||
if batch_size > 0:
|
if batch_size > 0:
|
||||||
lens = [len(b) for b in self.buffer]
|
lens = [len(b) for b in self.buffer]
|
||||||
|
59
tianshou/env/vecenv.py
vendored
59
tianshou/env/vecenv.py
vendored
@ -12,12 +12,12 @@ from tianshou.env.utils import CloudpickleWrapper
|
|||||||
|
|
||||||
|
|
||||||
class BaseVectorEnv(ABC, gym.Wrapper):
|
class BaseVectorEnv(ABC, gym.Wrapper):
|
||||||
"""
|
"""Base class for vectorized environments wrapper. Usage:
|
||||||
Base class for vectorized environments wrapper. Usage:
|
|
||||||
::
|
::
|
||||||
|
|
||||||
env_num = 8
|
env_num = 8
|
||||||
envs = VectorEnv([lambda: gym.make(task) for _ in range(env_num)])
|
envs = VectorEnv([lambda: gym.make(task) for _ in range(env_num)])
|
||||||
|
assert len(envs) == env_num
|
||||||
|
|
||||||
It accepts a list of environment generators. In other words, an environment
|
It accepts a list of environment generators. In other words, an environment
|
||||||
generator ``efn`` of a specific task means that ``efn()`` returns the
|
generator ``efn`` of a specific task means that ``efn()`` returns the
|
||||||
@ -46,8 +46,7 @@ class BaseVectorEnv(ABC, gym.Wrapper):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def reset(self, id=None):
|
def reset(self, id=None):
|
||||||
"""
|
"""Reset the state of all the environments and return initial
|
||||||
Reset the state of all the environments and returns initial
|
|
||||||
observations if id is ``None``, otherwise reset the specific
|
observations if id is ``None``, otherwise reset the specific
|
||||||
environments with given id, either an int or a list.
|
environments with given id, either an int or a list.
|
||||||
"""
|
"""
|
||||||
@ -55,38 +54,38 @@ class BaseVectorEnv(ABC, gym.Wrapper):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""
|
"""Run one timestep of all the environments’ dynamics. When the end of
|
||||||
Run one timestep of all the environments’ dynamics. When end of episode
|
episode is reached, you are responsible for calling reset(id) to reset
|
||||||
is reached, you are responsible for calling reset(id) to reset this
|
this environment’s state.
|
||||||
environment’s state.
|
|
||||||
|
|
||||||
Accepts a batch of action and returns a tuple (obs, rew, done, info).
|
Accept a batch of action and return a tuple (obs, rew, done, info).
|
||||||
|
|
||||||
:args:
|
:param action: a numpy.ndarray, a batch of action provided by the
|
||||||
action (numpy.ndarray): a batch of action provided by the agent
|
agent.
|
||||||
|
|
||||||
:return:
|
:return: A tuple including four items:
|
||||||
* obs (numpy.ndarray): agent's observation of current environments
|
|
||||||
* rew (numpy.ndarray) : amount of rewards returned after previous \
|
* ``obs`` a numpy.ndarray, the agent's observation of current \
|
||||||
actions
|
environments
|
||||||
* done (numpy.ndarray): whether these episodes have ended, in \
|
* ``rew`` a numpy.ndarray, the amount of rewards returned after \
|
||||||
|
previous actions
|
||||||
|
* ``done`` a numpy.ndarray, whether these episodes have ended, in \
|
||||||
which case further step() calls will return undefined results
|
which case further step() calls will return undefined results
|
||||||
* info (numpy.ndarray): contains auxiliary diagnostic information \
|
* ``info`` a numpy.ndarray, contains auxiliary diagnostic \
|
||||||
(helpful for debugging, and sometimes learning)
|
information (helpful for debugging, and sometimes learning)
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
"""
|
"""Set the seed for all environments. Accept ``None``, an int (which
|
||||||
Set the seed for all environments. Accept ``None``, an int (which will
|
will extend ``i`` to ``[i, i + 1, i + 2, ...]``) or a list.
|
||||||
extend ``i`` to ``[i, i + 1, i + 2, ...]``) or a list.
|
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def render(self, **kwargs):
|
def render(self, **kwargs):
|
||||||
"""Renders the environment."""
|
"""Render all of the environments."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -96,8 +95,7 @@ class BaseVectorEnv(ABC, gym.Wrapper):
|
|||||||
|
|
||||||
|
|
||||||
class VectorEnv(BaseVectorEnv):
|
class VectorEnv(BaseVectorEnv):
|
||||||
"""
|
"""Dummy vectorized environment wrapper, implemented in for-loop. The usage
|
||||||
Dummy vectorized environment wrapper, implemented in for-loop. The usage \
|
|
||||||
is in :class:`~tianshou.env.BaseVectorEnv`.
|
is in :class:`~tianshou.env.BaseVectorEnv`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -173,8 +171,7 @@ def worker(parent, p, env_fn_wrapper):
|
|||||||
|
|
||||||
|
|
||||||
class SubprocVectorEnv(BaseVectorEnv):
|
class SubprocVectorEnv(BaseVectorEnv):
|
||||||
"""
|
"""Vectorized environment wrapper based on subprocess. The usage is in
|
||||||
Vectorized environment wrapper based on subprocess. The usage is in \
|
|
||||||
:class:`~tianshou.env.BaseVectorEnv`.
|
:class:`~tianshou.env.BaseVectorEnv`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -248,11 +245,11 @@ class SubprocVectorEnv(BaseVectorEnv):
|
|||||||
|
|
||||||
|
|
||||||
class RayVectorEnv(BaseVectorEnv):
|
class RayVectorEnv(BaseVectorEnv):
|
||||||
"""
|
"""Vectorized environment wrapper based on
|
||||||
Vectorized environment wrapper based on \
|
`ray <https://github.com/ray-project/ray>`_. However, according to our
|
||||||
`ray <https://github.com/ray-project/ray>`_. However, according to our \
|
test, it is about two times slower than
|
||||||
test, it is slower than :class:`~tianshou.env.SubprocVectorEnv`. The usage\
|
:class:`~tianshou.env.SubprocVectorEnv`. The usage is in
|
||||||
is in :class:`~tianshou.env.BaseVectorEnv`.
|
:class:`~tianshou.env.BaseVectorEnv`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env_fns):
|
def __init__(self, env_fns):
|
||||||
|
@ -2,8 +2,7 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
class OUNoise(object):
|
class OUNoise(object):
|
||||||
"""
|
"""Class for Ornstein-Uhlenbeck process, as used for exploration in DDPG.
|
||||||
Class for Ornstein-Uhlenbeck process, as used for exploration in DDPG.
|
|
||||||
Usage:
|
Usage:
|
||||||
::
|
::
|
||||||
|
|
||||||
@ -25,9 +24,8 @@ class OUNoise(object):
|
|||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def __call__(self, size, mu=.1):
|
def __call__(self, size, mu=.1):
|
||||||
"""
|
"""Generate new noise. Return a ``numpy.ndarray`` which size is equal
|
||||||
Generate new noise. Return a ``numpy.ndarray`` which size is equal to
|
to ``size``.
|
||||||
``size``.
|
|
||||||
"""
|
"""
|
||||||
if self.x is None or self.x.shape != size:
|
if self.x is None or self.x.shape != size:
|
||||||
self.x = 0
|
self.x = 0
|
||||||
@ -36,7 +34,5 @@ class OUNoise(object):
|
|||||||
return self.x
|
return self.x
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""Reset to the initial state."""
|
||||||
Reset to the initial state.
|
|
||||||
"""
|
|
||||||
self.x = None
|
self.x = None
|
||||||
|
@ -64,7 +64,7 @@ class DDPGPolicy(BasePolicy):
|
|||||||
|
|
||||||
def process_fn(self, batch, buffer, indice):
|
def process_fn(self, batch, buffer, indice):
|
||||||
if self._rew_norm:
|
if self._rew_norm:
|
||||||
bfr = buffer.rew[:len(buffer)]
|
bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer
|
||||||
mean, std = bfr.mean(), bfr.std()
|
mean, std = bfr.mean(), bfr.std()
|
||||||
if std > self.__eps:
|
if std > self.__eps:
|
||||||
batch.rew = (batch.rew - mean) / std
|
batch.rew = (batch.rew - mean) / std
|
||||||
|
@ -8,7 +8,41 @@ from tianshou.trainer import test_episode, gather_info
|
|||||||
def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||||
step_per_epoch, collect_per_step, episode_per_test,
|
step_per_epoch, collect_per_step, episode_per_test,
|
||||||
batch_size, train_fn=None, test_fn=None, stop_fn=None,
|
batch_size, train_fn=None, test_fn=None, stop_fn=None,
|
||||||
writer=None, log_interval=1, verbose=True, task=''):
|
writer=None, log_interval=1, verbose=True, task='',
|
||||||
|
**kwargs):
|
||||||
|
"""A wrapper for off-policy trainer procedure.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
* **policy** – an instance of the :class:`~tianshou.policy.BasePolicy`\
|
||||||
|
class.
|
||||||
|
* **train_collector** – the collector used for training.
|
||||||
|
* **test_collector** – the collector used for testing.
|
||||||
|
* **max_epoch** – the maximum of epochs for training. The training \
|
||||||
|
process might be finished before reaching the ``max_epoch``.
|
||||||
|
* **step_per_epoch** – the number of step for updating policy network \
|
||||||
|
in one epoch.
|
||||||
|
* **collect_per_step** – the number of frames the collector would \
|
||||||
|
collect before the network update. In other words, collect some \
|
||||||
|
frames and do one policy network update.
|
||||||
|
* **episode_per_test** – the number of episodes for one policy \
|
||||||
|
evaluation.
|
||||||
|
* **batch_size** – the batch size of sample data, which is going to \
|
||||||
|
feed in the policy network.
|
||||||
|
* **train_fn** – a function receives the current number of epoch index\
|
||||||
|
and performs some operations at the beginning of training in this \
|
||||||
|
epoch.
|
||||||
|
* **test_fn** – a function receives the current number of epoch index \
|
||||||
|
and performs some operations at the beginning of testing in this \
|
||||||
|
epoch.
|
||||||
|
* **stop_fn** – a function receives the average undiscounted returns \
|
||||||
|
of the testing result, return a boolean which indicates whether \
|
||||||
|
reaching the goal.
|
||||||
|
* **writer** – a SummaryWriter provided from TensorBoard.
|
||||||
|
* **log_interval** – an int indicating the log interval of the writer.
|
||||||
|
* **verbose** – a boolean indicating whether to print the information.
|
||||||
|
|
||||||
|
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||||
|
"""
|
||||||
global_step = 0
|
global_step = 0
|
||||||
best_epoch, best_reward = -1, -1
|
best_epoch, best_reward = -1, -1
|
||||||
stat = {}
|
stat = {}
|
||||||
|
@ -9,7 +9,44 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
step_per_epoch, collect_per_step, repeat_per_collect,
|
step_per_epoch, collect_per_step, repeat_per_collect,
|
||||||
episode_per_test, batch_size,
|
episode_per_test, batch_size,
|
||||||
train_fn=None, test_fn=None, stop_fn=None,
|
train_fn=None, test_fn=None, stop_fn=None,
|
||||||
writer=None, log_interval=1, verbose=True, task=''):
|
writer=None, log_interval=1, verbose=True, task='',
|
||||||
|
**kwargs):
|
||||||
|
"""A wrapper for on-policy trainer procedure.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
* **policy** – an instance of the :class:`~tianshou.policy.BasePolicy`\
|
||||||
|
class.
|
||||||
|
* **train_collector** – the collector used for training.
|
||||||
|
* **test_collector** – the collector used for testing.
|
||||||
|
* **max_epoch** – the maximum of epochs for training. The training \
|
||||||
|
process might be finished before reaching the ``max_epoch``.
|
||||||
|
* **step_per_epoch** – the number of step for updating policy network \
|
||||||
|
in one epoch.
|
||||||
|
* **collect_per_step** – the number of frames the collector would \
|
||||||
|
collect before the network update. In other words, collect some \
|
||||||
|
frames and do one policy network update.
|
||||||
|
* **repeat_per_collect** – the number of repeat time for policy \
|
||||||
|
learning, for example, set it to 2 means the policy needs to learn\
|
||||||
|
each given batch data twice.
|
||||||
|
* **episode_per_test** – the number of episodes for one policy \
|
||||||
|
evaluation.
|
||||||
|
* **batch_size** – the batch size of sample data, which is going to \
|
||||||
|
feed in the policy network.
|
||||||
|
* **train_fn** – a function receives the current number of epoch index\
|
||||||
|
and performs some operations at the beginning of training in this \
|
||||||
|
epoch.
|
||||||
|
* **test_fn** – a function receives the current number of epoch index \
|
||||||
|
and performs some operations at the beginning of testing in this \
|
||||||
|
epoch.
|
||||||
|
* **stop_fn** – a function receives the average undiscounted returns \
|
||||||
|
of the testing result, return a boolean which indicates whether \
|
||||||
|
reaching the goal.
|
||||||
|
* **writer** – a SummaryWriter provided from TensorBoard.
|
||||||
|
* **log_interval** – an int indicating the log interval of the writer.
|
||||||
|
* **verbose** – a boolean indicating whether to print the information.
|
||||||
|
|
||||||
|
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||||
|
"""
|
||||||
global_step = 0
|
global_step = 0
|
||||||
best_epoch, best_reward = -1, -1
|
best_epoch, best_reward = -1, -1
|
||||||
stat = {}
|
stat = {}
|
||||||
|
@ -18,23 +18,22 @@ def test_episode(policy, collector, test_fn, epoch, n_episode):
|
|||||||
|
|
||||||
|
|
||||||
def gather_info(start_time, train_c, test_c, best_reward):
|
def gather_info(start_time, train_c, test_c, best_reward):
|
||||||
"""
|
"""A simple wrapper of gathering information from collectors.
|
||||||
A simple wrapper of gathering information from collectors.
|
|
||||||
|
|
||||||
:return: A dictionary with following keys:
|
:return: A dictionary with the following keys:
|
||||||
|
|
||||||
* ``train_step``: the total collected step of training collector;
|
* ``train_step`` the total collected step of training collector;
|
||||||
* ``train_episode``: the total collected episode of training collector;
|
* ``train_episode`` the total collected episode of training collector;
|
||||||
* ``train_time/collector``: the time for collecting frames in the\
|
* ``train_time/collector`` the time for collecting frames in the \
|
||||||
training collector;
|
training collector;
|
||||||
* ``train_time/model``: the time for training models;
|
* ``train_time/model`` the time for training models;
|
||||||
* ``train_speed``: the speed of training (frames per second);
|
* ``train_speed`` the speed of training (frames per second);
|
||||||
* ``test_step``: the total collected step of test collector;
|
* ``test_step`` the total collected step of test collector;
|
||||||
* ``test_episode``: the total collected episode of test collector;
|
* ``test_episode`` the total collected episode of test collector;
|
||||||
* ``test_time``: the time for testing;
|
* ``test_time`` the time for testing;
|
||||||
* ``test_speed``: the speed of testing (frames per second);
|
* ``test_speed`` the speed of testing (frames per second);
|
||||||
* ``best_reward``: the best reward over the test results;
|
* ``best_reward`` the best reward over the test results;
|
||||||
* ``duration``: the total elapsed time.
|
* ``duration`` the total elapsed time.
|
||||||
"""
|
"""
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
model_time = duration - train_c.collect_time - test_c.collect_time
|
model_time = duration - train_c.collect_time - test_c.collect_time
|
||||||
|
@ -3,8 +3,7 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
class MovAvg(object):
|
class MovAvg(object):
|
||||||
"""
|
"""Class for moving average. Usage:
|
||||||
Class for moving average. Usage:
|
|
||||||
::
|
::
|
||||||
|
|
||||||
>>> stat = MovAvg(size=66)
|
>>> stat = MovAvg(size=66)
|
||||||
@ -25,8 +24,7 @@ class MovAvg(object):
|
|||||||
self.cache = []
|
self.cache = []
|
||||||
|
|
||||||
def add(self, x):
|
def add(self, x):
|
||||||
"""
|
"""Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with
|
||||||
Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with
|
|
||||||
only one element, a python scalar, or a list of python scalar. It will
|
only one element, a python scalar, or a list of python scalar. It will
|
||||||
exclude the infinity.
|
exclude the infinity.
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user