add policy docs (#21)
This commit is contained in:
parent
610390c132
commit
e0809ff135
@ -36,7 +36,7 @@ In Chinese, Tianshou means the innate talent, not taught by others. Tianshou is
|
||||
Tianshou is currently hosted on [PyPI](https://pypi.org/project/tianshou/). It requires Python >= 3.6. You can simply install Tianshou with the following command:
|
||||
|
||||
```bash
|
||||
pip3 install tianshou
|
||||
pip3 install tianshou -U
|
||||
```
|
||||
|
||||
You can also install with the newest version through GitHub:
|
||||
|
||||
@ -8,14 +8,14 @@ Welcome to Tianshou!
|
||||
|
||||
**Tianshou** (`天授 <https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88>`_) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed framework and pythonic API for building the deep reinforcement learning agent. The supported interface algorithms include:
|
||||
|
||||
* `Policy Gradient (PG) <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
|
||||
* `Deep Q-Network (DQN) <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
|
||||
* `Double DQN (DDQN) <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns
|
||||
* `Advantage Actor-Critic (A2C) <https://openai.com/blog/baselines-acktr-a2c/>`_
|
||||
* `Deep Deterministic Policy Gradient (DDPG) <https://arxiv.org/pdf/1509.02971.pdf>`_
|
||||
* `Proximal Policy Optimization (PPO) <https://arxiv.org/pdf/1707.06347.pdf>`_
|
||||
* `Twin Delayed DDPG (TD3) <https://arxiv.org/pdf/1802.09477.pdf>`_
|
||||
* `Soft Actor-Critic (SAC) <https://arxiv.org/pdf/1812.05905.pdf>`_
|
||||
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns
|
||||
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
|
||||
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
|
||||
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
||||
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
|
||||
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
|
||||
|
||||
|
||||
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms.
|
||||
@ -27,7 +27,7 @@ Installation
|
||||
Tianshou is currently hosted on `PyPI <https://pypi.org/project/tianshou/>`_. You can simply install Tianshou with the following command:
|
||||
::
|
||||
|
||||
pip3 install tianshou
|
||||
pip3 install tianshou -U
|
||||
|
||||
You can also install with the newest version through GitHub:
|
||||
::
|
||||
|
||||
@ -25,13 +25,14 @@ Data Buffer
|
||||
|
||||
Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out :class:`~tianshou.data.ReplayBuffer` for more detail.
|
||||
|
||||
.. _policy_concept:
|
||||
|
||||
Policy
|
||||
------
|
||||
|
||||
Tianshou aims to modularizing RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`.
|
||||
|
||||
A policy class typically has four parts:
|
||||
A policy class typically has four parts:
|
||||
|
||||
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including coping the target network and so on;
|
||||
* :meth:`~tianshou.policy.BasePolicy.__call__`: compute action with given observation;
|
||||
@ -119,7 +120,7 @@ There will be more types of trainers, for instance, multi-agent trainer.
|
||||
A High-level Explanation
|
||||
------------------------
|
||||
|
||||
We give a high-level explanation through the pseudocode used in section Policy:
|
||||
We give a high-level explanation through the pseudocode used in section :ref:`policy_concept`:
|
||||
::
|
||||
|
||||
# pseudocode, cannot work # methods in tianshou
|
||||
|
||||
@ -68,9 +68,9 @@ Here is about the experience of hyper-parameter tuning on CartPole and Pendulum:
|
||||
Code-level optimization
|
||||
-----------------------
|
||||
|
||||
Tianshou has many short-but-efficient lines of code. For example, when we want to compute :math:`V_s` and :math:`V_{s'}` by the same network, the best way is to concatenate :math:`s` and :math:`s'` together instead of computing the value function using twice of network forward.
|
||||
Tianshou has many short-but-efficient lines of code. For example, when we want to compute :math:`V(s)` and :math:`V(s')` by the same network, the best way is to concatenate :math:`s` and :math:`s'` together instead of computing the value function using twice of network forward.
|
||||
|
||||
Jiayi: I write each line of code after quite a lot of time of consideration. Details make a difference.
|
||||
.. Jiayi: I write each line of code after quite a lot of time of consideration. Details make a difference.
|
||||
|
||||
|
||||
Finally
|
||||
|
||||
@ -5,7 +5,8 @@ import torch.nn.functional as F
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'):
|
||||
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu',
|
||||
softmax=False):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.model = [
|
||||
@ -15,6 +16,8 @@ class Net(nn.Module):
|
||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
||||
if action_shape:
|
||||
self.model += [nn.Linear(128, np.prod(action_shape))]
|
||||
if softmax:
|
||||
self.model += [nn.Softmax(dim=-1)]
|
||||
self.model = nn.Sequential(*self.model)
|
||||
|
||||
def forward(self, s, state=None, info={}):
|
||||
|
||||
@ -113,7 +113,7 @@ def test_pg(args=get_args()):
|
||||
# model
|
||||
net = Net(
|
||||
args.layer_num, args.state_shape, args.action_shape,
|
||||
device=args.device)
|
||||
device=args.device, softmax=True)
|
||||
net = net.to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
dist = torch.distributions.Categorical
|
||||
|
||||
@ -104,10 +104,11 @@ class Batch(object):
|
||||
def split(self, size=None, permute=True):
|
||||
"""Split whole data into multiple small batch.
|
||||
|
||||
:param size: if it is ``None``, it does not split the data batch;
|
||||
:param int size: if it is ``None``, it does not split the data batch;
|
||||
otherwise it will divide the data batch with the given size.
|
||||
:param permute: randomly shuffle the entire data batch if it is
|
||||
``True``, otherwise remain in the same.
|
||||
Default to ``None``.
|
||||
:param bool permute: randomly shuffle the entire data batch if it is
|
||||
``True``, otherwise remain in the same. Default to ``True``.
|
||||
"""
|
||||
length = len(self)
|
||||
if size is None:
|
||||
|
||||
@ -10,7 +10,22 @@ from tianshou.utils import MovAvg
|
||||
|
||||
class Collector(object):
|
||||
"""The :class:`~tianshou.data.Collector` enables the policy to interact
|
||||
with different types of environments conveniently. Here is the usage:
|
||||
with different types of environments conveniently.
|
||||
|
||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
||||
class.
|
||||
:param env: an environment or an instance of the
|
||||
:class:`~tianshou.env.BaseVectorEnv` class.
|
||||
:param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer`
|
||||
class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to
|
||||
``None``, it will automatically assign a small-size
|
||||
:class:`~tianshou.data.ReplayBuffer`.
|
||||
:param int stat_size: for the moving average of recording speed, defaults
|
||||
to 100.
|
||||
:param bool store_obs_next: whether to store the obs_next to replay
|
||||
buffer, defaults to ``True``.
|
||||
|
||||
Example:
|
||||
::
|
||||
|
||||
policy = PGPolicy(...) # or other policies if you wish
|
||||
@ -55,7 +70,8 @@ class Collector(object):
|
||||
Please make sure the given environment has a time limitation.
|
||||
"""
|
||||
|
||||
def __init__(self, policy, env, buffer=None, stat_size=100):
|
||||
def __init__(self, policy, env, buffer=None, stat_size=100,
|
||||
store_obs_next=True, **kwargs):
|
||||
super().__init__()
|
||||
self.env = env
|
||||
self.env_num = 1
|
||||
@ -90,6 +106,7 @@ class Collector(object):
|
||||
self.state = None
|
||||
self.step_speed = MovAvg(stat_size)
|
||||
self.episode_speed = MovAvg(stat_size)
|
||||
self._save_s_ = store_obs_next
|
||||
|
||||
def reset_buffer(self):
|
||||
"""Reset the main data buffer."""
|
||||
@ -141,11 +158,12 @@ class Collector(object):
|
||||
def collect(self, n_step=0, n_episode=0, render=0):
|
||||
"""Collect a specified number of step or episode.
|
||||
|
||||
:param n_step: an int, indicates how many steps you want to collect.
|
||||
:param n_episode: an int or a list, indicates how many episodes you
|
||||
want to collect (in each environment).
|
||||
:param render: a float, the sleep time between rendering consecutive
|
||||
frames. ``0`` means no rendering.
|
||||
:param int n_step: how many steps you want to collect.
|
||||
:param n_episode: how many episodes you want to collect (in each
|
||||
environment).
|
||||
:type n_episode: int or list
|
||||
:param float render: the sleep time between rendering consecutive
|
||||
frames. No rendering if it is ``0`` (default option).
|
||||
|
||||
.. note::
|
||||
|
||||
@ -210,7 +228,8 @@ class Collector(object):
|
||||
data = {
|
||||
'obs': self._obs[i], 'act': self._act[i],
|
||||
'rew': self._rew[i], 'done': self._done[i],
|
||||
'obs_next': obs_next[i], 'info': self._info[i]}
|
||||
'obs_next': obs_next[i] if self._save_s_ else None,
|
||||
'info': self._info[i]}
|
||||
if self._cached_buf:
|
||||
warning_count += 1
|
||||
self._cached_buf[i].add(**data)
|
||||
@ -255,7 +274,8 @@ class Collector(object):
|
||||
else:
|
||||
self.buffer.add(
|
||||
self._obs, self._act[0], self._rew,
|
||||
self._done, obs_next, self._info)
|
||||
self._done, obs_next if self._save_s_ else None,
|
||||
self._info)
|
||||
cur_step += 1
|
||||
if self._done:
|
||||
cur_episode += 1
|
||||
@ -296,9 +316,9 @@ class Collector(object):
|
||||
:meth:`~tianshou.policy.BasePolicy.process_fn` before returning
|
||||
the final batch data.
|
||||
|
||||
:param batch_size: an int, ``0`` means it will extract all the data
|
||||
from the buffer, otherwise it will extract the given batch_size of
|
||||
data.
|
||||
:param int batch_size: ``0`` means it will extract all the data from
|
||||
the buffer, otherwise it will extract the data with the given
|
||||
batch_size.
|
||||
"""
|
||||
if self._multi_buf:
|
||||
if batch_size > 0:
|
||||
|
||||
3
tianshou/env/vecenv.py
vendored
3
tianshou/env/vecenv.py
vendored
@ -60,8 +60,7 @@ class BaseVectorEnv(ABC, gym.Wrapper):
|
||||
|
||||
Accept a batch of action and return a tuple (obs, rew, done, info).
|
||||
|
||||
:param action: a numpy.ndarray, a batch of action provided by the
|
||||
agent.
|
||||
:param numpy.ndarray action: a batch of action provided by the agent.
|
||||
|
||||
:return: A tuple including four items:
|
||||
|
||||
|
||||
@ -7,12 +7,26 @@ from tianshou.policy import PGPolicy
|
||||
|
||||
|
||||
class A2CPolicy(PGPolicy):
|
||||
"""docstring for A2CPolicy"""
|
||||
"""Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783
|
||||
|
||||
:param torch.nn.Module actor: the actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param torch.nn.Module critic: the critic network. (s -> V(s))
|
||||
:param torch.optim.Optimizer optim: the optimizer for actor and critic
|
||||
network.
|
||||
:param torch.distributions.Distribution dist_fn: for computing the action,
|
||||
defaults to ``torch.distributions.Categorical``.
|
||||
:param float discount_factor: in [0, 1], defaults to 0.99.
|
||||
:param float vf_coef: weight for value loss, defaults to 0.5.
|
||||
:param float ent_coef: weight for entropy loss, defaults to 0.01.
|
||||
:param float max_grad_norm: clipping gradients in back propagation,
|
||||
defaults to ``None``.
|
||||
"""
|
||||
|
||||
def __init__(self, actor, critic, optim,
|
||||
dist_fn=torch.distributions.Categorical,
|
||||
discount_factor=0.99, vf_coef=.5, ent_coef=.01,
|
||||
max_grad_norm=None):
|
||||
max_grad_norm=None, **kwargs):
|
||||
super().__init__(None, optim, dist_fn, discount_factor)
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
@ -20,13 +34,28 @@ class A2CPolicy(PGPolicy):
|
||||
self._w_ent = ent_coef
|
||||
self._grad_norm = max_grad_norm
|
||||
|
||||
def __call__(self, batch, state=None):
|
||||
def __call__(self, batch, state=None, **kwargs):
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
|
||||
|
||||
* ``act`` the action.
|
||||
* ``logits`` the network's raw output.
|
||||
* ``dist`` the action distribution.
|
||||
* ``state`` the hidden state.
|
||||
|
||||
More information can be found at
|
||||
:meth:`~tianshou.policy.BasePolicy.__call__`.
|
||||
"""
|
||||
logits, h = self.actor(batch.obs, state=state, info=batch.info)
|
||||
dist = self.dist_fn(logits)
|
||||
if isinstance(logits, tuple):
|
||||
dist = self.dist_fn(*logits)
|
||||
else:
|
||||
dist = self.dist_fn(logits)
|
||||
act = dist.sample()
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def learn(self, batch, batch_size=None, repeat=1):
|
||||
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
|
||||
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size):
|
||||
|
||||
@ -3,23 +3,74 @@ from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BasePolicy(ABC, nn.Module):
|
||||
"""docstring for BasePolicy"""
|
||||
"""Tianshou aims to modularizing RL algorithms. It comes into several
|
||||
classes of policies in Tianshou. All of the policy classes must inherit
|
||||
:class:`~tianshou.policy.BasePolicy`.
|
||||
|
||||
def __init__(self):
|
||||
A policy class typically has four parts:
|
||||
|
||||
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, \
|
||||
including coping the target network and so on;
|
||||
* :meth:`~tianshou.policy.BasePolicy.__call__`: compute action with given \
|
||||
observation;
|
||||
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from \
|
||||
the replay buffer (this function can interact with replay buffer);
|
||||
* :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given \
|
||||
batch of data.
|
||||
|
||||
Most of the policy needs a neural network to predict the action and an
|
||||
optimizer to optimize the policy. The rules of self-defined networks are:
|
||||
|
||||
1. Input: observation ``obs`` (may be a ``numpy.ndarray`` or \
|
||||
``torch.Tensor``), hidden state ``state`` (for RNN usage), and other \
|
||||
information ``info`` provided by the environment.
|
||||
2. Output: some ``logits`` and the next hidden state ``state``. The logits\
|
||||
could be a tuple instead of a ``torch.Tensor``. It depends on how the \
|
||||
policy process the network output. For example, in PPO, the return of \
|
||||
the network might be ``(mu, sigma), state`` for Gaussian policy.
|
||||
|
||||
Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``,
|
||||
you can operate :class:`~tianshou.policy.BasePolicy` almost the same as
|
||||
``torch.nn.Module``, for instance, load and save the model:
|
||||
::
|
||||
|
||||
torch.save(policy.state_dict(), 'policy.pth')
|
||||
policy.load_state_dict(torch.load('policy.pth'))
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
"""Pre-process the data from the provided replay buffer. Check out
|
||||
:ref:`policy_concept` for more information.
|
||||
"""
|
||||
return batch
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, batch, state=None):
|
||||
# return Batch(logits=..., act=..., state=None, ...)
|
||||
def __call__(self, batch, state=None, **kwargs):
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which MUST have the following\
|
||||
keys:
|
||||
|
||||
* ``act`` an numpy.ndarray or a torch.Tensor, the action over \
|
||||
given batch data.
|
||||
* ``state`` a dict, an numpy.ndarray or a torch.Tensor, the \
|
||||
internal state of the policy, ``None`` as default.
|
||||
|
||||
Other keys are user-defined. It depends on the algorithm. For example,
|
||||
::
|
||||
|
||||
# some code
|
||||
return Batch(logits=..., act=..., state=None, dist=...)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def learn(self, batch, batch_size=None):
|
||||
# return a dict which includes loss and its name
|
||||
pass
|
||||
def learn(self, batch, **kwargs):
|
||||
"""Update policy with a given batch of data.
|
||||
|
||||
def sync_weight(self):
|
||||
:return: A dict which includes loss and its corresponding label.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -5,18 +5,35 @@ import torch.nn.functional as F
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
|
||||
# from tianshou.exploration import OUNoise
|
||||
|
||||
|
||||
class DDPGPolicy(BasePolicy):
|
||||
"""docstring for DDPGPolicy"""
|
||||
"""Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971
|
||||
|
||||
:param torch.nn.Module actor: the actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
|
||||
:param torch.nn.Module critic: the critic network. (s, a -> Q(s, a))
|
||||
:param torch.optim.Optimizer critic_optim: the optimizer for critic
|
||||
network.
|
||||
:param float tau: param for soft update of the target network, defaults to
|
||||
0.005.
|
||||
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
|
||||
:param float exploration_noise: the noise intensity, add to the action,
|
||||
defaults to 0.1.
|
||||
:param action_range: the action range (minimum, maximum).
|
||||
:type action_range: [float, float]
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||
defaults to ``False``.
|
||||
:param bool ignore_done: ignore the done flag while training the policy,
|
||||
defaults to ``False``.
|
||||
"""
|
||||
|
||||
def __init__(self, actor, actor_optim, critic, critic_optim,
|
||||
tau=0.005, gamma=0.99, exploration_noise=0.1,
|
||||
action_range=None, reward_normalization=False,
|
||||
ignore_done=False):
|
||||
ignore_done=False, **kwargs):
|
||||
super().__init__()
|
||||
if actor is not None:
|
||||
self.actor, self.actor_old = actor, deepcopy(actor)
|
||||
@ -26,9 +43,9 @@ class DDPGPolicy(BasePolicy):
|
||||
self.critic, self.critic_old = critic, deepcopy(critic)
|
||||
self.critic_old.eval()
|
||||
self.critic_optim = critic_optim
|
||||
assert 0 < tau <= 1, 'tau should in (0, 1]'
|
||||
assert 0 <= tau <= 1, 'tau should in [0, 1]'
|
||||
self._tau = tau
|
||||
assert 0 < gamma <= 1, 'gamma should in (0, 1]'
|
||||
assert 0 <= gamma <= 1, 'gamma should in [0, 1]'
|
||||
self._gamma = gamma
|
||||
assert 0 <= exploration_noise, 'noise should not be negative'
|
||||
self._eps = exploration_noise
|
||||
@ -43,19 +60,23 @@ class DDPGPolicy(BasePolicy):
|
||||
self.__eps = np.finfo(np.float32).eps.item()
|
||||
|
||||
def set_eps(self, eps):
|
||||
"""Set the eps for exploration."""
|
||||
self._eps = eps
|
||||
|
||||
def train(self):
|
||||
"""Set the module in training mode, except for the target network."""
|
||||
self.training = True
|
||||
self.actor.train()
|
||||
self.critic.train()
|
||||
|
||||
def eval(self):
|
||||
"""Set the module in evaluation mode, except for the target network."""
|
||||
self.training = False
|
||||
self.actor.eval()
|
||||
self.critic.eval()
|
||||
|
||||
def sync_weight(self):
|
||||
"""Soft-update the weight for the target network."""
|
||||
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
|
||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||
for o, n in zip(
|
||||
@ -73,7 +94,19 @@ class DDPGPolicy(BasePolicy):
|
||||
return batch
|
||||
|
||||
def __call__(self, batch, state=None,
|
||||
model='actor', input='obs', eps=None):
|
||||
model='actor', input='obs', eps=None, **kwargs):
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:param float eps: in [0, 1], for exploration use.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which has 2 keys:
|
||||
|
||||
* ``act`` the action.
|
||||
* ``state`` the hidden state.
|
||||
|
||||
More information can be found at
|
||||
:meth:`~tianshou.policy.BasePolicy.__call__`.
|
||||
"""
|
||||
model = getattr(self, model)
|
||||
obs = getattr(batch, input)
|
||||
logits, h = model(obs, state=state, info=batch.info)
|
||||
@ -89,7 +122,7 @@ class DDPGPolicy(BasePolicy):
|
||||
logits = logits.clamp(self._range[0], self._range[1])
|
||||
return Batch(act=logits, state=h)
|
||||
|
||||
def learn(self, batch, batch_size=None, repeat=1):
|
||||
def learn(self, batch, **kwargs):
|
||||
with torch.no_grad():
|
||||
target_q = self.critic_old(batch.obs_next, self(
|
||||
batch, model='actor_old', input='obs_next', eps=0).act)
|
||||
|
||||
@ -8,11 +8,20 @@ from tianshou.policy import BasePolicy
|
||||
|
||||
|
||||
class DQNPolicy(BasePolicy):
|
||||
"""docstring for DQNPolicy"""
|
||||
"""Implementation of Deep Q Network. arXiv:1312.5602
|
||||
|
||||
:param torch.nn.Module model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
||||
:param float discount_factor: in [0, 1].
|
||||
:param int estimation_step: greater than 1, the number of steps to look
|
||||
ahead.
|
||||
:param int target_update_freq: the target network update frequency (``0``
|
||||
if you do not use the target network).
|
||||
"""
|
||||
|
||||
def __init__(self, model, optim, discount_factor=0.99,
|
||||
estimation_step=1, use_target_network=True,
|
||||
target_update_freq=300):
|
||||
estimation_step=1, target_update_freq=0, **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.optim = optim
|
||||
@ -21,28 +30,44 @@ class DQNPolicy(BasePolicy):
|
||||
self._gamma = discount_factor
|
||||
assert estimation_step > 0, 'estimation_step should greater than 0'
|
||||
self._n_step = estimation_step
|
||||
self._target = use_target_network
|
||||
self._target = target_update_freq > 0
|
||||
self._freq = target_update_freq
|
||||
self._cnt = 0
|
||||
if use_target_network:
|
||||
if self._target:
|
||||
self.model_old = deepcopy(self.model)
|
||||
self.model_old.eval()
|
||||
|
||||
def set_eps(self, eps):
|
||||
"""Set the eps for epsilon-greedy exploration."""
|
||||
self.eps = eps
|
||||
|
||||
def train(self):
|
||||
"""Set the module in training mode, except for the target network."""
|
||||
self.training = True
|
||||
self.model.train()
|
||||
|
||||
def eval(self):
|
||||
"""Set the module in evaluation mode, except for the target network."""
|
||||
self.training = False
|
||||
self.model.eval()
|
||||
|
||||
def sync_weight(self):
|
||||
"""Synchronize the weight for the target network."""
|
||||
self.model_old.load_state_dict(self.model.state_dict())
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
r"""Compute the n-step return for Q-learning targets:
|
||||
|
||||
.. math::
|
||||
G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i +
|
||||
\gamma^n (1 - d_{t + n}) \max_a Q_{old}(s_{t + n}, \arg\max_a
|
||||
(Q_{new}(s_{t + n}, a)))
|
||||
|
||||
, where :math:`\gamma` is the discount factor,
|
||||
:math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step
|
||||
:math:`t`. If there is no target network, the :math:`Q_{old}` is equal
|
||||
to :math:`Q_{new}`.
|
||||
"""
|
||||
returns = np.zeros_like(indice)
|
||||
gammas = np.zeros_like(indice) + self._n_step
|
||||
for n in range(self._n_step - 1, -1, -1):
|
||||
@ -70,7 +95,20 @@ class DQNPolicy(BasePolicy):
|
||||
return batch
|
||||
|
||||
def __call__(self, batch, state=None,
|
||||
model='model', input='obs', eps=None):
|
||||
model='model', input='obs', eps=None, **kwargs):
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:param float eps: in [0, 1], for epsilon-greedy exploration method.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which has 3 keys:
|
||||
|
||||
* ``act`` the action.
|
||||
* ``logits`` the network's raw output.
|
||||
* ``state`` the hidden state.
|
||||
|
||||
More information can be found at
|
||||
:meth:`~tianshou.policy.BasePolicy.__call__`.
|
||||
"""
|
||||
model = getattr(self, model)
|
||||
obs = getattr(batch, input)
|
||||
q, h = model(obs, state=state, info=batch.info)
|
||||
@ -83,7 +121,7 @@ class DQNPolicy(BasePolicy):
|
||||
act[i] = np.random.randint(q.shape[1])
|
||||
return Batch(logits=q, act=act, state=h)
|
||||
|
||||
def learn(self, batch, batch_size=None, repeat=1):
|
||||
def learn(self, batch, **kwargs):
|
||||
if self._target and self._cnt % self._freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
|
||||
@ -1,37 +1,65 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
|
||||
class PGPolicy(BasePolicy):
|
||||
"""docstring for PGPolicy"""
|
||||
"""Implementation of Vanilla Policy Gradient.
|
||||
|
||||
:param torch.nn.Module model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
||||
:param torch.distributions.Distribution dist_fn: for computing the action.
|
||||
:param float discount_factor: in [0, 1].
|
||||
"""
|
||||
|
||||
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
|
||||
discount_factor=0.99):
|
||||
discount_factor=0.99, **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.optim = optim
|
||||
self.dist_fn = dist_fn
|
||||
self._eps = np.finfo(np.float32).eps.item()
|
||||
assert 0 < discount_factor <= 1, 'discount_factor should in (0, 1]'
|
||||
assert 0 <= discount_factor <= 1, 'discount factor should in [0, 1]'
|
||||
self._gamma = discount_factor
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
r"""Compute the discounted returns for each frame:
|
||||
|
||||
.. math::
|
||||
G_t = \sum_{i=t}^T \gamma^{i-t}r_i
|
||||
|
||||
, where :math:`T` is the terminal time step, :math:`\gamma` is the
|
||||
discount factor, :math:`\gamma \in [0, 1]`.
|
||||
"""
|
||||
batch.returns = self._vanilla_returns(batch)
|
||||
# batch.returns = self._vectorized_returns(batch)
|
||||
return batch
|
||||
|
||||
def __call__(self, batch, state=None):
|
||||
def __call__(self, batch, state=None, **kwargs):
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
|
||||
|
||||
* ``act`` the action.
|
||||
* ``logits`` the network's raw output.
|
||||
* ``dist`` the action distribution.
|
||||
* ``state`` the hidden state.
|
||||
|
||||
More information can be found at
|
||||
:meth:`~tianshou.policy.BasePolicy.__call__`.
|
||||
"""
|
||||
logits, h = self.model(batch.obs, state=state, info=batch.info)
|
||||
logits = F.softmax(logits, dim=1)
|
||||
dist = self.dist_fn(logits)
|
||||
if isinstance(logits, tuple):
|
||||
dist = self.dist_fn(*logits)
|
||||
else:
|
||||
dist = self.dist_fn(logits)
|
||||
act = dist.sample()
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def learn(self, batch, batch_size=None, repeat=1):
|
||||
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
|
||||
losses = []
|
||||
r = batch.returns
|
||||
batch.returns = (r - r.mean()) / (r.std() + self._eps)
|
||||
@ -57,7 +85,7 @@ class PGPolicy(BasePolicy):
|
||||
return returns
|
||||
|
||||
def _vectorized_returns(self, batch):
|
||||
# according to my tests, it is slower than vanilla
|
||||
# according to my tests, it is slower than _vanilla_returns
|
||||
# import scipy.signal
|
||||
convolve = np.convolve
|
||||
# convolve = scipy.signal.convolve
|
||||
|
||||
@ -9,7 +9,24 @@ from tianshou.policy import PGPolicy
|
||||
|
||||
|
||||
class PPOPolicy(PGPolicy):
|
||||
"""docstring for PPOPolicy"""
|
||||
r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347
|
||||
|
||||
:param torch.nn.Module actor: the actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param torch.nn.Module critic: the critic network. (s -> V(s))
|
||||
:param torch.optim.Optimizer optim: the optimizer for actor and critic
|
||||
network.
|
||||
:param torch.distributions.Distribution dist_fn: for computing the action.
|
||||
:param float discount_factor: in [0, 1], defaults to 0.99.
|
||||
:param float max_grad_norm: clipping gradients in back propagation,
|
||||
defaults to ``None``.
|
||||
:param float eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original
|
||||
paper, defaults to 0.2.
|
||||
:param float vf_coef: weight for value loss, defaults to 0.5.
|
||||
:param float ent_coef: weight for entropy loss, defaults to 0.01.
|
||||
:param action_range: the action range (minimum, maximum).
|
||||
:type action_range: [float, float]
|
||||
"""
|
||||
|
||||
def __init__(self, actor, critic, optim, dist_fn,
|
||||
discount_factor=0.99,
|
||||
@ -17,7 +34,8 @@ class PPOPolicy(PGPolicy):
|
||||
eps_clip=.2,
|
||||
vf_coef=.5,
|
||||
ent_coef=.0,
|
||||
action_range=None):
|
||||
action_range=None,
|
||||
**kwargs):
|
||||
super().__init__(None, None, dist_fn, discount_factor)
|
||||
self._max_grad_norm = max_grad_norm
|
||||
self._eps_clip = eps_clip
|
||||
@ -31,16 +49,30 @@ class PPOPolicy(PGPolicy):
|
||||
self.optim = optim
|
||||
|
||||
def train(self):
|
||||
"""Set the module in training mode, except for the target network."""
|
||||
self.training = True
|
||||
self.actor.train()
|
||||
self.critic.train()
|
||||
|
||||
def eval(self):
|
||||
"""Set the module in evaluation mode, except for the target network."""
|
||||
self.training = False
|
||||
self.actor.eval()
|
||||
self.critic.eval()
|
||||
|
||||
def __call__(self, batch, state=None, model='actor'):
|
||||
def __call__(self, batch, state=None, model='actor', **kwargs):
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
|
||||
|
||||
* ``act`` the action.
|
||||
* ``logits`` the network's raw output.
|
||||
* ``dist`` the action distribution.
|
||||
* ``state`` the hidden state.
|
||||
|
||||
More information can be found at
|
||||
:meth:`~tianshou.policy.BasePolicy.__call__`.
|
||||
"""
|
||||
model = getattr(self, model)
|
||||
logits, h = model(batch.obs, state=state, info=batch.info)
|
||||
if isinstance(logits, tuple):
|
||||
@ -53,10 +85,11 @@ class PPOPolicy(PGPolicy):
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def sync_weight(self):
|
||||
"""Synchronize the weight for the target network."""
|
||||
self.actor_old.load_state_dict(self.actor.state_dict())
|
||||
self.critic_old.load_state_dict(self.critic.state_dict())
|
||||
|
||||
def learn(self, batch, batch_size=None, repeat=1):
|
||||
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
|
||||
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
||||
r = batch.returns
|
||||
batch.returns = (r - r.mean()) / (r.std() + self._eps)
|
||||
@ -79,7 +112,6 @@ class PPOPolicy(PGPolicy):
|
||||
clip_losses.append(clip_loss.item())
|
||||
vf_loss = F.smooth_l1_loss(self.critic(b.obs), target_v)
|
||||
vf_losses.append(vf_loss.item())
|
||||
|
||||
e_loss = dist.entropy().mean()
|
||||
ent_losses.append(e_loss.item())
|
||||
loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss
|
||||
|
||||
@ -8,12 +8,37 @@ from tianshou.policy import DDPGPolicy
|
||||
|
||||
|
||||
class SACPolicy(DDPGPolicy):
|
||||
"""docstring for SACPolicy"""
|
||||
"""Implementation of Soft Actor-Critic. arXiv:1812.05905
|
||||
|
||||
:param torch.nn.Module actor: the actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
|
||||
:param torch.nn.Module critic1: the first critic network. (s, a -> Q(s,
|
||||
a))
|
||||
:param torch.optim.Optimizer critic1_optim: the optimizer for the first
|
||||
critic network.
|
||||
:param torch.nn.Module critic2: the second critic network. (s, a -> Q(s,
|
||||
a))
|
||||
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
|
||||
critic network.
|
||||
:param float tau: param for soft update of the target network, defaults to
|
||||
0.005.
|
||||
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
|
||||
:param float exploration_noise: the noise intensity, add to the action,
|
||||
defaults to 0.1.
|
||||
:param float alpha: entropy regularization coefficient, default to 0.2.
|
||||
:param action_range: the action range (minimum, maximum).
|
||||
:type action_range: [float, float]
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||
defaults to ``False``.
|
||||
:param bool ignore_done: ignore the done flag while training the policy,
|
||||
defaults to ``False``.
|
||||
"""
|
||||
|
||||
def __init__(self, actor, actor_optim, critic1, critic1_optim,
|
||||
critic2, critic2_optim, tau=0.005, gamma=0.99,
|
||||
alpha=0.2, action_range=None, reward_normalization=False,
|
||||
ignore_done=False):
|
||||
ignore_done=False, **kwargs):
|
||||
super().__init__(None, None, None, None, tau, gamma, 0,
|
||||
action_range, reward_normalization, ignore_done)
|
||||
self.actor, self.actor_optim = actor, actor_optim
|
||||
@ -46,12 +71,11 @@ class SACPolicy(DDPGPolicy):
|
||||
self.critic2_old.parameters(), self.critic2.parameters()):
|
||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||
|
||||
def __call__(self, batch, state=None, input='obs'):
|
||||
def __call__(self, batch, state=None, input='obs', **kwargs):
|
||||
obs = getattr(batch, input)
|
||||
logits, h = self.actor(obs, state=state, info=batch.info)
|
||||
assert isinstance(logits, tuple)
|
||||
dist = torch.distributions.Normal(*logits)
|
||||
|
||||
x = dist.rsample()
|
||||
y = torch.tanh(x)
|
||||
act = y * self._action_scale + self._action_bias
|
||||
@ -61,7 +85,7 @@ class SACPolicy(DDPGPolicy):
|
||||
return Batch(
|
||||
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
|
||||
|
||||
def learn(self, batch, batch_size=None, repeat=1):
|
||||
def learn(self, batch, **kwargs):
|
||||
with torch.no_grad():
|
||||
obs_next_result = self(batch, input='obs_next')
|
||||
a_ = obs_next_result.act
|
||||
|
||||
@ -6,13 +6,44 @@ from tianshou.policy import DDPGPolicy
|
||||
|
||||
|
||||
class TD3Policy(DDPGPolicy):
|
||||
"""docstring for TD3Policy"""
|
||||
"""Implementation of Twin Delayed Deep Deterministic Policy Gradient,
|
||||
arXiv:1802.09477
|
||||
|
||||
:param torch.nn.Module actor: the actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
|
||||
:param torch.nn.Module critic1: the first critic network. (s, a -> Q(s,
|
||||
a))
|
||||
:param torch.optim.Optimizer critic1_optim: the optimizer for the first
|
||||
critic network.
|
||||
:param torch.nn.Module critic2: the second critic network. (s, a -> Q(s,
|
||||
a))
|
||||
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
|
||||
critic network.
|
||||
:param float tau: param for soft update of the target network, defaults to
|
||||
0.005.
|
||||
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
|
||||
:param float exploration_noise: the noise intensity, add to the action,
|
||||
defaults to 0.1.
|
||||
:param float policy_noise: the noise used in updating policy network,
|
||||
default to 0.2.
|
||||
:param int update_actor_freq: the update frequency of actor network,
|
||||
default to 2.
|
||||
:param float noise_clip: the clipping range used in updating policy
|
||||
network, default to 0.5.
|
||||
:param action_range: the action range (minimum, maximum).
|
||||
:type action_range: [float, float]
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||
defaults to ``False``.
|
||||
:param bool ignore_done: ignore the done flag while training the policy,
|
||||
defaults to ``False``.
|
||||
"""
|
||||
|
||||
def __init__(self, actor, actor_optim, critic1, critic1_optim,
|
||||
critic2, critic2_optim, tau=0.005, gamma=0.99,
|
||||
exploration_noise=0.1, policy_noise=0.2, update_actor_freq=2,
|
||||
noise_clip=0.5, action_range=None,
|
||||
reward_normalization=False, ignore_done=False):
|
||||
reward_normalization=False, ignore_done=False, **kwargs):
|
||||
super().__init__(actor, actor_optim, None, None, tau, gamma,
|
||||
exploration_noise, action_range, reward_normalization,
|
||||
ignore_done)
|
||||
@ -50,7 +81,7 @@ class TD3Policy(DDPGPolicy):
|
||||
self.critic2_old.parameters(), self.critic2.parameters()):
|
||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||
|
||||
def learn(self, batch, batch_size=None, repeat=1):
|
||||
def learn(self, batch, **kwargs):
|
||||
with torch.no_grad():
|
||||
a_ = self(batch, model='actor_old', input='obs_next').act
|
||||
dev = a_.device
|
||||
|
||||
@ -12,34 +12,35 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
**kwargs):
|
||||
"""A wrapper for off-policy trainer procedure.
|
||||
|
||||
Parameters
|
||||
* **policy** – an instance of the :class:`~tianshou.policy.BasePolicy`\
|
||||
class.
|
||||
* **train_collector** – the collector used for training.
|
||||
* **test_collector** – the collector used for testing.
|
||||
* **max_epoch** – the maximum of epochs for training. The training \
|
||||
process might be finished before reaching the ``max_epoch``.
|
||||
* **step_per_epoch** – the number of step for updating policy network \
|
||||
in one epoch.
|
||||
* **collect_per_step** – the number of frames the collector would \
|
||||
collect before the network update. In other words, collect some \
|
||||
frames and do one policy network update.
|
||||
* **episode_per_test** – the number of episodes for one policy \
|
||||
evaluation.
|
||||
* **batch_size** – the batch size of sample data, which is going to \
|
||||
feed in the policy network.
|
||||
* **train_fn** – a function receives the current number of epoch index\
|
||||
and performs some operations at the beginning of training in this \
|
||||
epoch.
|
||||
* **test_fn** – a function receives the current number of epoch index \
|
||||
and performs some operations at the beginning of testing in this \
|
||||
epoch.
|
||||
* **stop_fn** – a function receives the average undiscounted returns \
|
||||
of the testing result, return a boolean which indicates whether \
|
||||
reaching the goal.
|
||||
* **writer** – a SummaryWriter provided from TensorBoard.
|
||||
* **log_interval** – an int indicating the log interval of the writer.
|
||||
* **verbose** – a boolean indicating whether to print the information.
|
||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
||||
class.
|
||||
:param train_collector: the collector used for training.
|
||||
:type train_collector: :class:`~tianshou.data.Collector`
|
||||
:param test_collector: the collector used for testing.
|
||||
:type test_collector: :class:`~tianshou.data.Collector`
|
||||
:param int max_epoch: the maximum of epochs for training. The training
|
||||
process might be finished before reaching the ``max_epoch``.
|
||||
:param int step_per_epoch: the number of step for updating policy network
|
||||
in one epoch.
|
||||
:param int collect_per_step: the number of frames the collector would
|
||||
collect before the network update. In other words, collect some frames
|
||||
and do one policy network update.
|
||||
:param episode_per_test: the number of episodes for one policy evaluation.
|
||||
:param int batch_size: the batch size of sample data, which is going to
|
||||
feed in the policy network.
|
||||
:param function train_fn: a function receives the current number of epoch
|
||||
index and performs some operations at the beginning of training in this
|
||||
epoch.
|
||||
:param function test_fn: a function receives the current number of epoch
|
||||
index and performs some operations at the beginning of testing in this
|
||||
epoch.
|
||||
:param function stop_fn: a function receives the average undiscounted
|
||||
returns of the testing result, return a boolean which indicates whether
|
||||
reaching the goal.
|
||||
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
|
||||
SummaryWriter.
|
||||
:param int log_interval: the log interval of the writer.
|
||||
:param bool verbose: whether to print the information.
|
||||
|
||||
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||
"""
|
||||
|
||||
@ -13,37 +13,39 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
**kwargs):
|
||||
"""A wrapper for on-policy trainer procedure.
|
||||
|
||||
Parameters
|
||||
* **policy** – an instance of the :class:`~tianshou.policy.BasePolicy`\
|
||||
class.
|
||||
* **train_collector** – the collector used for training.
|
||||
* **test_collector** – the collector used for testing.
|
||||
* **max_epoch** – the maximum of epochs for training. The training \
|
||||
process might be finished before reaching the ``max_epoch``.
|
||||
* **step_per_epoch** – the number of step for updating policy network \
|
||||
in one epoch.
|
||||
* **collect_per_step** – the number of frames the collector would \
|
||||
collect before the network update. In other words, collect some \
|
||||
frames and do one policy network update.
|
||||
* **repeat_per_collect** – the number of repeat time for policy \
|
||||
learning, for example, set it to 2 means the policy needs to learn\
|
||||
each given batch data twice.
|
||||
* **episode_per_test** – the number of episodes for one policy \
|
||||
evaluation.
|
||||
* **batch_size** – the batch size of sample data, which is going to \
|
||||
feed in the policy network.
|
||||
* **train_fn** – a function receives the current number of epoch index\
|
||||
and performs some operations at the beginning of training in this \
|
||||
epoch.
|
||||
* **test_fn** – a function receives the current number of epoch index \
|
||||
and performs some operations at the beginning of testing in this \
|
||||
epoch.
|
||||
* **stop_fn** – a function receives the average undiscounted returns \
|
||||
of the testing result, return a boolean which indicates whether \
|
||||
reaching the goal.
|
||||
* **writer** – a SummaryWriter provided from TensorBoard.
|
||||
* **log_interval** – an int indicating the log interval of the writer.
|
||||
* **verbose** – a boolean indicating whether to print the information.
|
||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
||||
class.
|
||||
:param train_collector: the collector used for training.
|
||||
:type train_collector: :class:`~tianshou.data.Collector`
|
||||
:param test_collector: the collector used for testing.
|
||||
:type test_collector: :class:`~tianshou.data.Collector`
|
||||
:param int max_epoch: the maximum of epochs for training. The training
|
||||
process might be finished before reaching the ``max_epoch``.
|
||||
:param int step_per_epoch: the number of step for updating policy network
|
||||
in one epoch.
|
||||
:param int collect_per_step: the number of frames the collector would
|
||||
collect before the network update. In other words, collect some frames
|
||||
and do one policy network update.
|
||||
:param int repeat_per_collect: the number of repeat time for policy
|
||||
learning, for example, set it to 2 means the policy needs to learn each
|
||||
given batch data twice.
|
||||
:param episode_per_test: the number of episodes for one policy evaluation.
|
||||
:type episode_per_test: int or list of ints
|
||||
:param int batch_size: the batch size of sample data, which is going to
|
||||
feed in the policy network.
|
||||
:param function train_fn: a function receives the current number of epoch
|
||||
index and performs some operations at the beginning of training in this
|
||||
epoch.
|
||||
:param function test_fn: a function receives the current number of epoch
|
||||
index and performs some operations at the beginning of testing in this
|
||||
epoch.
|
||||
:param function stop_fn: a function receives the average undiscounted
|
||||
returns of the testing result, return a boolean which indicates whether
|
||||
reaching the goal.
|
||||
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
|
||||
SummaryWriter.
|
||||
:param int log_interval: the log interval of the writer.
|
||||
:param bool verbose: whether to print the information.
|
||||
|
||||
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||
"""
|
||||
|
||||
@ -26,7 +26,7 @@ class MovAvg(object):
|
||||
def add(self, x):
|
||||
"""Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with
|
||||
only one element, a python scalar, or a list of python scalar. It will
|
||||
exclude the infinity.
|
||||
automatically exclude the infinity.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.item()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user