add docs of collector and trainer (#20)

This commit is contained in:
Trinkle23897 2020-04-05 18:34:45 +08:00
parent 4d4d0daf9e
commit 610390c132
13 changed files with 267 additions and 115 deletions

View File

@ -49,8 +49,8 @@ If no error occurs, you have successfully installed Tianshou.
tutorials/dqn
tutorials/concepts
tutorials/trick
tutorials/tabular
tutorials/trick
.. toctree::
:maxdepth: 1

View File

@ -23,7 +23,7 @@ Data Buffer
:members:
:noindex:
Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out the :doc:`/api/tianshou.data` API documentation for more detail.
Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out :class:`~tianshou.data.ReplayBuffer` for more detail.
Policy
@ -85,7 +85,41 @@ Thus, we need a time-related interface for calculating the 2-step return. :meth:
This code does not consider the done flag, so it may not work very well. It shows two ways to get :math:`s_{t + 2}` from the replay buffer easily in :meth:`~tianshou.policy.BasePolicy.process_fn`.
For other method, you can check out the API documentation for more detail. We give a high-level explanation through the same pseudocode:
For other method, you can check out :doc:`/api/tianshou.policy`. We give the usage of policy class a high-level explanation in :ref:`pseudocode`.
Collector
---------
The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently.
In short, :class:`~tianshou.data.Collector` has two main methods:
* :meth:`~tianshou.data.Collector.collect`: let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer;
* :meth:`~tianshou.data.Collector.sample`: sample a data batch from replay buffer; it will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data.
Why do we mention **at least** here? For a single environment, the collector will finish exactly ``n_step`` or ``n_episode``. However, for multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically.
The solution is to add some cache buffers inside the collector. Once collecting **a full episode of trajectory**, it will move the stored data from the cache buffer to the main buffer. To satisfy this condition, the collector will interact with environments that may exceed the given step number or episode number.
The general explanation is listed in :ref:`pseudocode`. Other usages of collector are listed in :class:`~tianshou.data.Collector` documentation.
Trainer
-------
Once you have a collector and a policy, you can start writing the training method for your RL agent. Trainer, to be honest, is a simple wrapper. It helps you save energy for writing the training loop. You can also construct your own trainer: :ref:`customized_trainer`.
Tianshou has two types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` and :func:`~tianshou.trainer.offpolicy_trainer`, corresponding to on-policy algorithms (such as Policy Gradient) and off-policy algorithms (such as DQN). Please check out :doc:`/api/tianshou.trainer` for the usage.
There will be more types of trainers, for instance, multi-agent trainer.
.. _pseudocode:
A High-level Explanation
------------------------
We give a high-level explanation through the pseudocode used in section Policy:
::
# pseudocode, cannot work # methods in tianshou
@ -105,31 +139,6 @@ For other method, you can check out the API documentation for more detail. We gi
agent.update(b_s, b_a, b_s_, b_r, b_d, b_ret) # done in policy.learn(batch, ...)
Collector
---------
The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently.
* :meth:`~tianshou.data.Collector.collect`: let the policy perform (at least) a specified number of steps ``n_step`` or episodes ``n_episode`` and store the data in the replay buffer;
* :meth:`~tianshou.data.Collector.sample`: sample a data batch from replay buffer; it will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data.
Why do we mention **at least** here? For a single environment, the collector will finish exactly ``n_step`` or ``n_episode``. However, for multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically.
The solution is to add some cache buffers inside the collector. Once collecting **a full episode of trajectory**, it will move the stored data from the cache buffer to the main buffer. To satisfy this condition, the collector will interact with environments that may exceed the given step number or episode number.
The general explanation is listed in the pseudocode above. Other usages of collector are listed in :class:`~tianshou.data.Collector` documentation.
Trainer
-------
Once you have a collector and a policy, you can start writing the training method for your RL agent. Trainer, to be honest, is a simple wrapper. It helps you save energy for writing the training loop. You can also construct your own trainer: :ref:`customized_trainer`.
Tianshou has two types of trainer: :meth:`~tianshou.trainer.onpolicy_trainer` and :meth:`~tianshou.trainer.offpolicy_trainer`, corresponding to on-policy algorithms (such as Policy Gradient) and off-policy algorithms (such as DQN). Please check out the API documentation for the usage.
There will be more types of trainers, for instance, multi-agent trainer.
Conclusion
----------

View File

@ -122,7 +122,7 @@ The meaning of each parameter is as follows:
* ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``;
* ``step_per_epoch``: The number of step for updating policy network in one epoch;
* ``collect_per_step``: The number of frames the collector would collect before the network update. For example, the code above means "collect 10 frames and do one policy network update";
* ``episode_per_test``: The number of episode for one policy evaluation.
* ``episode_per_test``: The number of episodes for one policy evaluation.
* ``batch_size``: The batch size of sample data, which is going to feed in the policy network.
* ``train_fn``: A function receives the current number of epoch index and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training".
* ``test_fn``: A function receives the current number of epoch index and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing".

View File

@ -3,8 +3,7 @@ import numpy as np
class Batch(object):
"""
Tianshou provides :class:`~tianshou.data.Batch` as the internal data
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data
structure to pass any kind of data to other methods, for example, a
collector gives a :class:`~tianshou.data.Batch` to policy for learning.
Here is the usage:
@ -25,12 +24,12 @@ class Batch(object):
current implementation of Tianshou typically use 6 keys in
:class:`~tianshou.data.Batch`:
* ``obs``: the observation of step :math:`t` ;
* ``act``: the action of step :math:`t` ;
* ``rew``: the reward of step :math:`t` ;
* ``done``: the done flag of step :math:`t` ;
* ``obs_next``: the observation of step :math:`t+1` ;
* ``info``: the info of step :math:`t` (in ``gym.Env``, the ``env.step()``\
* ``obs`` the observation of step :math:`t` ;
* ``act`` the action of step :math:`t` ;
* ``rew`` the reward of step :math:`t` ;
* ``done`` the done flag of step :math:`t` ;
* ``obs_next`` the observation of step :math:`t+1` ;
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()``\
function return 4 arguments, and the last one is ``info``);
:class:`~tianshou.data.Batch` has other methods, including
@ -75,7 +74,7 @@ class Batch(object):
return b
def append(self, batch):
"""Append a :class:`~tianshou.data.Batch` object to the end."""
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
for k in batch.__dict__.keys():
if batch.__dict__[k] is None:
@ -103,12 +102,11 @@ class Batch(object):
if self.__dict__[k] is not None])
def split(self, size=None, permute=True):
"""
Split whole data into multiple small batch.
"""Split whole data into multiple small batch.
:param size: if equals to ``None``, it does not split the data batch; \
:param size: if it is ``None``, it does not split the data batch;
otherwise it will divide the data batch with the given size.
:param permute: randomly shuffle the entire data batch if it equals to\
:param permute: randomly shuffle the entire data batch if it is
``True``, otherwise remain in the same.
"""
length = len(self)

View File

@ -3,17 +3,18 @@ from tianshou.data.batch import Batch
class ReplayBuffer(object):
"""
:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction
between the policy and environment. It stores basically 6 types of data, as
mentioned in :class:`~tianshou.data.Batch`, based on ``numpy.ndarray``.
Here is the usage:
""":class:`~tianshou.data.ReplayBuffer` stores data generated from
interaction between the policy and environment. It stores basically 6 types
of data, as mentioned in :class:`~tianshou.data.Batch`, based on
``numpy.ndarray``. Here is the usage:
::
>>> from tianshou.data import ReplayBuffer
>>> buf = ReplayBuffer(size=20)
>>> for i in range(3):
... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> len(buf)
3
>>> buf.obs
# since we set size = 20, len(buf.obs) == 20.
array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
@ -22,11 +23,13 @@ class ReplayBuffer(object):
>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> len(buf2)
10
>>> buf2.obs
# since its size = 10, it only stores the last 10 steps' result.
array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.])
>>> # move buf2's result into buf (keep it chronologically meanwhile)
>>> # move buf2's result into buf (meanwhile keep it chronologically)
>>> buf.update(buf2)
array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
0., 0., 0., 0., 0., 0., 0.])
@ -96,8 +99,8 @@ class ReplayBuffer(object):
self.indice = []
def sample(self, batch_size):
"""
Get a random sample from buffer with size = ``batch_size``
"""Get a random sample from buffer with size equal to batch_size. \
Return all the data in the buffer if batch_size is ``0``.
:return: Sample data and its corresponding index inside the buffer.
"""
@ -123,9 +126,8 @@ class ReplayBuffer(object):
class ListReplayBuffer(ReplayBuffer):
"""
The function of :class:`~tianshou.data.ListReplayBuffer` is almost the same
as :class:`~tianshou.data.ReplayBuffer`. The only difference is that
"""The function of :class:`~tianshou.data.ListReplayBuffer` is almost the
same as :class:`~tianshou.data.ReplayBuffer`. The only difference is that
:class:`~tianshou.data.ListReplayBuffer` is based on ``list``.
"""

View File

@ -9,7 +9,51 @@ from tianshou.utils import MovAvg
class Collector(object):
"""docstring for Collector"""
"""The :class:`~tianshou.data.Collector` enables the policy to interact
with different types of environments conveniently. Here is the usage:
::
policy = PGPolicy(...) # or other policies if you wish
env = gym.make('CartPole-v0')
replay_buffer = ReplayBuffer(size=10000)
# here we set up a collector with a single environment
collector = Collector(policy, env, buffer=replay_buffer)
# the collector supports vectorized environments as well
envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)])
buffers = [ReplayBuffer(size=5000) for _ in range(3)]
# you can also pass a list of replay buffer to collector, for multi-env
# collector = Collector(policy, envs, buffer=buffers)
collector = Collector(policy, envs, buffer=replay_buffer)
# collect at least 3 episodes
collector.collect(n_episode=3)
# collect 1 episode for the first env, 3 for the third env
collector.collect(n_episode=[1, 0, 3])
# collect at least 2 steps
collector.collect(n_step=2)
# collect episodes with visual rendering (the render argument is the
# sleep time between rendering consecutive frames)
collector.collect(n_episode=1, render=0.03)
# sample data with a given number of batch-size:
batch_data = collector.sample(batch_size=64)
# policy.learn(batch_data) # btw, vanilla policy gradient only
# supports on-policy training, so here we pick all data in the buffer
batch_data = collector.sample(batch_size=0)
policy.learn(batch_data)
# on-policy algorithms use the collected data only once, so here we
# clear the buffer
collector.reset_buffer()
For the scenario of collecting data from multiple environments to a single
buffer, the cache buffers will turn on automatically. It may return the
data more than the given limitation.
.. note::
Please make sure the given environment has a time limitation.
"""
def __init__(self, policy, env, buffer=None, stat_size=100):
super().__init__()
@ -48,6 +92,7 @@ class Collector(object):
self.episode_speed = MovAvg(stat_size)
def reset_buffer(self):
"""Reset the main data buffer."""
if self._multi_buf:
for b in self.buffer:
b.reset()
@ -55,9 +100,13 @@ class Collector(object):
self.buffer.reset()
def get_env_num(self):
"""Return the number of environments the collector has."""
return self.env_num
def reset_env(self):
"""Reset all of the environment(s)' states and reset all of the cache
buffers (if need).
"""
self._obs = self.env.reset()
self._act = self._rew = self._done = self._info = None
if self._multi_env:
@ -69,14 +118,17 @@ class Collector(object):
b.reset()
def seed(self, seed=None):
"""Reset all the seed(s) of the given environment(s)."""
if hasattr(self.env, 'seed'):
return self.env.seed(seed)
def render(self, **kwargs):
"""Render all the environment(s)."""
if hasattr(self.env, 'render'):
return self.env.render(**kwargs)
def close(self):
"""Close the environment(s)."""
if hasattr(self.env, 'close'):
self.env.close()
@ -87,12 +139,34 @@ class Collector(object):
return np.array([data])
def collect(self, n_step=0, n_episode=0, render=0):
"""Collect a specified number of step or episode.
:param n_step: an int, indicates how many steps you want to collect.
:param n_episode: an int or a list, indicates how many episodes you
want to collect (in each environment).
:param render: a float, the sleep time between rendering consecutive
frames. ``0`` means no rendering.
.. note::
One and only one collection number specification is permitted,
either ``n_step`` or ``n_episode``.
:return: A dict including the following keys
* ``n/ep`` the collected number of episodes.
* ``n/st`` the collected number of steps.
* ``v/st`` the speed of steps per second.
* ``v/ep`` the speed of episode per second.
* ``rew`` the mean reward over collected episodes.
* ``len`` the mean length over collected episodes.
"""
warning_count = 0
if not self._multi_env:
n_episode = np.sum(n_episode)
start_time = time.time()
assert sum([(n_step != 0), (n_episode != 0)]) == 1, \
"One and only one collection number specification permitted!"
"One and only one collection number specification is permitted!"
cur_step = 0
cur_episode = np.zeros(self.env_num) if self._multi_env else 0
reward_sum = 0
@ -218,6 +292,14 @@ class Collector(object):
}
def sample(self, batch_size):
"""Sample a data batch from the internal replay buffer. It will call
:meth:`~tianshou.policy.BasePolicy.process_fn` before returning
the final batch data.
:param batch_size: an int, ``0`` means it will extract all the data
from the buffer, otherwise it will extract the given batch_size of
data.
"""
if self._multi_buf:
if batch_size > 0:
lens = [len(b) for b in self.buffer]

View File

@ -12,12 +12,12 @@ from tianshou.env.utils import CloudpickleWrapper
class BaseVectorEnv(ABC, gym.Wrapper):
"""
Base class for vectorized environments wrapper. Usage:
"""Base class for vectorized environments wrapper. Usage:
::
env_num = 8
envs = VectorEnv([lambda: gym.make(task) for _ in range(env_num)])
assert len(envs) == env_num
It accepts a list of environment generators. In other words, an environment
generator ``efn`` of a specific task means that ``efn()`` returns the
@ -46,8 +46,7 @@ class BaseVectorEnv(ABC, gym.Wrapper):
@abstractmethod
def reset(self, id=None):
"""
Reset the state of all the environments and returns initial
"""Reset the state of all the environments and return initial
observations if id is ``None``, otherwise reset the specific
environments with given id, either an int or a list.
"""
@ -55,38 +54,38 @@ class BaseVectorEnv(ABC, gym.Wrapper):
@abstractmethod
def step(self, action):
"""
Run one timestep of all the environments dynamics. When end of episode
is reached, you are responsible for calling reset(id) to reset this
environments state.
"""Run one timestep of all the environments dynamics. When the end of
episode is reached, you are responsible for calling reset(id) to reset
this environments state.
Accepts a batch of action and returns a tuple (obs, rew, done, info).
Accept a batch of action and return a tuple (obs, rew, done, info).
:args:
action (numpy.ndarray): a batch of action provided by the agent
:param action: a numpy.ndarray, a batch of action provided by the
agent.
:return:
* obs (numpy.ndarray): agent's observation of current environments
* rew (numpy.ndarray) : amount of rewards returned after previous \
actions
* done (numpy.ndarray): whether these episodes have ended, in \
:return: A tuple including four items:
* ``obs`` a numpy.ndarray, the agent's observation of current \
environments
* ``rew`` a numpy.ndarray, the amount of rewards returned after \
previous actions
* ``done`` a numpy.ndarray, whether these episodes have ended, in \
which case further step() calls will return undefined results
* info (numpy.ndarray): contains auxiliary diagnostic information \
(helpful for debugging, and sometimes learning)
* ``info`` a numpy.ndarray, contains auxiliary diagnostic \
information (helpful for debugging, and sometimes learning)
"""
pass
@abstractmethod
def seed(self, seed=None):
"""
Set the seed for all environments. Accept ``None``, an int (which will
extend ``i`` to ``[i, i + 1, i + 2, ...]``) or a list.
"""Set the seed for all environments. Accept ``None``, an int (which
will extend ``i`` to ``[i, i + 1, i + 2, ...]``) or a list.
"""
pass
@abstractmethod
def render(self, **kwargs):
"""Renders the environment."""
"""Render all of the environments."""
pass
@abstractmethod
@ -96,8 +95,7 @@ class BaseVectorEnv(ABC, gym.Wrapper):
class VectorEnv(BaseVectorEnv):
"""
Dummy vectorized environment wrapper, implemented in for-loop. The usage \
"""Dummy vectorized environment wrapper, implemented in for-loop. The usage
is in :class:`~tianshou.env.BaseVectorEnv`.
"""
@ -173,8 +171,7 @@ def worker(parent, p, env_fn_wrapper):
class SubprocVectorEnv(BaseVectorEnv):
"""
Vectorized environment wrapper based on subprocess. The usage is in \
"""Vectorized environment wrapper based on subprocess. The usage is in
:class:`~tianshou.env.BaseVectorEnv`.
"""
@ -248,11 +245,11 @@ class SubprocVectorEnv(BaseVectorEnv):
class RayVectorEnv(BaseVectorEnv):
"""
Vectorized environment wrapper based on \
`ray <https://github.com/ray-project/ray>`_. However, according to our \
test, it is slower than :class:`~tianshou.env.SubprocVectorEnv`. The usage\
is in :class:`~tianshou.env.BaseVectorEnv`.
"""Vectorized environment wrapper based on
`ray <https://github.com/ray-project/ray>`_. However, according to our
test, it is about two times slower than
:class:`~tianshou.env.SubprocVectorEnv`. The usage is in
:class:`~tianshou.env.BaseVectorEnv`.
"""
def __init__(self, env_fns):

View File

@ -2,8 +2,7 @@ import numpy as np
class OUNoise(object):
"""
Class for Ornstein-Uhlenbeck process, as used for exploration in DDPG.
"""Class for Ornstein-Uhlenbeck process, as used for exploration in DDPG.
Usage:
::
@ -25,9 +24,8 @@ class OUNoise(object):
self.reset()
def __call__(self, size, mu=.1):
"""
Generate new noise. Return a ``numpy.ndarray`` which size is equal to
``size``.
"""Generate new noise. Return a ``numpy.ndarray`` which size is equal
to ``size``.
"""
if self.x is None or self.x.shape != size:
self.x = 0
@ -36,7 +34,5 @@ class OUNoise(object):
return self.x
def reset(self):
"""
Reset to the initial state.
"""
"""Reset to the initial state."""
self.x = None

View File

@ -64,7 +64,7 @@ class DDPGPolicy(BasePolicy):
def process_fn(self, batch, buffer, indice):
if self._rew_norm:
bfr = buffer.rew[:len(buffer)]
bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer
mean, std = bfr.mean(), bfr.std()
if std > self.__eps:
batch.rew = (batch.rew - mean) / std

View File

@ -8,7 +8,41 @@ from tianshou.trainer import test_episode, gather_info
def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
step_per_epoch, collect_per_step, episode_per_test,
batch_size, train_fn=None, test_fn=None, stop_fn=None,
writer=None, log_interval=1, verbose=True, task=''):
writer=None, log_interval=1, verbose=True, task='',
**kwargs):
"""A wrapper for off-policy trainer procedure.
Parameters
* **policy** an instance of the :class:`~tianshou.policy.BasePolicy`\
class.
* **train_collector** the collector used for training.
* **test_collector** the collector used for testing.
* **max_epoch** the maximum of epochs for training. The training \
process might be finished before reaching the ``max_epoch``.
* **step_per_epoch** the number of step for updating policy network \
in one epoch.
* **collect_per_step** the number of frames the collector would \
collect before the network update. In other words, collect some \
frames and do one policy network update.
* **episode_per_test** the number of episodes for one policy \
evaluation.
* **batch_size** the batch size of sample data, which is going to \
feed in the policy network.
* **train_fn** a function receives the current number of epoch index\
and performs some operations at the beginning of training in this \
epoch.
* **test_fn** a function receives the current number of epoch index \
and performs some operations at the beginning of testing in this \
epoch.
* **stop_fn** a function receives the average undiscounted returns \
of the testing result, return a boolean which indicates whether \
reaching the goal.
* **writer** a SummaryWriter provided from TensorBoard.
* **log_interval** an int indicating the log interval of the writer.
* **verbose** a boolean indicating whether to print the information.
:return: See :func:`~tianshou.trainer.gather_info`.
"""
global_step = 0
best_epoch, best_reward = -1, -1
stat = {}

View File

@ -9,7 +9,44 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
step_per_epoch, collect_per_step, repeat_per_collect,
episode_per_test, batch_size,
train_fn=None, test_fn=None, stop_fn=None,
writer=None, log_interval=1, verbose=True, task=''):
writer=None, log_interval=1, verbose=True, task='',
**kwargs):
"""A wrapper for on-policy trainer procedure.
Parameters
* **policy** an instance of the :class:`~tianshou.policy.BasePolicy`\
class.
* **train_collector** the collector used for training.
* **test_collector** the collector used for testing.
* **max_epoch** the maximum of epochs for training. The training \
process might be finished before reaching the ``max_epoch``.
* **step_per_epoch** the number of step for updating policy network \
in one epoch.
* **collect_per_step** the number of frames the collector would \
collect before the network update. In other words, collect some \
frames and do one policy network update.
* **repeat_per_collect** the number of repeat time for policy \
learning, for example, set it to 2 means the policy needs to learn\
each given batch data twice.
* **episode_per_test** the number of episodes for one policy \
evaluation.
* **batch_size** the batch size of sample data, which is going to \
feed in the policy network.
* **train_fn** a function receives the current number of epoch index\
and performs some operations at the beginning of training in this \
epoch.
* **test_fn** a function receives the current number of epoch index \
and performs some operations at the beginning of testing in this \
epoch.
* **stop_fn** a function receives the average undiscounted returns \
of the testing result, return a boolean which indicates whether \
reaching the goal.
* **writer** a SummaryWriter provided from TensorBoard.
* **log_interval** an int indicating the log interval of the writer.
* **verbose** a boolean indicating whether to print the information.
:return: See :func:`~tianshou.trainer.gather_info`.
"""
global_step = 0
best_epoch, best_reward = -1, -1
stat = {}

View File

@ -18,23 +18,22 @@ def test_episode(policy, collector, test_fn, epoch, n_episode):
def gather_info(start_time, train_c, test_c, best_reward):
"""
A simple wrapper of gathering information from collectors.
"""A simple wrapper of gathering information from collectors.
:return: A dictionary with following keys:
:return: A dictionary with the following keys:
* ``train_step``: the total collected step of training collector;
* ``train_episode``: the total collected episode of training collector;
* ``train_time/collector``: the time for collecting frames in the\
* ``train_step`` the total collected step of training collector;
* ``train_episode`` the total collected episode of training collector;
* ``train_time/collector`` the time for collecting frames in the \
training collector;
* ``train_time/model``: the time for training models;
* ``train_speed``: the speed of training (frames per second);
* ``test_step``: the total collected step of test collector;
* ``test_episode``: the total collected episode of test collector;
* ``test_time``: the time for testing;
* ``test_speed``: the speed of testing (frames per second);
* ``best_reward``: the best reward over the test results;
* ``duration``: the total elapsed time.
* ``train_time/model`` the time for training models;
* ``train_speed`` the speed of training (frames per second);
* ``test_step`` the total collected step of test collector;
* ``test_episode`` the total collected episode of test collector;
* ``test_time`` the time for testing;
* ``test_speed`` the speed of testing (frames per second);
* ``best_reward`` the best reward over the test results;
* ``duration`` the total elapsed time.
"""
duration = time.time() - start_time
model_time = duration - train_c.collect_time - test_c.collect_time

View File

@ -3,8 +3,7 @@ import numpy as np
class MovAvg(object):
"""
Class for moving average. Usage:
"""Class for moving average. Usage:
::
>>> stat = MovAvg(size=66)
@ -25,8 +24,7 @@ class MovAvg(object):
self.cache = []
def add(self, x):
"""
Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with
"""Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with
only one element, a python scalar, or a list of python scalar. It will
exclude the infinity.
"""