add policy docs (#21)
This commit is contained in:
parent
610390c132
commit
e0809ff135
@ -36,7 +36,7 @@ In Chinese, Tianshou means the innate talent, not taught by others. Tianshou is
|
|||||||
Tianshou is currently hosted on [PyPI](https://pypi.org/project/tianshou/). It requires Python >= 3.6. You can simply install Tianshou with the following command:
|
Tianshou is currently hosted on [PyPI](https://pypi.org/project/tianshou/). It requires Python >= 3.6. You can simply install Tianshou with the following command:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install tianshou
|
pip3 install tianshou -U
|
||||||
```
|
```
|
||||||
|
|
||||||
You can also install with the newest version through GitHub:
|
You can also install with the newest version through GitHub:
|
||||||
|
|||||||
@ -8,14 +8,14 @@ Welcome to Tianshou!
|
|||||||
|
|
||||||
**Tianshou** (`天授 <https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88>`_) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed framework and pythonic API for building the deep reinforcement learning agent. The supported interface algorithms include:
|
**Tianshou** (`天授 <https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88>`_) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed framework and pythonic API for building the deep reinforcement learning agent. The supported interface algorithms include:
|
||||||
|
|
||||||
* `Policy Gradient (PG) <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
|
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
|
||||||
* `Deep Q-Network (DQN) <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
|
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
|
||||||
* `Double DQN (DDQN) <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns
|
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns
|
||||||
* `Advantage Actor-Critic (A2C) <https://openai.com/blog/baselines-acktr-a2c/>`_
|
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
|
||||||
* `Deep Deterministic Policy Gradient (DDPG) <https://arxiv.org/pdf/1509.02971.pdf>`_
|
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
|
||||||
* `Proximal Policy Optimization (PPO) <https://arxiv.org/pdf/1707.06347.pdf>`_
|
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
||||||
* `Twin Delayed DDPG (TD3) <https://arxiv.org/pdf/1802.09477.pdf>`_
|
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
|
||||||
* `Soft Actor-Critic (SAC) <https://arxiv.org/pdf/1812.05905.pdf>`_
|
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
|
||||||
|
|
||||||
|
|
||||||
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms.
|
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms.
|
||||||
@ -27,7 +27,7 @@ Installation
|
|||||||
Tianshou is currently hosted on `PyPI <https://pypi.org/project/tianshou/>`_. You can simply install Tianshou with the following command:
|
Tianshou is currently hosted on `PyPI <https://pypi.org/project/tianshou/>`_. You can simply install Tianshou with the following command:
|
||||||
::
|
::
|
||||||
|
|
||||||
pip3 install tianshou
|
pip3 install tianshou -U
|
||||||
|
|
||||||
You can also install with the newest version through GitHub:
|
You can also install with the newest version through GitHub:
|
||||||
::
|
::
|
||||||
|
|||||||
@ -25,13 +25,14 @@ Data Buffer
|
|||||||
|
|
||||||
Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out :class:`~tianshou.data.ReplayBuffer` for more detail.
|
Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out :class:`~tianshou.data.ReplayBuffer` for more detail.
|
||||||
|
|
||||||
|
.. _policy_concept:
|
||||||
|
|
||||||
Policy
|
Policy
|
||||||
------
|
------
|
||||||
|
|
||||||
Tianshou aims to modularizing RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`.
|
Tianshou aims to modularizing RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`.
|
||||||
|
|
||||||
A policy class typically has four parts:
|
A policy class typically has four parts:
|
||||||
|
|
||||||
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including coping the target network and so on;
|
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including coping the target network and so on;
|
||||||
* :meth:`~tianshou.policy.BasePolicy.__call__`: compute action with given observation;
|
* :meth:`~tianshou.policy.BasePolicy.__call__`: compute action with given observation;
|
||||||
@ -119,7 +120,7 @@ There will be more types of trainers, for instance, multi-agent trainer.
|
|||||||
A High-level Explanation
|
A High-level Explanation
|
||||||
------------------------
|
------------------------
|
||||||
|
|
||||||
We give a high-level explanation through the pseudocode used in section Policy:
|
We give a high-level explanation through the pseudocode used in section :ref:`policy_concept`:
|
||||||
::
|
::
|
||||||
|
|
||||||
# pseudocode, cannot work # methods in tianshou
|
# pseudocode, cannot work # methods in tianshou
|
||||||
|
|||||||
@ -68,9 +68,9 @@ Here is about the experience of hyper-parameter tuning on CartPole and Pendulum:
|
|||||||
Code-level optimization
|
Code-level optimization
|
||||||
-----------------------
|
-----------------------
|
||||||
|
|
||||||
Tianshou has many short-but-efficient lines of code. For example, when we want to compute :math:`V_s` and :math:`V_{s'}` by the same network, the best way is to concatenate :math:`s` and :math:`s'` together instead of computing the value function using twice of network forward.
|
Tianshou has many short-but-efficient lines of code. For example, when we want to compute :math:`V(s)` and :math:`V(s')` by the same network, the best way is to concatenate :math:`s` and :math:`s'` together instead of computing the value function using twice of network forward.
|
||||||
|
|
||||||
Jiayi: I write each line of code after quite a lot of time of consideration. Details make a difference.
|
.. Jiayi: I write each line of code after quite a lot of time of consideration. Details make a difference.
|
||||||
|
|
||||||
|
|
||||||
Finally
|
Finally
|
||||||
|
|||||||
@ -5,7 +5,8 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
|
|
||||||
class Net(nn.Module):
|
class Net(nn.Module):
|
||||||
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'):
|
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu',
|
||||||
|
softmax=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.model = [
|
self.model = [
|
||||||
@ -15,6 +16,8 @@ class Net(nn.Module):
|
|||||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
||||||
if action_shape:
|
if action_shape:
|
||||||
self.model += [nn.Linear(128, np.prod(action_shape))]
|
self.model += [nn.Linear(128, np.prod(action_shape))]
|
||||||
|
if softmax:
|
||||||
|
self.model += [nn.Softmax(dim=-1)]
|
||||||
self.model = nn.Sequential(*self.model)
|
self.model = nn.Sequential(*self.model)
|
||||||
|
|
||||||
def forward(self, s, state=None, info={}):
|
def forward(self, s, state=None, info={}):
|
||||||
|
|||||||
@ -113,7 +113,7 @@ def test_pg(args=get_args()):
|
|||||||
# model
|
# model
|
||||||
net = Net(
|
net = Net(
|
||||||
args.layer_num, args.state_shape, args.action_shape,
|
args.layer_num, args.state_shape, args.action_shape,
|
||||||
device=args.device)
|
device=args.device, softmax=True)
|
||||||
net = net.to(args.device)
|
net = net.to(args.device)
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
dist = torch.distributions.Categorical
|
dist = torch.distributions.Categorical
|
||||||
|
|||||||
@ -104,10 +104,11 @@ class Batch(object):
|
|||||||
def split(self, size=None, permute=True):
|
def split(self, size=None, permute=True):
|
||||||
"""Split whole data into multiple small batch.
|
"""Split whole data into multiple small batch.
|
||||||
|
|
||||||
:param size: if it is ``None``, it does not split the data batch;
|
:param int size: if it is ``None``, it does not split the data batch;
|
||||||
otherwise it will divide the data batch with the given size.
|
otherwise it will divide the data batch with the given size.
|
||||||
:param permute: randomly shuffle the entire data batch if it is
|
Default to ``None``.
|
||||||
``True``, otherwise remain in the same.
|
:param bool permute: randomly shuffle the entire data batch if it is
|
||||||
|
``True``, otherwise remain in the same. Default to ``True``.
|
||||||
"""
|
"""
|
||||||
length = len(self)
|
length = len(self)
|
||||||
if size is None:
|
if size is None:
|
||||||
|
|||||||
@ -10,7 +10,22 @@ from tianshou.utils import MovAvg
|
|||||||
|
|
||||||
class Collector(object):
|
class Collector(object):
|
||||||
"""The :class:`~tianshou.data.Collector` enables the policy to interact
|
"""The :class:`~tianshou.data.Collector` enables the policy to interact
|
||||||
with different types of environments conveniently. Here is the usage:
|
with different types of environments conveniently.
|
||||||
|
|
||||||
|
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
||||||
|
class.
|
||||||
|
:param env: an environment or an instance of the
|
||||||
|
:class:`~tianshou.env.BaseVectorEnv` class.
|
||||||
|
:param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer`
|
||||||
|
class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to
|
||||||
|
``None``, it will automatically assign a small-size
|
||||||
|
:class:`~tianshou.data.ReplayBuffer`.
|
||||||
|
:param int stat_size: for the moving average of recording speed, defaults
|
||||||
|
to 100.
|
||||||
|
:param bool store_obs_next: whether to store the obs_next to replay
|
||||||
|
buffer, defaults to ``True``.
|
||||||
|
|
||||||
|
Example:
|
||||||
::
|
::
|
||||||
|
|
||||||
policy = PGPolicy(...) # or other policies if you wish
|
policy = PGPolicy(...) # or other policies if you wish
|
||||||
@ -55,7 +70,8 @@ class Collector(object):
|
|||||||
Please make sure the given environment has a time limitation.
|
Please make sure the given environment has a time limitation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, policy, env, buffer=None, stat_size=100):
|
def __init__(self, policy, env, buffer=None, stat_size=100,
|
||||||
|
store_obs_next=True, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.env = env
|
self.env = env
|
||||||
self.env_num = 1
|
self.env_num = 1
|
||||||
@ -90,6 +106,7 @@ class Collector(object):
|
|||||||
self.state = None
|
self.state = None
|
||||||
self.step_speed = MovAvg(stat_size)
|
self.step_speed = MovAvg(stat_size)
|
||||||
self.episode_speed = MovAvg(stat_size)
|
self.episode_speed = MovAvg(stat_size)
|
||||||
|
self._save_s_ = store_obs_next
|
||||||
|
|
||||||
def reset_buffer(self):
|
def reset_buffer(self):
|
||||||
"""Reset the main data buffer."""
|
"""Reset the main data buffer."""
|
||||||
@ -141,11 +158,12 @@ class Collector(object):
|
|||||||
def collect(self, n_step=0, n_episode=0, render=0):
|
def collect(self, n_step=0, n_episode=0, render=0):
|
||||||
"""Collect a specified number of step or episode.
|
"""Collect a specified number of step or episode.
|
||||||
|
|
||||||
:param n_step: an int, indicates how many steps you want to collect.
|
:param int n_step: how many steps you want to collect.
|
||||||
:param n_episode: an int or a list, indicates how many episodes you
|
:param n_episode: how many episodes you want to collect (in each
|
||||||
want to collect (in each environment).
|
environment).
|
||||||
:param render: a float, the sleep time between rendering consecutive
|
:type n_episode: int or list
|
||||||
frames. ``0`` means no rendering.
|
:param float render: the sleep time between rendering consecutive
|
||||||
|
frames. No rendering if it is ``0`` (default option).
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
@ -210,7 +228,8 @@ class Collector(object):
|
|||||||
data = {
|
data = {
|
||||||
'obs': self._obs[i], 'act': self._act[i],
|
'obs': self._obs[i], 'act': self._act[i],
|
||||||
'rew': self._rew[i], 'done': self._done[i],
|
'rew': self._rew[i], 'done': self._done[i],
|
||||||
'obs_next': obs_next[i], 'info': self._info[i]}
|
'obs_next': obs_next[i] if self._save_s_ else None,
|
||||||
|
'info': self._info[i]}
|
||||||
if self._cached_buf:
|
if self._cached_buf:
|
||||||
warning_count += 1
|
warning_count += 1
|
||||||
self._cached_buf[i].add(**data)
|
self._cached_buf[i].add(**data)
|
||||||
@ -255,7 +274,8 @@ class Collector(object):
|
|||||||
else:
|
else:
|
||||||
self.buffer.add(
|
self.buffer.add(
|
||||||
self._obs, self._act[0], self._rew,
|
self._obs, self._act[0], self._rew,
|
||||||
self._done, obs_next, self._info)
|
self._done, obs_next if self._save_s_ else None,
|
||||||
|
self._info)
|
||||||
cur_step += 1
|
cur_step += 1
|
||||||
if self._done:
|
if self._done:
|
||||||
cur_episode += 1
|
cur_episode += 1
|
||||||
@ -296,9 +316,9 @@ class Collector(object):
|
|||||||
:meth:`~tianshou.policy.BasePolicy.process_fn` before returning
|
:meth:`~tianshou.policy.BasePolicy.process_fn` before returning
|
||||||
the final batch data.
|
the final batch data.
|
||||||
|
|
||||||
:param batch_size: an int, ``0`` means it will extract all the data
|
:param int batch_size: ``0`` means it will extract all the data from
|
||||||
from the buffer, otherwise it will extract the given batch_size of
|
the buffer, otherwise it will extract the data with the given
|
||||||
data.
|
batch_size.
|
||||||
"""
|
"""
|
||||||
if self._multi_buf:
|
if self._multi_buf:
|
||||||
if batch_size > 0:
|
if batch_size > 0:
|
||||||
|
|||||||
3
tianshou/env/vecenv.py
vendored
3
tianshou/env/vecenv.py
vendored
@ -60,8 +60,7 @@ class BaseVectorEnv(ABC, gym.Wrapper):
|
|||||||
|
|
||||||
Accept a batch of action and return a tuple (obs, rew, done, info).
|
Accept a batch of action and return a tuple (obs, rew, done, info).
|
||||||
|
|
||||||
:param action: a numpy.ndarray, a batch of action provided by the
|
:param numpy.ndarray action: a batch of action provided by the agent.
|
||||||
agent.
|
|
||||||
|
|
||||||
:return: A tuple including four items:
|
:return: A tuple including four items:
|
||||||
|
|
||||||
|
|||||||
@ -7,12 +7,26 @@ from tianshou.policy import PGPolicy
|
|||||||
|
|
||||||
|
|
||||||
class A2CPolicy(PGPolicy):
|
class A2CPolicy(PGPolicy):
|
||||||
"""docstring for A2CPolicy"""
|
"""Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783
|
||||||
|
|
||||||
|
:param torch.nn.Module actor: the actor network following the rules in
|
||||||
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||||
|
:param torch.nn.Module critic: the critic network. (s -> V(s))
|
||||||
|
:param torch.optim.Optimizer optim: the optimizer for actor and critic
|
||||||
|
network.
|
||||||
|
:param torch.distributions.Distribution dist_fn: for computing the action,
|
||||||
|
defaults to ``torch.distributions.Categorical``.
|
||||||
|
:param float discount_factor: in [0, 1], defaults to 0.99.
|
||||||
|
:param float vf_coef: weight for value loss, defaults to 0.5.
|
||||||
|
:param float ent_coef: weight for entropy loss, defaults to 0.01.
|
||||||
|
:param float max_grad_norm: clipping gradients in back propagation,
|
||||||
|
defaults to ``None``.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, actor, critic, optim,
|
def __init__(self, actor, critic, optim,
|
||||||
dist_fn=torch.distributions.Categorical,
|
dist_fn=torch.distributions.Categorical,
|
||||||
discount_factor=0.99, vf_coef=.5, ent_coef=.01,
|
discount_factor=0.99, vf_coef=.5, ent_coef=.01,
|
||||||
max_grad_norm=None):
|
max_grad_norm=None, **kwargs):
|
||||||
super().__init__(None, optim, dist_fn, discount_factor)
|
super().__init__(None, optim, dist_fn, discount_factor)
|
||||||
self.actor = actor
|
self.actor = actor
|
||||||
self.critic = critic
|
self.critic = critic
|
||||||
@ -20,13 +34,28 @@ class A2CPolicy(PGPolicy):
|
|||||||
self._w_ent = ent_coef
|
self._w_ent = ent_coef
|
||||||
self._grad_norm = max_grad_norm
|
self._grad_norm = max_grad_norm
|
||||||
|
|
||||||
def __call__(self, batch, state=None):
|
def __call__(self, batch, state=None, **kwargs):
|
||||||
|
"""Compute action over the given batch data.
|
||||||
|
|
||||||
|
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
|
||||||
|
|
||||||
|
* ``act`` the action.
|
||||||
|
* ``logits`` the network's raw output.
|
||||||
|
* ``dist`` the action distribution.
|
||||||
|
* ``state`` the hidden state.
|
||||||
|
|
||||||
|
More information can be found at
|
||||||
|
:meth:`~tianshou.policy.BasePolicy.__call__`.
|
||||||
|
"""
|
||||||
logits, h = self.actor(batch.obs, state=state, info=batch.info)
|
logits, h = self.actor(batch.obs, state=state, info=batch.info)
|
||||||
dist = self.dist_fn(logits)
|
if isinstance(logits, tuple):
|
||||||
|
dist = self.dist_fn(*logits)
|
||||||
|
else:
|
||||||
|
dist = self.dist_fn(logits)
|
||||||
act = dist.sample()
|
act = dist.sample()
|
||||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||||
|
|
||||||
def learn(self, batch, batch_size=None, repeat=1):
|
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
|
||||||
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
||||||
for _ in range(repeat):
|
for _ in range(repeat):
|
||||||
for b in batch.split(batch_size):
|
for b in batch.split(batch_size):
|
||||||
|
|||||||
@ -3,23 +3,74 @@ from abc import ABC, abstractmethod
|
|||||||
|
|
||||||
|
|
||||||
class BasePolicy(ABC, nn.Module):
|
class BasePolicy(ABC, nn.Module):
|
||||||
"""docstring for BasePolicy"""
|
"""Tianshou aims to modularizing RL algorithms. It comes into several
|
||||||
|
classes of policies in Tianshou. All of the policy classes must inherit
|
||||||
|
:class:`~tianshou.policy.BasePolicy`.
|
||||||
|
|
||||||
def __init__(self):
|
A policy class typically has four parts:
|
||||||
|
|
||||||
|
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, \
|
||||||
|
including coping the target network and so on;
|
||||||
|
* :meth:`~tianshou.policy.BasePolicy.__call__`: compute action with given \
|
||||||
|
observation;
|
||||||
|
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from \
|
||||||
|
the replay buffer (this function can interact with replay buffer);
|
||||||
|
* :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given \
|
||||||
|
batch of data.
|
||||||
|
|
||||||
|
Most of the policy needs a neural network to predict the action and an
|
||||||
|
optimizer to optimize the policy. The rules of self-defined networks are:
|
||||||
|
|
||||||
|
1. Input: observation ``obs`` (may be a ``numpy.ndarray`` or \
|
||||||
|
``torch.Tensor``), hidden state ``state`` (for RNN usage), and other \
|
||||||
|
information ``info`` provided by the environment.
|
||||||
|
2. Output: some ``logits`` and the next hidden state ``state``. The logits\
|
||||||
|
could be a tuple instead of a ``torch.Tensor``. It depends on how the \
|
||||||
|
policy process the network output. For example, in PPO, the return of \
|
||||||
|
the network might be ``(mu, sigma), state`` for Gaussian policy.
|
||||||
|
|
||||||
|
Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``,
|
||||||
|
you can operate :class:`~tianshou.policy.BasePolicy` almost the same as
|
||||||
|
``torch.nn.Module``, for instance, load and save the model:
|
||||||
|
::
|
||||||
|
|
||||||
|
torch.save(policy.state_dict(), 'policy.pth')
|
||||||
|
policy.load_state_dict(torch.load('policy.pth'))
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def process_fn(self, batch, buffer, indice):
|
def process_fn(self, batch, buffer, indice):
|
||||||
|
"""Pre-process the data from the provided replay buffer. Check out
|
||||||
|
:ref:`policy_concept` for more information.
|
||||||
|
"""
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __call__(self, batch, state=None):
|
def __call__(self, batch, state=None, **kwargs):
|
||||||
# return Batch(logits=..., act=..., state=None, ...)
|
"""Compute action over the given batch data.
|
||||||
|
|
||||||
|
:return: A :class:`~tianshou.data.Batch` which MUST have the following\
|
||||||
|
keys:
|
||||||
|
|
||||||
|
* ``act`` an numpy.ndarray or a torch.Tensor, the action over \
|
||||||
|
given batch data.
|
||||||
|
* ``state`` a dict, an numpy.ndarray or a torch.Tensor, the \
|
||||||
|
internal state of the policy, ``None`` as default.
|
||||||
|
|
||||||
|
Other keys are user-defined. It depends on the algorithm. For example,
|
||||||
|
::
|
||||||
|
|
||||||
|
# some code
|
||||||
|
return Batch(logits=..., act=..., state=None, dist=...)
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def learn(self, batch, batch_size=None):
|
def learn(self, batch, **kwargs):
|
||||||
# return a dict which includes loss and its name
|
"""Update policy with a given batch of data.
|
||||||
pass
|
|
||||||
|
|
||||||
def sync_weight(self):
|
:return: A dict which includes loss and its corresponding label.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -5,18 +5,35 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from tianshou.data import Batch
|
from tianshou.data import Batch
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
|
|
||||||
|
|
||||||
# from tianshou.exploration import OUNoise
|
# from tianshou.exploration import OUNoise
|
||||||
|
|
||||||
|
|
||||||
class DDPGPolicy(BasePolicy):
|
class DDPGPolicy(BasePolicy):
|
||||||
"""docstring for DDPGPolicy"""
|
"""Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971
|
||||||
|
|
||||||
|
:param torch.nn.Module actor: the actor network following the rules in
|
||||||
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||||
|
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
|
||||||
|
:param torch.nn.Module critic: the critic network. (s, a -> Q(s, a))
|
||||||
|
:param torch.optim.Optimizer critic_optim: the optimizer for critic
|
||||||
|
network.
|
||||||
|
:param float tau: param for soft update of the target network, defaults to
|
||||||
|
0.005.
|
||||||
|
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
|
||||||
|
:param float exploration_noise: the noise intensity, add to the action,
|
||||||
|
defaults to 0.1.
|
||||||
|
:param action_range: the action range (minimum, maximum).
|
||||||
|
:type action_range: [float, float]
|
||||||
|
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||||
|
defaults to ``False``.
|
||||||
|
:param bool ignore_done: ignore the done flag while training the policy,
|
||||||
|
defaults to ``False``.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, actor, actor_optim, critic, critic_optim,
|
def __init__(self, actor, actor_optim, critic, critic_optim,
|
||||||
tau=0.005, gamma=0.99, exploration_noise=0.1,
|
tau=0.005, gamma=0.99, exploration_noise=0.1,
|
||||||
action_range=None, reward_normalization=False,
|
action_range=None, reward_normalization=False,
|
||||||
ignore_done=False):
|
ignore_done=False, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if actor is not None:
|
if actor is not None:
|
||||||
self.actor, self.actor_old = actor, deepcopy(actor)
|
self.actor, self.actor_old = actor, deepcopy(actor)
|
||||||
@ -26,9 +43,9 @@ class DDPGPolicy(BasePolicy):
|
|||||||
self.critic, self.critic_old = critic, deepcopy(critic)
|
self.critic, self.critic_old = critic, deepcopy(critic)
|
||||||
self.critic_old.eval()
|
self.critic_old.eval()
|
||||||
self.critic_optim = critic_optim
|
self.critic_optim = critic_optim
|
||||||
assert 0 < tau <= 1, 'tau should in (0, 1]'
|
assert 0 <= tau <= 1, 'tau should in [0, 1]'
|
||||||
self._tau = tau
|
self._tau = tau
|
||||||
assert 0 < gamma <= 1, 'gamma should in (0, 1]'
|
assert 0 <= gamma <= 1, 'gamma should in [0, 1]'
|
||||||
self._gamma = gamma
|
self._gamma = gamma
|
||||||
assert 0 <= exploration_noise, 'noise should not be negative'
|
assert 0 <= exploration_noise, 'noise should not be negative'
|
||||||
self._eps = exploration_noise
|
self._eps = exploration_noise
|
||||||
@ -43,19 +60,23 @@ class DDPGPolicy(BasePolicy):
|
|||||||
self.__eps = np.finfo(np.float32).eps.item()
|
self.__eps = np.finfo(np.float32).eps.item()
|
||||||
|
|
||||||
def set_eps(self, eps):
|
def set_eps(self, eps):
|
||||||
|
"""Set the eps for exploration."""
|
||||||
self._eps = eps
|
self._eps = eps
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
|
"""Set the module in training mode, except for the target network."""
|
||||||
self.training = True
|
self.training = True
|
||||||
self.actor.train()
|
self.actor.train()
|
||||||
self.critic.train()
|
self.critic.train()
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
|
"""Set the module in evaluation mode, except for the target network."""
|
||||||
self.training = False
|
self.training = False
|
||||||
self.actor.eval()
|
self.actor.eval()
|
||||||
self.critic.eval()
|
self.critic.eval()
|
||||||
|
|
||||||
def sync_weight(self):
|
def sync_weight(self):
|
||||||
|
"""Soft-update the weight for the target network."""
|
||||||
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
|
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
|
||||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||||
for o, n in zip(
|
for o, n in zip(
|
||||||
@ -73,7 +94,19 @@ class DDPGPolicy(BasePolicy):
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
def __call__(self, batch, state=None,
|
def __call__(self, batch, state=None,
|
||||||
model='actor', input='obs', eps=None):
|
model='actor', input='obs', eps=None, **kwargs):
|
||||||
|
"""Compute action over the given batch data.
|
||||||
|
|
||||||
|
:param float eps: in [0, 1], for exploration use.
|
||||||
|
|
||||||
|
:return: A :class:`~tianshou.data.Batch` which has 2 keys:
|
||||||
|
|
||||||
|
* ``act`` the action.
|
||||||
|
* ``state`` the hidden state.
|
||||||
|
|
||||||
|
More information can be found at
|
||||||
|
:meth:`~tianshou.policy.BasePolicy.__call__`.
|
||||||
|
"""
|
||||||
model = getattr(self, model)
|
model = getattr(self, model)
|
||||||
obs = getattr(batch, input)
|
obs = getattr(batch, input)
|
||||||
logits, h = model(obs, state=state, info=batch.info)
|
logits, h = model(obs, state=state, info=batch.info)
|
||||||
@ -89,7 +122,7 @@ class DDPGPolicy(BasePolicy):
|
|||||||
logits = logits.clamp(self._range[0], self._range[1])
|
logits = logits.clamp(self._range[0], self._range[1])
|
||||||
return Batch(act=logits, state=h)
|
return Batch(act=logits, state=h)
|
||||||
|
|
||||||
def learn(self, batch, batch_size=None, repeat=1):
|
def learn(self, batch, **kwargs):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
target_q = self.critic_old(batch.obs_next, self(
|
target_q = self.critic_old(batch.obs_next, self(
|
||||||
batch, model='actor_old', input='obs_next', eps=0).act)
|
batch, model='actor_old', input='obs_next', eps=0).act)
|
||||||
|
|||||||
@ -8,11 +8,20 @@ from tianshou.policy import BasePolicy
|
|||||||
|
|
||||||
|
|
||||||
class DQNPolicy(BasePolicy):
|
class DQNPolicy(BasePolicy):
|
||||||
"""docstring for DQNPolicy"""
|
"""Implementation of Deep Q Network. arXiv:1312.5602
|
||||||
|
|
||||||
|
:param torch.nn.Module model: a model following the rules in
|
||||||
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||||
|
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
||||||
|
:param float discount_factor: in [0, 1].
|
||||||
|
:param int estimation_step: greater than 1, the number of steps to look
|
||||||
|
ahead.
|
||||||
|
:param int target_update_freq: the target network update frequency (``0``
|
||||||
|
if you do not use the target network).
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, model, optim, discount_factor=0.99,
|
def __init__(self, model, optim, discount_factor=0.99,
|
||||||
estimation_step=1, use_target_network=True,
|
estimation_step=1, target_update_freq=0, **kwargs):
|
||||||
target_update_freq=300):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.optim = optim
|
self.optim = optim
|
||||||
@ -21,28 +30,44 @@ class DQNPolicy(BasePolicy):
|
|||||||
self._gamma = discount_factor
|
self._gamma = discount_factor
|
||||||
assert estimation_step > 0, 'estimation_step should greater than 0'
|
assert estimation_step > 0, 'estimation_step should greater than 0'
|
||||||
self._n_step = estimation_step
|
self._n_step = estimation_step
|
||||||
self._target = use_target_network
|
self._target = target_update_freq > 0
|
||||||
self._freq = target_update_freq
|
self._freq = target_update_freq
|
||||||
self._cnt = 0
|
self._cnt = 0
|
||||||
if use_target_network:
|
if self._target:
|
||||||
self.model_old = deepcopy(self.model)
|
self.model_old = deepcopy(self.model)
|
||||||
self.model_old.eval()
|
self.model_old.eval()
|
||||||
|
|
||||||
def set_eps(self, eps):
|
def set_eps(self, eps):
|
||||||
|
"""Set the eps for epsilon-greedy exploration."""
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
|
"""Set the module in training mode, except for the target network."""
|
||||||
self.training = True
|
self.training = True
|
||||||
self.model.train()
|
self.model.train()
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
|
"""Set the module in evaluation mode, except for the target network."""
|
||||||
self.training = False
|
self.training = False
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
def sync_weight(self):
|
def sync_weight(self):
|
||||||
|
"""Synchronize the weight for the target network."""
|
||||||
self.model_old.load_state_dict(self.model.state_dict())
|
self.model_old.load_state_dict(self.model.state_dict())
|
||||||
|
|
||||||
def process_fn(self, batch, buffer, indice):
|
def process_fn(self, batch, buffer, indice):
|
||||||
|
r"""Compute the n-step return for Q-learning targets:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i +
|
||||||
|
\gamma^n (1 - d_{t + n}) \max_a Q_{old}(s_{t + n}, \arg\max_a
|
||||||
|
(Q_{new}(s_{t + n}, a)))
|
||||||
|
|
||||||
|
, where :math:`\gamma` is the discount factor,
|
||||||
|
:math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step
|
||||||
|
:math:`t`. If there is no target network, the :math:`Q_{old}` is equal
|
||||||
|
to :math:`Q_{new}`.
|
||||||
|
"""
|
||||||
returns = np.zeros_like(indice)
|
returns = np.zeros_like(indice)
|
||||||
gammas = np.zeros_like(indice) + self._n_step
|
gammas = np.zeros_like(indice) + self._n_step
|
||||||
for n in range(self._n_step - 1, -1, -1):
|
for n in range(self._n_step - 1, -1, -1):
|
||||||
@ -70,7 +95,20 @@ class DQNPolicy(BasePolicy):
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
def __call__(self, batch, state=None,
|
def __call__(self, batch, state=None,
|
||||||
model='model', input='obs', eps=None):
|
model='model', input='obs', eps=None, **kwargs):
|
||||||
|
"""Compute action over the given batch data.
|
||||||
|
|
||||||
|
:param float eps: in [0, 1], for epsilon-greedy exploration method.
|
||||||
|
|
||||||
|
:return: A :class:`~tianshou.data.Batch` which has 3 keys:
|
||||||
|
|
||||||
|
* ``act`` the action.
|
||||||
|
* ``logits`` the network's raw output.
|
||||||
|
* ``state`` the hidden state.
|
||||||
|
|
||||||
|
More information can be found at
|
||||||
|
:meth:`~tianshou.policy.BasePolicy.__call__`.
|
||||||
|
"""
|
||||||
model = getattr(self, model)
|
model = getattr(self, model)
|
||||||
obs = getattr(batch, input)
|
obs = getattr(batch, input)
|
||||||
q, h = model(obs, state=state, info=batch.info)
|
q, h = model(obs, state=state, info=batch.info)
|
||||||
@ -83,7 +121,7 @@ class DQNPolicy(BasePolicy):
|
|||||||
act[i] = np.random.randint(q.shape[1])
|
act[i] = np.random.randint(q.shape[1])
|
||||||
return Batch(logits=q, act=act, state=h)
|
return Batch(logits=q, act=act, state=h)
|
||||||
|
|
||||||
def learn(self, batch, batch_size=None, repeat=1):
|
def learn(self, batch, **kwargs):
|
||||||
if self._target and self._cnt % self._freq == 0:
|
if self._target and self._cnt % self._freq == 0:
|
||||||
self.sync_weight()
|
self.sync_weight()
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
|
|||||||
@ -1,37 +1,65 @@
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from tianshou.data import Batch
|
from tianshou.data import Batch
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
|
|
||||||
|
|
||||||
class PGPolicy(BasePolicy):
|
class PGPolicy(BasePolicy):
|
||||||
"""docstring for PGPolicy"""
|
"""Implementation of Vanilla Policy Gradient.
|
||||||
|
|
||||||
|
:param torch.nn.Module model: a model following the rules in
|
||||||
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||||
|
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
||||||
|
:param torch.distributions.Distribution dist_fn: for computing the action.
|
||||||
|
:param float discount_factor: in [0, 1].
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
|
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
|
||||||
discount_factor=0.99):
|
discount_factor=0.99, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.optim = optim
|
self.optim = optim
|
||||||
self.dist_fn = dist_fn
|
self.dist_fn = dist_fn
|
||||||
self._eps = np.finfo(np.float32).eps.item()
|
self._eps = np.finfo(np.float32).eps.item()
|
||||||
assert 0 < discount_factor <= 1, 'discount_factor should in (0, 1]'
|
assert 0 <= discount_factor <= 1, 'discount factor should in [0, 1]'
|
||||||
self._gamma = discount_factor
|
self._gamma = discount_factor
|
||||||
|
|
||||||
def process_fn(self, batch, buffer, indice):
|
def process_fn(self, batch, buffer, indice):
|
||||||
|
r"""Compute the discounted returns for each frame:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
G_t = \sum_{i=t}^T \gamma^{i-t}r_i
|
||||||
|
|
||||||
|
, where :math:`T` is the terminal time step, :math:`\gamma` is the
|
||||||
|
discount factor, :math:`\gamma \in [0, 1]`.
|
||||||
|
"""
|
||||||
batch.returns = self._vanilla_returns(batch)
|
batch.returns = self._vanilla_returns(batch)
|
||||||
# batch.returns = self._vectorized_returns(batch)
|
# batch.returns = self._vectorized_returns(batch)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def __call__(self, batch, state=None):
|
def __call__(self, batch, state=None, **kwargs):
|
||||||
|
"""Compute action over the given batch data.
|
||||||
|
|
||||||
|
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
|
||||||
|
|
||||||
|
* ``act`` the action.
|
||||||
|
* ``logits`` the network's raw output.
|
||||||
|
* ``dist`` the action distribution.
|
||||||
|
* ``state`` the hidden state.
|
||||||
|
|
||||||
|
More information can be found at
|
||||||
|
:meth:`~tianshou.policy.BasePolicy.__call__`.
|
||||||
|
"""
|
||||||
logits, h = self.model(batch.obs, state=state, info=batch.info)
|
logits, h = self.model(batch.obs, state=state, info=batch.info)
|
||||||
logits = F.softmax(logits, dim=1)
|
if isinstance(logits, tuple):
|
||||||
dist = self.dist_fn(logits)
|
dist = self.dist_fn(*logits)
|
||||||
|
else:
|
||||||
|
dist = self.dist_fn(logits)
|
||||||
act = dist.sample()
|
act = dist.sample()
|
||||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||||
|
|
||||||
def learn(self, batch, batch_size=None, repeat=1):
|
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
|
||||||
losses = []
|
losses = []
|
||||||
r = batch.returns
|
r = batch.returns
|
||||||
batch.returns = (r - r.mean()) / (r.std() + self._eps)
|
batch.returns = (r - r.mean()) / (r.std() + self._eps)
|
||||||
@ -57,7 +85,7 @@ class PGPolicy(BasePolicy):
|
|||||||
return returns
|
return returns
|
||||||
|
|
||||||
def _vectorized_returns(self, batch):
|
def _vectorized_returns(self, batch):
|
||||||
# according to my tests, it is slower than vanilla
|
# according to my tests, it is slower than _vanilla_returns
|
||||||
# import scipy.signal
|
# import scipy.signal
|
||||||
convolve = np.convolve
|
convolve = np.convolve
|
||||||
# convolve = scipy.signal.convolve
|
# convolve = scipy.signal.convolve
|
||||||
|
|||||||
@ -9,7 +9,24 @@ from tianshou.policy import PGPolicy
|
|||||||
|
|
||||||
|
|
||||||
class PPOPolicy(PGPolicy):
|
class PPOPolicy(PGPolicy):
|
||||||
"""docstring for PPOPolicy"""
|
r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347
|
||||||
|
|
||||||
|
:param torch.nn.Module actor: the actor network following the rules in
|
||||||
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||||
|
:param torch.nn.Module critic: the critic network. (s -> V(s))
|
||||||
|
:param torch.optim.Optimizer optim: the optimizer for actor and critic
|
||||||
|
network.
|
||||||
|
:param torch.distributions.Distribution dist_fn: for computing the action.
|
||||||
|
:param float discount_factor: in [0, 1], defaults to 0.99.
|
||||||
|
:param float max_grad_norm: clipping gradients in back propagation,
|
||||||
|
defaults to ``None``.
|
||||||
|
:param float eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original
|
||||||
|
paper, defaults to 0.2.
|
||||||
|
:param float vf_coef: weight for value loss, defaults to 0.5.
|
||||||
|
:param float ent_coef: weight for entropy loss, defaults to 0.01.
|
||||||
|
:param action_range: the action range (minimum, maximum).
|
||||||
|
:type action_range: [float, float]
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, actor, critic, optim, dist_fn,
|
def __init__(self, actor, critic, optim, dist_fn,
|
||||||
discount_factor=0.99,
|
discount_factor=0.99,
|
||||||
@ -17,7 +34,8 @@ class PPOPolicy(PGPolicy):
|
|||||||
eps_clip=.2,
|
eps_clip=.2,
|
||||||
vf_coef=.5,
|
vf_coef=.5,
|
||||||
ent_coef=.0,
|
ent_coef=.0,
|
||||||
action_range=None):
|
action_range=None,
|
||||||
|
**kwargs):
|
||||||
super().__init__(None, None, dist_fn, discount_factor)
|
super().__init__(None, None, dist_fn, discount_factor)
|
||||||
self._max_grad_norm = max_grad_norm
|
self._max_grad_norm = max_grad_norm
|
||||||
self._eps_clip = eps_clip
|
self._eps_clip = eps_clip
|
||||||
@ -31,16 +49,30 @@ class PPOPolicy(PGPolicy):
|
|||||||
self.optim = optim
|
self.optim = optim
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
|
"""Set the module in training mode, except for the target network."""
|
||||||
self.training = True
|
self.training = True
|
||||||
self.actor.train()
|
self.actor.train()
|
||||||
self.critic.train()
|
self.critic.train()
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
|
"""Set the module in evaluation mode, except for the target network."""
|
||||||
self.training = False
|
self.training = False
|
||||||
self.actor.eval()
|
self.actor.eval()
|
||||||
self.critic.eval()
|
self.critic.eval()
|
||||||
|
|
||||||
def __call__(self, batch, state=None, model='actor'):
|
def __call__(self, batch, state=None, model='actor', **kwargs):
|
||||||
|
"""Compute action over the given batch data.
|
||||||
|
|
||||||
|
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
|
||||||
|
|
||||||
|
* ``act`` the action.
|
||||||
|
* ``logits`` the network's raw output.
|
||||||
|
* ``dist`` the action distribution.
|
||||||
|
* ``state`` the hidden state.
|
||||||
|
|
||||||
|
More information can be found at
|
||||||
|
:meth:`~tianshou.policy.BasePolicy.__call__`.
|
||||||
|
"""
|
||||||
model = getattr(self, model)
|
model = getattr(self, model)
|
||||||
logits, h = model(batch.obs, state=state, info=batch.info)
|
logits, h = model(batch.obs, state=state, info=batch.info)
|
||||||
if isinstance(logits, tuple):
|
if isinstance(logits, tuple):
|
||||||
@ -53,10 +85,11 @@ class PPOPolicy(PGPolicy):
|
|||||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||||
|
|
||||||
def sync_weight(self):
|
def sync_weight(self):
|
||||||
|
"""Synchronize the weight for the target network."""
|
||||||
self.actor_old.load_state_dict(self.actor.state_dict())
|
self.actor_old.load_state_dict(self.actor.state_dict())
|
||||||
self.critic_old.load_state_dict(self.critic.state_dict())
|
self.critic_old.load_state_dict(self.critic.state_dict())
|
||||||
|
|
||||||
def learn(self, batch, batch_size=None, repeat=1):
|
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
|
||||||
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
||||||
r = batch.returns
|
r = batch.returns
|
||||||
batch.returns = (r - r.mean()) / (r.std() + self._eps)
|
batch.returns = (r - r.mean()) / (r.std() + self._eps)
|
||||||
@ -79,7 +112,6 @@ class PPOPolicy(PGPolicy):
|
|||||||
clip_losses.append(clip_loss.item())
|
clip_losses.append(clip_loss.item())
|
||||||
vf_loss = F.smooth_l1_loss(self.critic(b.obs), target_v)
|
vf_loss = F.smooth_l1_loss(self.critic(b.obs), target_v)
|
||||||
vf_losses.append(vf_loss.item())
|
vf_losses.append(vf_loss.item())
|
||||||
|
|
||||||
e_loss = dist.entropy().mean()
|
e_loss = dist.entropy().mean()
|
||||||
ent_losses.append(e_loss.item())
|
ent_losses.append(e_loss.item())
|
||||||
loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss
|
loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss
|
||||||
|
|||||||
@ -8,12 +8,37 @@ from tianshou.policy import DDPGPolicy
|
|||||||
|
|
||||||
|
|
||||||
class SACPolicy(DDPGPolicy):
|
class SACPolicy(DDPGPolicy):
|
||||||
"""docstring for SACPolicy"""
|
"""Implementation of Soft Actor-Critic. arXiv:1812.05905
|
||||||
|
|
||||||
|
:param torch.nn.Module actor: the actor network following the rules in
|
||||||
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||||
|
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
|
||||||
|
:param torch.nn.Module critic1: the first critic network. (s, a -> Q(s,
|
||||||
|
a))
|
||||||
|
:param torch.optim.Optimizer critic1_optim: the optimizer for the first
|
||||||
|
critic network.
|
||||||
|
:param torch.nn.Module critic2: the second critic network. (s, a -> Q(s,
|
||||||
|
a))
|
||||||
|
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
|
||||||
|
critic network.
|
||||||
|
:param float tau: param for soft update of the target network, defaults to
|
||||||
|
0.005.
|
||||||
|
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
|
||||||
|
:param float exploration_noise: the noise intensity, add to the action,
|
||||||
|
defaults to 0.1.
|
||||||
|
:param float alpha: entropy regularization coefficient, default to 0.2.
|
||||||
|
:param action_range: the action range (minimum, maximum).
|
||||||
|
:type action_range: [float, float]
|
||||||
|
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||||
|
defaults to ``False``.
|
||||||
|
:param bool ignore_done: ignore the done flag while training the policy,
|
||||||
|
defaults to ``False``.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, actor, actor_optim, critic1, critic1_optim,
|
def __init__(self, actor, actor_optim, critic1, critic1_optim,
|
||||||
critic2, critic2_optim, tau=0.005, gamma=0.99,
|
critic2, critic2_optim, tau=0.005, gamma=0.99,
|
||||||
alpha=0.2, action_range=None, reward_normalization=False,
|
alpha=0.2, action_range=None, reward_normalization=False,
|
||||||
ignore_done=False):
|
ignore_done=False, **kwargs):
|
||||||
super().__init__(None, None, None, None, tau, gamma, 0,
|
super().__init__(None, None, None, None, tau, gamma, 0,
|
||||||
action_range, reward_normalization, ignore_done)
|
action_range, reward_normalization, ignore_done)
|
||||||
self.actor, self.actor_optim = actor, actor_optim
|
self.actor, self.actor_optim = actor, actor_optim
|
||||||
@ -46,12 +71,11 @@ class SACPolicy(DDPGPolicy):
|
|||||||
self.critic2_old.parameters(), self.critic2.parameters()):
|
self.critic2_old.parameters(), self.critic2.parameters()):
|
||||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||||
|
|
||||||
def __call__(self, batch, state=None, input='obs'):
|
def __call__(self, batch, state=None, input='obs', **kwargs):
|
||||||
obs = getattr(batch, input)
|
obs = getattr(batch, input)
|
||||||
logits, h = self.actor(obs, state=state, info=batch.info)
|
logits, h = self.actor(obs, state=state, info=batch.info)
|
||||||
assert isinstance(logits, tuple)
|
assert isinstance(logits, tuple)
|
||||||
dist = torch.distributions.Normal(*logits)
|
dist = torch.distributions.Normal(*logits)
|
||||||
|
|
||||||
x = dist.rsample()
|
x = dist.rsample()
|
||||||
y = torch.tanh(x)
|
y = torch.tanh(x)
|
||||||
act = y * self._action_scale + self._action_bias
|
act = y * self._action_scale + self._action_bias
|
||||||
@ -61,7 +85,7 @@ class SACPolicy(DDPGPolicy):
|
|||||||
return Batch(
|
return Batch(
|
||||||
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
|
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
|
||||||
|
|
||||||
def learn(self, batch, batch_size=None, repeat=1):
|
def learn(self, batch, **kwargs):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
obs_next_result = self(batch, input='obs_next')
|
obs_next_result = self(batch, input='obs_next')
|
||||||
a_ = obs_next_result.act
|
a_ = obs_next_result.act
|
||||||
|
|||||||
@ -6,13 +6,44 @@ from tianshou.policy import DDPGPolicy
|
|||||||
|
|
||||||
|
|
||||||
class TD3Policy(DDPGPolicy):
|
class TD3Policy(DDPGPolicy):
|
||||||
"""docstring for TD3Policy"""
|
"""Implementation of Twin Delayed Deep Deterministic Policy Gradient,
|
||||||
|
arXiv:1802.09477
|
||||||
|
|
||||||
|
:param torch.nn.Module actor: the actor network following the rules in
|
||||||
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||||
|
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
|
||||||
|
:param torch.nn.Module critic1: the first critic network. (s, a -> Q(s,
|
||||||
|
a))
|
||||||
|
:param torch.optim.Optimizer critic1_optim: the optimizer for the first
|
||||||
|
critic network.
|
||||||
|
:param torch.nn.Module critic2: the second critic network. (s, a -> Q(s,
|
||||||
|
a))
|
||||||
|
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
|
||||||
|
critic network.
|
||||||
|
:param float tau: param for soft update of the target network, defaults to
|
||||||
|
0.005.
|
||||||
|
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
|
||||||
|
:param float exploration_noise: the noise intensity, add to the action,
|
||||||
|
defaults to 0.1.
|
||||||
|
:param float policy_noise: the noise used in updating policy network,
|
||||||
|
default to 0.2.
|
||||||
|
:param int update_actor_freq: the update frequency of actor network,
|
||||||
|
default to 2.
|
||||||
|
:param float noise_clip: the clipping range used in updating policy
|
||||||
|
network, default to 0.5.
|
||||||
|
:param action_range: the action range (minimum, maximum).
|
||||||
|
:type action_range: [float, float]
|
||||||
|
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||||
|
defaults to ``False``.
|
||||||
|
:param bool ignore_done: ignore the done flag while training the policy,
|
||||||
|
defaults to ``False``.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, actor, actor_optim, critic1, critic1_optim,
|
def __init__(self, actor, actor_optim, critic1, critic1_optim,
|
||||||
critic2, critic2_optim, tau=0.005, gamma=0.99,
|
critic2, critic2_optim, tau=0.005, gamma=0.99,
|
||||||
exploration_noise=0.1, policy_noise=0.2, update_actor_freq=2,
|
exploration_noise=0.1, policy_noise=0.2, update_actor_freq=2,
|
||||||
noise_clip=0.5, action_range=None,
|
noise_clip=0.5, action_range=None,
|
||||||
reward_normalization=False, ignore_done=False):
|
reward_normalization=False, ignore_done=False, **kwargs):
|
||||||
super().__init__(actor, actor_optim, None, None, tau, gamma,
|
super().__init__(actor, actor_optim, None, None, tau, gamma,
|
||||||
exploration_noise, action_range, reward_normalization,
|
exploration_noise, action_range, reward_normalization,
|
||||||
ignore_done)
|
ignore_done)
|
||||||
@ -50,7 +81,7 @@ class TD3Policy(DDPGPolicy):
|
|||||||
self.critic2_old.parameters(), self.critic2.parameters()):
|
self.critic2_old.parameters(), self.critic2.parameters()):
|
||||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||||
|
|
||||||
def learn(self, batch, batch_size=None, repeat=1):
|
def learn(self, batch, **kwargs):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
a_ = self(batch, model='actor_old', input='obs_next').act
|
a_ = self(batch, model='actor_old', input='obs_next').act
|
||||||
dev = a_.device
|
dev = a_.device
|
||||||
|
|||||||
@ -12,34 +12,35 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
**kwargs):
|
**kwargs):
|
||||||
"""A wrapper for off-policy trainer procedure.
|
"""A wrapper for off-policy trainer procedure.
|
||||||
|
|
||||||
Parameters
|
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
||||||
* **policy** – an instance of the :class:`~tianshou.policy.BasePolicy`\
|
class.
|
||||||
class.
|
:param train_collector: the collector used for training.
|
||||||
* **train_collector** – the collector used for training.
|
:type train_collector: :class:`~tianshou.data.Collector`
|
||||||
* **test_collector** – the collector used for testing.
|
:param test_collector: the collector used for testing.
|
||||||
* **max_epoch** – the maximum of epochs for training. The training \
|
:type test_collector: :class:`~tianshou.data.Collector`
|
||||||
process might be finished before reaching the ``max_epoch``.
|
:param int max_epoch: the maximum of epochs for training. The training
|
||||||
* **step_per_epoch** – the number of step for updating policy network \
|
process might be finished before reaching the ``max_epoch``.
|
||||||
in one epoch.
|
:param int step_per_epoch: the number of step for updating policy network
|
||||||
* **collect_per_step** – the number of frames the collector would \
|
in one epoch.
|
||||||
collect before the network update. In other words, collect some \
|
:param int collect_per_step: the number of frames the collector would
|
||||||
frames and do one policy network update.
|
collect before the network update. In other words, collect some frames
|
||||||
* **episode_per_test** – the number of episodes for one policy \
|
and do one policy network update.
|
||||||
evaluation.
|
:param episode_per_test: the number of episodes for one policy evaluation.
|
||||||
* **batch_size** – the batch size of sample data, which is going to \
|
:param int batch_size: the batch size of sample data, which is going to
|
||||||
feed in the policy network.
|
feed in the policy network.
|
||||||
* **train_fn** – a function receives the current number of epoch index\
|
:param function train_fn: a function receives the current number of epoch
|
||||||
and performs some operations at the beginning of training in this \
|
index and performs some operations at the beginning of training in this
|
||||||
epoch.
|
epoch.
|
||||||
* **test_fn** – a function receives the current number of epoch index \
|
:param function test_fn: a function receives the current number of epoch
|
||||||
and performs some operations at the beginning of testing in this \
|
index and performs some operations at the beginning of testing in this
|
||||||
epoch.
|
epoch.
|
||||||
* **stop_fn** – a function receives the average undiscounted returns \
|
:param function stop_fn: a function receives the average undiscounted
|
||||||
of the testing result, return a boolean which indicates whether \
|
returns of the testing result, return a boolean which indicates whether
|
||||||
reaching the goal.
|
reaching the goal.
|
||||||
* **writer** – a SummaryWriter provided from TensorBoard.
|
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
|
||||||
* **log_interval** – an int indicating the log interval of the writer.
|
SummaryWriter.
|
||||||
* **verbose** – a boolean indicating whether to print the information.
|
:param int log_interval: the log interval of the writer.
|
||||||
|
:param bool verbose: whether to print the information.
|
||||||
|
|
||||||
:return: See :func:`~tianshou.trainer.gather_info`.
|
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -13,37 +13,39 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
**kwargs):
|
**kwargs):
|
||||||
"""A wrapper for on-policy trainer procedure.
|
"""A wrapper for on-policy trainer procedure.
|
||||||
|
|
||||||
Parameters
|
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
||||||
* **policy** – an instance of the :class:`~tianshou.policy.BasePolicy`\
|
class.
|
||||||
class.
|
:param train_collector: the collector used for training.
|
||||||
* **train_collector** – the collector used for training.
|
:type train_collector: :class:`~tianshou.data.Collector`
|
||||||
* **test_collector** – the collector used for testing.
|
:param test_collector: the collector used for testing.
|
||||||
* **max_epoch** – the maximum of epochs for training. The training \
|
:type test_collector: :class:`~tianshou.data.Collector`
|
||||||
process might be finished before reaching the ``max_epoch``.
|
:param int max_epoch: the maximum of epochs for training. The training
|
||||||
* **step_per_epoch** – the number of step for updating policy network \
|
process might be finished before reaching the ``max_epoch``.
|
||||||
in one epoch.
|
:param int step_per_epoch: the number of step for updating policy network
|
||||||
* **collect_per_step** – the number of frames the collector would \
|
in one epoch.
|
||||||
collect before the network update. In other words, collect some \
|
:param int collect_per_step: the number of frames the collector would
|
||||||
frames and do one policy network update.
|
collect before the network update. In other words, collect some frames
|
||||||
* **repeat_per_collect** – the number of repeat time for policy \
|
and do one policy network update.
|
||||||
learning, for example, set it to 2 means the policy needs to learn\
|
:param int repeat_per_collect: the number of repeat time for policy
|
||||||
each given batch data twice.
|
learning, for example, set it to 2 means the policy needs to learn each
|
||||||
* **episode_per_test** – the number of episodes for one policy \
|
given batch data twice.
|
||||||
evaluation.
|
:param episode_per_test: the number of episodes for one policy evaluation.
|
||||||
* **batch_size** – the batch size of sample data, which is going to \
|
:type episode_per_test: int or list of ints
|
||||||
feed in the policy network.
|
:param int batch_size: the batch size of sample data, which is going to
|
||||||
* **train_fn** – a function receives the current number of epoch index\
|
feed in the policy network.
|
||||||
and performs some operations at the beginning of training in this \
|
:param function train_fn: a function receives the current number of epoch
|
||||||
epoch.
|
index and performs some operations at the beginning of training in this
|
||||||
* **test_fn** – a function receives the current number of epoch index \
|
epoch.
|
||||||
and performs some operations at the beginning of testing in this \
|
:param function test_fn: a function receives the current number of epoch
|
||||||
epoch.
|
index and performs some operations at the beginning of testing in this
|
||||||
* **stop_fn** – a function receives the average undiscounted returns \
|
epoch.
|
||||||
of the testing result, return a boolean which indicates whether \
|
:param function stop_fn: a function receives the average undiscounted
|
||||||
reaching the goal.
|
returns of the testing result, return a boolean which indicates whether
|
||||||
* **writer** – a SummaryWriter provided from TensorBoard.
|
reaching the goal.
|
||||||
* **log_interval** – an int indicating the log interval of the writer.
|
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
|
||||||
* **verbose** – a boolean indicating whether to print the information.
|
SummaryWriter.
|
||||||
|
:param int log_interval: the log interval of the writer.
|
||||||
|
:param bool verbose: whether to print the information.
|
||||||
|
|
||||||
:return: See :func:`~tianshou.trainer.gather_info`.
|
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -26,7 +26,7 @@ class MovAvg(object):
|
|||||||
def add(self, x):
|
def add(self, x):
|
||||||
"""Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with
|
"""Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with
|
||||||
only one element, a python scalar, or a list of python scalar. It will
|
only one element, a python scalar, or a list of python scalar. It will
|
||||||
exclude the infinity.
|
automatically exclude the infinity.
|
||||||
"""
|
"""
|
||||||
if isinstance(x, torch.Tensor):
|
if isinstance(x, torch.Tensor):
|
||||||
x = x.item()
|
x = x.item()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user