add policy docs (#21)

This commit is contained in:
Trinkle23897 2020-04-06 19:36:59 +08:00
parent 610390c132
commit e0809ff135
20 changed files with 436 additions and 143 deletions

View File

@ -36,7 +36,7 @@ In Chinese, Tianshou means the innate talent, not taught by others. Tianshou is
Tianshou is currently hosted on [PyPI](https://pypi.org/project/tianshou/). It requires Python >= 3.6. You can simply install Tianshou with the following command:
```bash
pip3 install tianshou
pip3 install tianshou -U
```
You can also install with the newest version through GitHub:

View File

@ -8,14 +8,14 @@ Welcome to Tianshou!
**Tianshou** (`天授 <https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88>`_) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed framework and pythonic API for building the deep reinforcement learning agent. The supported interface algorithms include:
* `Policy Gradient (PG) <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
* `Deep Q-Network (DQN) <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
* `Double DQN (DDQN) <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns
* `Advantage Actor-Critic (A2C) <https://openai.com/blog/baselines-acktr-a2c/>`_
* `Deep Deterministic Policy Gradient (DDPG) <https://arxiv.org/pdf/1509.02971.pdf>`_
* `Proximal Policy Optimization (PPO) <https://arxiv.org/pdf/1707.06347.pdf>`_
* `Twin Delayed DDPG (TD3) <https://arxiv.org/pdf/1802.09477.pdf>`_
* `Soft Actor-Critic (SAC) <https://arxiv.org/pdf/1812.05905.pdf>`_
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms.
@ -27,7 +27,7 @@ Installation
Tianshou is currently hosted on `PyPI <https://pypi.org/project/tianshou/>`_. You can simply install Tianshou with the following command:
::
pip3 install tianshou
pip3 install tianshou -U
You can also install with the newest version through GitHub:
::

View File

@ -25,13 +25,14 @@ Data Buffer
Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out :class:`~tianshou.data.ReplayBuffer` for more detail.
.. _policy_concept:
Policy
------
Tianshou aims to modularizing RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`.
A policy class typically has four parts:
A policy class typically has four parts:
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including coping the target network and so on;
* :meth:`~tianshou.policy.BasePolicy.__call__`: compute action with given observation;
@ -119,7 +120,7 @@ There will be more types of trainers, for instance, multi-agent trainer.
A High-level Explanation
------------------------
We give a high-level explanation through the pseudocode used in section Policy:
We give a high-level explanation through the pseudocode used in section :ref:`policy_concept`:
::
# pseudocode, cannot work # methods in tianshou

View File

@ -68,9 +68,9 @@ Here is about the experience of hyper-parameter tuning on CartPole and Pendulum:
Code-level optimization
-----------------------
Tianshou has many short-but-efficient lines of code. For example, when we want to compute :math:`V_s` and :math:`V_{s'}` by the same network, the best way is to concatenate :math:`s` and :math:`s'` together instead of computing the value function using twice of network forward.
Tianshou has many short-but-efficient lines of code. For example, when we want to compute :math:`V(s)` and :math:`V(s')` by the same network, the best way is to concatenate :math:`s` and :math:`s'` together instead of computing the value function using twice of network forward.
Jiayi: I write each line of code after quite a lot of time of consideration. Details make a difference.
.. Jiayi: I write each line of code after quite a lot of time of consideration. Details make a difference.
Finally

View File

@ -5,7 +5,8 @@ import torch.nn.functional as F
class Net(nn.Module):
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'):
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu',
softmax=False):
super().__init__()
self.device = device
self.model = [
@ -15,6 +16,8 @@ class Net(nn.Module):
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
if action_shape:
self.model += [nn.Linear(128, np.prod(action_shape))]
if softmax:
self.model += [nn.Softmax(dim=-1)]
self.model = nn.Sequential(*self.model)
def forward(self, s, state=None, info={}):

View File

@ -113,7 +113,7 @@ def test_pg(args=get_args()):
# model
net = Net(
args.layer_num, args.state_shape, args.action_shape,
device=args.device)
device=args.device, softmax=True)
net = net.to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
dist = torch.distributions.Categorical

View File

@ -104,10 +104,11 @@ class Batch(object):
def split(self, size=None, permute=True):
"""Split whole data into multiple small batch.
:param size: if it is ``None``, it does not split the data batch;
:param int size: if it is ``None``, it does not split the data batch;
otherwise it will divide the data batch with the given size.
:param permute: randomly shuffle the entire data batch if it is
``True``, otherwise remain in the same.
Default to ``None``.
:param bool permute: randomly shuffle the entire data batch if it is
``True``, otherwise remain in the same. Default to ``True``.
"""
length = len(self)
if size is None:

View File

@ -10,7 +10,22 @@ from tianshou.utils import MovAvg
class Collector(object):
"""The :class:`~tianshou.data.Collector` enables the policy to interact
with different types of environments conveniently. Here is the usage:
with different types of environments conveniently.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
class.
:param env: an environment or an instance of the
:class:`~tianshou.env.BaseVectorEnv` class.
:param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer`
class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to
``None``, it will automatically assign a small-size
:class:`~tianshou.data.ReplayBuffer`.
:param int stat_size: for the moving average of recording speed, defaults
to 100.
:param bool store_obs_next: whether to store the obs_next to replay
buffer, defaults to ``True``.
Example:
::
policy = PGPolicy(...) # or other policies if you wish
@ -55,7 +70,8 @@ class Collector(object):
Please make sure the given environment has a time limitation.
"""
def __init__(self, policy, env, buffer=None, stat_size=100):
def __init__(self, policy, env, buffer=None, stat_size=100,
store_obs_next=True, **kwargs):
super().__init__()
self.env = env
self.env_num = 1
@ -90,6 +106,7 @@ class Collector(object):
self.state = None
self.step_speed = MovAvg(stat_size)
self.episode_speed = MovAvg(stat_size)
self._save_s_ = store_obs_next
def reset_buffer(self):
"""Reset the main data buffer."""
@ -141,11 +158,12 @@ class Collector(object):
def collect(self, n_step=0, n_episode=0, render=0):
"""Collect a specified number of step or episode.
:param n_step: an int, indicates how many steps you want to collect.
:param n_episode: an int or a list, indicates how many episodes you
want to collect (in each environment).
:param render: a float, the sleep time between rendering consecutive
frames. ``0`` means no rendering.
:param int n_step: how many steps you want to collect.
:param n_episode: how many episodes you want to collect (in each
environment).
:type n_episode: int or list
:param float render: the sleep time between rendering consecutive
frames. No rendering if it is ``0`` (default option).
.. note::
@ -210,7 +228,8 @@ class Collector(object):
data = {
'obs': self._obs[i], 'act': self._act[i],
'rew': self._rew[i], 'done': self._done[i],
'obs_next': obs_next[i], 'info': self._info[i]}
'obs_next': obs_next[i] if self._save_s_ else None,
'info': self._info[i]}
if self._cached_buf:
warning_count += 1
self._cached_buf[i].add(**data)
@ -255,7 +274,8 @@ class Collector(object):
else:
self.buffer.add(
self._obs, self._act[0], self._rew,
self._done, obs_next, self._info)
self._done, obs_next if self._save_s_ else None,
self._info)
cur_step += 1
if self._done:
cur_episode += 1
@ -296,9 +316,9 @@ class Collector(object):
:meth:`~tianshou.policy.BasePolicy.process_fn` before returning
the final batch data.
:param batch_size: an int, ``0`` means it will extract all the data
from the buffer, otherwise it will extract the given batch_size of
data.
:param int batch_size: ``0`` means it will extract all the data from
the buffer, otherwise it will extract the data with the given
batch_size.
"""
if self._multi_buf:
if batch_size > 0:

View File

@ -60,8 +60,7 @@ class BaseVectorEnv(ABC, gym.Wrapper):
Accept a batch of action and return a tuple (obs, rew, done, info).
:param action: a numpy.ndarray, a batch of action provided by the
agent.
:param numpy.ndarray action: a batch of action provided by the agent.
:return: A tuple including four items:

View File

@ -7,12 +7,26 @@ from tianshou.policy import PGPolicy
class A2CPolicy(PGPolicy):
"""docstring for A2CPolicy"""
"""Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783
:param torch.nn.Module actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.nn.Module critic: the critic network. (s -> V(s))
:param torch.optim.Optimizer optim: the optimizer for actor and critic
network.
:param torch.distributions.Distribution dist_fn: for computing the action,
defaults to ``torch.distributions.Categorical``.
:param float discount_factor: in [0, 1], defaults to 0.99.
:param float vf_coef: weight for value loss, defaults to 0.5.
:param float ent_coef: weight for entropy loss, defaults to 0.01.
:param float max_grad_norm: clipping gradients in back propagation,
defaults to ``None``.
"""
def __init__(self, actor, critic, optim,
dist_fn=torch.distributions.Categorical,
discount_factor=0.99, vf_coef=.5, ent_coef=.01,
max_grad_norm=None):
max_grad_norm=None, **kwargs):
super().__init__(None, optim, dist_fn, discount_factor)
self.actor = actor
self.critic = critic
@ -20,13 +34,28 @@ class A2CPolicy(PGPolicy):
self._w_ent = ent_coef
self._grad_norm = max_grad_norm
def __call__(self, batch, state=None):
def __call__(self, batch, state=None, **kwargs):
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
* ``act`` the action.
* ``logits`` the network's raw output.
* ``dist`` the action distribution.
* ``state`` the hidden state.
More information can be found at
:meth:`~tianshou.policy.BasePolicy.__call__`.
"""
logits, h = self.actor(batch.obs, state=state, info=batch.info)
dist = self.dist_fn(logits)
if isinstance(logits, tuple):
dist = self.dist_fn(*logits)
else:
dist = self.dist_fn(logits)
act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist)
def learn(self, batch, batch_size=None, repeat=1):
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
for _ in range(repeat):
for b in batch.split(batch_size):

View File

@ -3,23 +3,74 @@ from abc import ABC, abstractmethod
class BasePolicy(ABC, nn.Module):
"""docstring for BasePolicy"""
"""Tianshou aims to modularizing RL algorithms. It comes into several
classes of policies in Tianshou. All of the policy classes must inherit
:class:`~tianshou.policy.BasePolicy`.
def __init__(self):
A policy class typically has four parts:
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, \
including coping the target network and so on;
* :meth:`~tianshou.policy.BasePolicy.__call__`: compute action with given \
observation;
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from \
the replay buffer (this function can interact with replay buffer);
* :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given \
batch of data.
Most of the policy needs a neural network to predict the action and an
optimizer to optimize the policy. The rules of self-defined networks are:
1. Input: observation ``obs`` (may be a ``numpy.ndarray`` or \
``torch.Tensor``), hidden state ``state`` (for RNN usage), and other \
information ``info`` provided by the environment.
2. Output: some ``logits`` and the next hidden state ``state``. The logits\
could be a tuple instead of a ``torch.Tensor``. It depends on how the \
policy process the network output. For example, in PPO, the return of \
the network might be ``(mu, sigma), state`` for Gaussian policy.
Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``,
you can operate :class:`~tianshou.policy.BasePolicy` almost the same as
``torch.nn.Module``, for instance, load and save the model:
::
torch.save(policy.state_dict(), 'policy.pth')
policy.load_state_dict(torch.load('policy.pth'))
"""
def __init__(self, **kwargs):
super().__init__()
def process_fn(self, batch, buffer, indice):
"""Pre-process the data from the provided replay buffer. Check out
:ref:`policy_concept` for more information.
"""
return batch
@abstractmethod
def __call__(self, batch, state=None):
# return Batch(logits=..., act=..., state=None, ...)
def __call__(self, batch, state=None, **kwargs):
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which MUST have the following\
keys:
* ``act`` an numpy.ndarray or a torch.Tensor, the action over \
given batch data.
* ``state`` a dict, an numpy.ndarray or a torch.Tensor, the \
internal state of the policy, ``None`` as default.
Other keys are user-defined. It depends on the algorithm. For example,
::
# some code
return Batch(logits=..., act=..., state=None, dist=...)
"""
pass
@abstractmethod
def learn(self, batch, batch_size=None):
# return a dict which includes loss and its name
pass
def learn(self, batch, **kwargs):
"""Update policy with a given batch of data.
def sync_weight(self):
:return: A dict which includes loss and its corresponding label.
"""
pass

View File

@ -5,18 +5,35 @@ import torch.nn.functional as F
from tianshou.data import Batch
from tianshou.policy import BasePolicy
# from tianshou.exploration import OUNoise
class DDPGPolicy(BasePolicy):
"""docstring for DDPGPolicy"""
"""Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971
:param torch.nn.Module actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
:param torch.nn.Module critic: the critic network. (s, a -> Q(s, a))
:param torch.optim.Optimizer critic_optim: the optimizer for critic
network.
:param float tau: param for soft update of the target network, defaults to
0.005.
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
:param float exploration_noise: the noise intensity, add to the action,
defaults to 0.1.
:param action_range: the action range (minimum, maximum).
:type action_range: [float, float]
:param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to ``False``.
:param bool ignore_done: ignore the done flag while training the policy,
defaults to ``False``.
"""
def __init__(self, actor, actor_optim, critic, critic_optim,
tau=0.005, gamma=0.99, exploration_noise=0.1,
action_range=None, reward_normalization=False,
ignore_done=False):
ignore_done=False, **kwargs):
super().__init__()
if actor is not None:
self.actor, self.actor_old = actor, deepcopy(actor)
@ -26,9 +43,9 @@ class DDPGPolicy(BasePolicy):
self.critic, self.critic_old = critic, deepcopy(critic)
self.critic_old.eval()
self.critic_optim = critic_optim
assert 0 < tau <= 1, 'tau should in (0, 1]'
assert 0 <= tau <= 1, 'tau should in [0, 1]'
self._tau = tau
assert 0 < gamma <= 1, 'gamma should in (0, 1]'
assert 0 <= gamma <= 1, 'gamma should in [0, 1]'
self._gamma = gamma
assert 0 <= exploration_noise, 'noise should not be negative'
self._eps = exploration_noise
@ -43,19 +60,23 @@ class DDPGPolicy(BasePolicy):
self.__eps = np.finfo(np.float32).eps.item()
def set_eps(self, eps):
"""Set the eps for exploration."""
self._eps = eps
def train(self):
"""Set the module in training mode, except for the target network."""
self.training = True
self.actor.train()
self.critic.train()
def eval(self):
"""Set the module in evaluation mode, except for the target network."""
self.training = False
self.actor.eval()
self.critic.eval()
def sync_weight(self):
"""Soft-update the weight for the target network."""
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
for o, n in zip(
@ -73,7 +94,19 @@ class DDPGPolicy(BasePolicy):
return batch
def __call__(self, batch, state=None,
model='actor', input='obs', eps=None):
model='actor', input='obs', eps=None, **kwargs):
"""Compute action over the given batch data.
:param float eps: in [0, 1], for exploration use.
:return: A :class:`~tianshou.data.Batch` which has 2 keys:
* ``act`` the action.
* ``state`` the hidden state.
More information can be found at
:meth:`~tianshou.policy.BasePolicy.__call__`.
"""
model = getattr(self, model)
obs = getattr(batch, input)
logits, h = model(obs, state=state, info=batch.info)
@ -89,7 +122,7 @@ class DDPGPolicy(BasePolicy):
logits = logits.clamp(self._range[0], self._range[1])
return Batch(act=logits, state=h)
def learn(self, batch, batch_size=None, repeat=1):
def learn(self, batch, **kwargs):
with torch.no_grad():
target_q = self.critic_old(batch.obs_next, self(
batch, model='actor_old', input='obs_next', eps=0).act)

View File

@ -8,11 +8,20 @@ from tianshou.policy import BasePolicy
class DQNPolicy(BasePolicy):
"""docstring for DQNPolicy"""
"""Implementation of Deep Q Network. arXiv:1312.5602
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
:param float discount_factor: in [0, 1].
:param int estimation_step: greater than 1, the number of steps to look
ahead.
:param int target_update_freq: the target network update frequency (``0``
if you do not use the target network).
"""
def __init__(self, model, optim, discount_factor=0.99,
estimation_step=1, use_target_network=True,
target_update_freq=300):
estimation_step=1, target_update_freq=0, **kwargs):
super().__init__()
self.model = model
self.optim = optim
@ -21,28 +30,44 @@ class DQNPolicy(BasePolicy):
self._gamma = discount_factor
assert estimation_step > 0, 'estimation_step should greater than 0'
self._n_step = estimation_step
self._target = use_target_network
self._target = target_update_freq > 0
self._freq = target_update_freq
self._cnt = 0
if use_target_network:
if self._target:
self.model_old = deepcopy(self.model)
self.model_old.eval()
def set_eps(self, eps):
"""Set the eps for epsilon-greedy exploration."""
self.eps = eps
def train(self):
"""Set the module in training mode, except for the target network."""
self.training = True
self.model.train()
def eval(self):
"""Set the module in evaluation mode, except for the target network."""
self.training = False
self.model.eval()
def sync_weight(self):
"""Synchronize the weight for the target network."""
self.model_old.load_state_dict(self.model.state_dict())
def process_fn(self, batch, buffer, indice):
r"""Compute the n-step return for Q-learning targets:
.. math::
G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i +
\gamma^n (1 - d_{t + n}) \max_a Q_{old}(s_{t + n}, \arg\max_a
(Q_{new}(s_{t + n}, a)))
, where :math:`\gamma` is the discount factor,
:math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step
:math:`t`. If there is no target network, the :math:`Q_{old}` is equal
to :math:`Q_{new}`.
"""
returns = np.zeros_like(indice)
gammas = np.zeros_like(indice) + self._n_step
for n in range(self._n_step - 1, -1, -1):
@ -70,7 +95,20 @@ class DQNPolicy(BasePolicy):
return batch
def __call__(self, batch, state=None,
model='model', input='obs', eps=None):
model='model', input='obs', eps=None, **kwargs):
"""Compute action over the given batch data.
:param float eps: in [0, 1], for epsilon-greedy exploration method.
:return: A :class:`~tianshou.data.Batch` which has 3 keys:
* ``act`` the action.
* ``logits`` the network's raw output.
* ``state`` the hidden state.
More information can be found at
:meth:`~tianshou.policy.BasePolicy.__call__`.
"""
model = getattr(self, model)
obs = getattr(batch, input)
q, h = model(obs, state=state, info=batch.info)
@ -83,7 +121,7 @@ class DQNPolicy(BasePolicy):
act[i] = np.random.randint(q.shape[1])
return Batch(logits=q, act=act, state=h)
def learn(self, batch, batch_size=None, repeat=1):
def learn(self, batch, **kwargs):
if self._target and self._cnt % self._freq == 0:
self.sync_weight()
self.optim.zero_grad()

View File

@ -1,37 +1,65 @@
import torch
import numpy as np
import torch.nn.functional as F
from tianshou.data import Batch
from tianshou.policy import BasePolicy
class PGPolicy(BasePolicy):
"""docstring for PGPolicy"""
"""Implementation of Vanilla Policy Gradient.
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
:param torch.distributions.Distribution dist_fn: for computing the action.
:param float discount_factor: in [0, 1].
"""
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
discount_factor=0.99):
discount_factor=0.99, **kwargs):
super().__init__()
self.model = model
self.optim = optim
self.dist_fn = dist_fn
self._eps = np.finfo(np.float32).eps.item()
assert 0 < discount_factor <= 1, 'discount_factor should in (0, 1]'
assert 0 <= discount_factor <= 1, 'discount factor should in [0, 1]'
self._gamma = discount_factor
def process_fn(self, batch, buffer, indice):
r"""Compute the discounted returns for each frame:
.. math::
G_t = \sum_{i=t}^T \gamma^{i-t}r_i
, where :math:`T` is the terminal time step, :math:`\gamma` is the
discount factor, :math:`\gamma \in [0, 1]`.
"""
batch.returns = self._vanilla_returns(batch)
# batch.returns = self._vectorized_returns(batch)
return batch
def __call__(self, batch, state=None):
def __call__(self, batch, state=None, **kwargs):
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
* ``act`` the action.
* ``logits`` the network's raw output.
* ``dist`` the action distribution.
* ``state`` the hidden state.
More information can be found at
:meth:`~tianshou.policy.BasePolicy.__call__`.
"""
logits, h = self.model(batch.obs, state=state, info=batch.info)
logits = F.softmax(logits, dim=1)
dist = self.dist_fn(logits)
if isinstance(logits, tuple):
dist = self.dist_fn(*logits)
else:
dist = self.dist_fn(logits)
act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist)
def learn(self, batch, batch_size=None, repeat=1):
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
losses = []
r = batch.returns
batch.returns = (r - r.mean()) / (r.std() + self._eps)
@ -57,7 +85,7 @@ class PGPolicy(BasePolicy):
return returns
def _vectorized_returns(self, batch):
# according to my tests, it is slower than vanilla
# according to my tests, it is slower than _vanilla_returns
# import scipy.signal
convolve = np.convolve
# convolve = scipy.signal.convolve

View File

@ -9,7 +9,24 @@ from tianshou.policy import PGPolicy
class PPOPolicy(PGPolicy):
"""docstring for PPOPolicy"""
r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347
:param torch.nn.Module actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.nn.Module critic: the critic network. (s -> V(s))
:param torch.optim.Optimizer optim: the optimizer for actor and critic
network.
:param torch.distributions.Distribution dist_fn: for computing the action.
:param float discount_factor: in [0, 1], defaults to 0.99.
:param float max_grad_norm: clipping gradients in back propagation,
defaults to ``None``.
:param float eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original
paper, defaults to 0.2.
:param float vf_coef: weight for value loss, defaults to 0.5.
:param float ent_coef: weight for entropy loss, defaults to 0.01.
:param action_range: the action range (minimum, maximum).
:type action_range: [float, float]
"""
def __init__(self, actor, critic, optim, dist_fn,
discount_factor=0.99,
@ -17,7 +34,8 @@ class PPOPolicy(PGPolicy):
eps_clip=.2,
vf_coef=.5,
ent_coef=.0,
action_range=None):
action_range=None,
**kwargs):
super().__init__(None, None, dist_fn, discount_factor)
self._max_grad_norm = max_grad_norm
self._eps_clip = eps_clip
@ -31,16 +49,30 @@ class PPOPolicy(PGPolicy):
self.optim = optim
def train(self):
"""Set the module in training mode, except for the target network."""
self.training = True
self.actor.train()
self.critic.train()
def eval(self):
"""Set the module in evaluation mode, except for the target network."""
self.training = False
self.actor.eval()
self.critic.eval()
def __call__(self, batch, state=None, model='actor'):
def __call__(self, batch, state=None, model='actor', **kwargs):
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
* ``act`` the action.
* ``logits`` the network's raw output.
* ``dist`` the action distribution.
* ``state`` the hidden state.
More information can be found at
:meth:`~tianshou.policy.BasePolicy.__call__`.
"""
model = getattr(self, model)
logits, h = model(batch.obs, state=state, info=batch.info)
if isinstance(logits, tuple):
@ -53,10 +85,11 @@ class PPOPolicy(PGPolicy):
return Batch(logits=logits, act=act, state=h, dist=dist)
def sync_weight(self):
"""Synchronize the weight for the target network."""
self.actor_old.load_state_dict(self.actor.state_dict())
self.critic_old.load_state_dict(self.critic.state_dict())
def learn(self, batch, batch_size=None, repeat=1):
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
r = batch.returns
batch.returns = (r - r.mean()) / (r.std() + self._eps)
@ -79,7 +112,6 @@ class PPOPolicy(PGPolicy):
clip_losses.append(clip_loss.item())
vf_loss = F.smooth_l1_loss(self.critic(b.obs), target_v)
vf_losses.append(vf_loss.item())
e_loss = dist.entropy().mean()
ent_losses.append(e_loss.item())
loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss

View File

@ -8,12 +8,37 @@ from tianshou.policy import DDPGPolicy
class SACPolicy(DDPGPolicy):
"""docstring for SACPolicy"""
"""Implementation of Soft Actor-Critic. arXiv:1812.05905
:param torch.nn.Module actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
:param torch.nn.Module critic1: the first critic network. (s, a -> Q(s,
a))
:param torch.optim.Optimizer critic1_optim: the optimizer for the first
critic network.
:param torch.nn.Module critic2: the second critic network. (s, a -> Q(s,
a))
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
critic network.
:param float tau: param for soft update of the target network, defaults to
0.005.
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
:param float exploration_noise: the noise intensity, add to the action,
defaults to 0.1.
:param float alpha: entropy regularization coefficient, default to 0.2.
:param action_range: the action range (minimum, maximum).
:type action_range: [float, float]
:param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to ``False``.
:param bool ignore_done: ignore the done flag while training the policy,
defaults to ``False``.
"""
def __init__(self, actor, actor_optim, critic1, critic1_optim,
critic2, critic2_optim, tau=0.005, gamma=0.99,
alpha=0.2, action_range=None, reward_normalization=False,
ignore_done=False):
ignore_done=False, **kwargs):
super().__init__(None, None, None, None, tau, gamma, 0,
action_range, reward_normalization, ignore_done)
self.actor, self.actor_optim = actor, actor_optim
@ -46,12 +71,11 @@ class SACPolicy(DDPGPolicy):
self.critic2_old.parameters(), self.critic2.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
def __call__(self, batch, state=None, input='obs'):
def __call__(self, batch, state=None, input='obs', **kwargs):
obs = getattr(batch, input)
logits, h = self.actor(obs, state=state, info=batch.info)
assert isinstance(logits, tuple)
dist = torch.distributions.Normal(*logits)
x = dist.rsample()
y = torch.tanh(x)
act = y * self._action_scale + self._action_bias
@ -61,7 +85,7 @@ class SACPolicy(DDPGPolicy):
return Batch(
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
def learn(self, batch, batch_size=None, repeat=1):
def learn(self, batch, **kwargs):
with torch.no_grad():
obs_next_result = self(batch, input='obs_next')
a_ = obs_next_result.act

View File

@ -6,13 +6,44 @@ from tianshou.policy import DDPGPolicy
class TD3Policy(DDPGPolicy):
"""docstring for TD3Policy"""
"""Implementation of Twin Delayed Deep Deterministic Policy Gradient,
arXiv:1802.09477
:param torch.nn.Module actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
:param torch.nn.Module critic1: the first critic network. (s, a -> Q(s,
a))
:param torch.optim.Optimizer critic1_optim: the optimizer for the first
critic network.
:param torch.nn.Module critic2: the second critic network. (s, a -> Q(s,
a))
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
critic network.
:param float tau: param for soft update of the target network, defaults to
0.005.
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
:param float exploration_noise: the noise intensity, add to the action,
defaults to 0.1.
:param float policy_noise: the noise used in updating policy network,
default to 0.2.
:param int update_actor_freq: the update frequency of actor network,
default to 2.
:param float noise_clip: the clipping range used in updating policy
network, default to 0.5.
:param action_range: the action range (minimum, maximum).
:type action_range: [float, float]
:param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to ``False``.
:param bool ignore_done: ignore the done flag while training the policy,
defaults to ``False``.
"""
def __init__(self, actor, actor_optim, critic1, critic1_optim,
critic2, critic2_optim, tau=0.005, gamma=0.99,
exploration_noise=0.1, policy_noise=0.2, update_actor_freq=2,
noise_clip=0.5, action_range=None,
reward_normalization=False, ignore_done=False):
reward_normalization=False, ignore_done=False, **kwargs):
super().__init__(actor, actor_optim, None, None, tau, gamma,
exploration_noise, action_range, reward_normalization,
ignore_done)
@ -50,7 +81,7 @@ class TD3Policy(DDPGPolicy):
self.critic2_old.parameters(), self.critic2.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
def learn(self, batch, batch_size=None, repeat=1):
def learn(self, batch, **kwargs):
with torch.no_grad():
a_ = self(batch, model='actor_old', input='obs_next').act
dev = a_.device

View File

@ -12,34 +12,35 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
**kwargs):
"""A wrapper for off-policy trainer procedure.
Parameters
* **policy** an instance of the :class:`~tianshou.policy.BasePolicy`\
class.
* **train_collector** the collector used for training.
* **test_collector** the collector used for testing.
* **max_epoch** the maximum of epochs for training. The training \
process might be finished before reaching the ``max_epoch``.
* **step_per_epoch** the number of step for updating policy network \
in one epoch.
* **collect_per_step** the number of frames the collector would \
collect before the network update. In other words, collect some \
frames and do one policy network update.
* **episode_per_test** the number of episodes for one policy \
evaluation.
* **batch_size** the batch size of sample data, which is going to \
feed in the policy network.
* **train_fn** a function receives the current number of epoch index\
and performs some operations at the beginning of training in this \
epoch.
* **test_fn** a function receives the current number of epoch index \
and performs some operations at the beginning of testing in this \
epoch.
* **stop_fn** a function receives the average undiscounted returns \
of the testing result, return a boolean which indicates whether \
reaching the goal.
* **writer** a SummaryWriter provided from TensorBoard.
* **log_interval** an int indicating the log interval of the writer.
* **verbose** a boolean indicating whether to print the information.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
class.
:param train_collector: the collector used for training.
:type train_collector: :class:`~tianshou.data.Collector`
:param test_collector: the collector used for testing.
:type test_collector: :class:`~tianshou.data.Collector`
:param int max_epoch: the maximum of epochs for training. The training
process might be finished before reaching the ``max_epoch``.
:param int step_per_epoch: the number of step for updating policy network
in one epoch.
:param int collect_per_step: the number of frames the collector would
collect before the network update. In other words, collect some frames
and do one policy network update.
:param episode_per_test: the number of episodes for one policy evaluation.
:param int batch_size: the batch size of sample data, which is going to
feed in the policy network.
:param function train_fn: a function receives the current number of epoch
index and performs some operations at the beginning of training in this
epoch.
:param function test_fn: a function receives the current number of epoch
index and performs some operations at the beginning of testing in this
epoch.
:param function stop_fn: a function receives the average undiscounted
returns of the testing result, return a boolean which indicates whether
reaching the goal.
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
SummaryWriter.
:param int log_interval: the log interval of the writer.
:param bool verbose: whether to print the information.
:return: See :func:`~tianshou.trainer.gather_info`.
"""

View File

@ -13,37 +13,39 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
**kwargs):
"""A wrapper for on-policy trainer procedure.
Parameters
* **policy** an instance of the :class:`~tianshou.policy.BasePolicy`\
class.
* **train_collector** the collector used for training.
* **test_collector** the collector used for testing.
* **max_epoch** the maximum of epochs for training. The training \
process might be finished before reaching the ``max_epoch``.
* **step_per_epoch** the number of step for updating policy network \
in one epoch.
* **collect_per_step** the number of frames the collector would \
collect before the network update. In other words, collect some \
frames and do one policy network update.
* **repeat_per_collect** the number of repeat time for policy \
learning, for example, set it to 2 means the policy needs to learn\
each given batch data twice.
* **episode_per_test** the number of episodes for one policy \
evaluation.
* **batch_size** the batch size of sample data, which is going to \
feed in the policy network.
* **train_fn** a function receives the current number of epoch index\
and performs some operations at the beginning of training in this \
epoch.
* **test_fn** a function receives the current number of epoch index \
and performs some operations at the beginning of testing in this \
epoch.
* **stop_fn** a function receives the average undiscounted returns \
of the testing result, return a boolean which indicates whether \
reaching the goal.
* **writer** a SummaryWriter provided from TensorBoard.
* **log_interval** an int indicating the log interval of the writer.
* **verbose** a boolean indicating whether to print the information.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
class.
:param train_collector: the collector used for training.
:type train_collector: :class:`~tianshou.data.Collector`
:param test_collector: the collector used for testing.
:type test_collector: :class:`~tianshou.data.Collector`
:param int max_epoch: the maximum of epochs for training. The training
process might be finished before reaching the ``max_epoch``.
:param int step_per_epoch: the number of step for updating policy network
in one epoch.
:param int collect_per_step: the number of frames the collector would
collect before the network update. In other words, collect some frames
and do one policy network update.
:param int repeat_per_collect: the number of repeat time for policy
learning, for example, set it to 2 means the policy needs to learn each
given batch data twice.
:param episode_per_test: the number of episodes for one policy evaluation.
:type episode_per_test: int or list of ints
:param int batch_size: the batch size of sample data, which is going to
feed in the policy network.
:param function train_fn: a function receives the current number of epoch
index and performs some operations at the beginning of training in this
epoch.
:param function test_fn: a function receives the current number of epoch
index and performs some operations at the beginning of testing in this
epoch.
:param function stop_fn: a function receives the average undiscounted
returns of the testing result, return a boolean which indicates whether
reaching the goal.
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
SummaryWriter.
:param int log_interval: the log interval of the writer.
:param bool verbose: whether to print the information.
:return: See :func:`~tianshou.trainer.gather_info`.
"""

View File

@ -26,7 +26,7 @@ class MovAvg(object):
def add(self, x):
"""Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with
only one element, a python scalar, or a list of python scalar. It will
exclude the infinity.
automatically exclude the infinity.
"""
if isinstance(x, torch.Tensor):
x = x.item()