This commit is contained in:
Trinkle23897 2020-05-27 11:02:23 +08:00
parent 6237cc0d52
commit de556fd22d
15 changed files with 31 additions and 29 deletions

View File

@ -20,14 +20,14 @@
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf) - [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) - [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) with n-step returns - [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) with n-step returns
- [Prioritized DQN (PDQN)](https://arxiv.org/pdf/1511.05952.pdf)
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/) - [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf) - [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf) - [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf) - [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf) - [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
- Vanilla Imitation Learning - Vanilla Imitation Learning
- [Generalized Advantage Estimation (GAE)](https://arxiv.org/pdf/1506.02438.pdf) - [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms. Our team is working on supporting more algorithms and more scenarios on Tianshou in this period of development. Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms. Our team is working on supporting more algorithms and more scenarios on Tianshou in this period of development.

View File

@ -11,14 +11,14 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_ * :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_ * :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns * :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns
* :class:`~tianshou.policy.DQNPolicy` `Prioritized DQN <https://arxiv.org/pdf/1511.05952.pdf>`_
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_ * :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_ * :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_ * :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_ * :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_ * :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning * :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimation <https://arxiv.org/pdf/1506.02438.pdf>`_ * :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms. Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms.

View File

@ -20,5 +20,14 @@ def test_batch():
print(batch) print(batch)
def test_batch_over_batch():
batch = Batch(a=[3, 4, 5], b=[4, 5, 6])
batch2 = Batch(b=batch, c=[6, 7, 8])
batch2.b.b[-1] = 0
print(batch2)
assert batch2[-1].b.b == 0
if __name__ == '__main__': if __name__ == '__main__':
test_batch() test_batch()
test_batch_over_batch()

View File

@ -8,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter
from tianshou.env import VectorEnv from tianshou.env import VectorEnv
from tianshou.policy import PPOPolicy from tianshou.policy import PPOPolicy
from tianshou.policy.utils import DiagGaussian from tianshou.policy.dist import DiagGaussian
from tianshou.trainer import onpolicy_trainer from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer

View File

@ -70,7 +70,7 @@ class Batch(object):
super().__init__() super().__init__()
self._meta = {} self._meta = {}
for k, v in kwargs.items(): for k, v in kwargs.items():
if (isinstance(v, list) or isinstance(v, np.ndarray)) \ if isinstance(v, (list, np.ndarray)) \
and len(v) > 0 and isinstance(v[0], dict) and k != 'info': and len(v) > 0 and isinstance(v[0], dict) and k != 'info':
self._meta[k] = list(v[0].keys()) self._meta[k] = list(v[0].keys())
for k_ in v[0].keys(): for k_ in v[0].keys():
@ -78,7 +78,7 @@ class Batch(object):
self.__dict__[k__] = np.array([ self.__dict__[k__] = np.array([
v[i][k_] for i in range(len(v)) v[i][k_] for i in range(len(v))
]) ])
elif isinstance(v, dict) or isinstance(v, Batch): elif isinstance(v, dict):
self._meta[k] = list(v.keys()) self._meta[k] = list(v.keys())
for k_ in v.keys(): for k_ in v.keys():
k__ = '_' + k + '@' + k_ k__ = '_' + k + '@' + k_

View File

@ -151,7 +151,7 @@ class ReplayBuffer(object):
if self.__dict__.get(name, None) is None: if self.__dict__.get(name, None) is None:
if isinstance(inst, np.ndarray): if isinstance(inst, np.ndarray):
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape]) self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
elif isinstance(inst, dict) or isinstance(inst, Batch): elif isinstance(inst, (dict, Batch)):
if name == 'info': if name == 'info':
self.__dict__[name] = np.array( self.__dict__[name] = np.array(
[{} for _ in range(self._maxsize)]) [{} for _ in range(self._maxsize)])

View File

@ -192,15 +192,13 @@ class Collector(object):
return return
if isinstance(self.state, list): if isinstance(self.state, list):
self.state[id] = None self.state[id] = None
elif isinstance(self.state, dict) or isinstance(self.state, Batch): elif isinstance(self.state, (dict, Batch)):
for k in self.state.keys(): for k in self.state.keys():
if isinstance(self.state[k], list): if isinstance(self.state[k], list):
self.state[k][id] = None self.state[k][id] = None
elif isinstance(self.state[k], torch.Tensor) or \ elif isinstance(self.state[k], (torch.Tensor, np.ndarray)):
isinstance(self.state[k], np.ndarray):
self.state[k][id] = 0 self.state[k][id] = 0
elif isinstance(self.state, torch.Tensor) or \ elif isinstance(self.state, (torch.Tensor, np.ndarray)):
isinstance(self.state, np.ndarray):
self.state[id] = 0 self.state[id] = 0
def _to_numpy(self, x: Union[ def _to_numpy(self, x: Union[

View File

@ -102,7 +102,7 @@ class BasePolicy(ABC, nn.Module):
gamma: float = 0.99, gamma: float = 0.99,
gae_lambda: float = 0.95) -> Batch: gae_lambda: float = 0.95) -> Batch:
"""Compute returns over given full-length episodes, including the """Compute returns over given full-length episodes, including the
implementation of Generalized Advantage Estimation (arXiv:1506.02438). implementation of Generalized Advantage Estimator (arXiv:1506.02438).
:param batch: a data batch which contains several full-episode data :param batch: a data batch which contains several full-episode data
chronologically. chronologically.

View File

@ -2,9 +2,7 @@ import torch
class DiagGaussian(torch.distributions.Normal): class DiagGaussian(torch.distributions.Normal):
"""Diagonal Gaussian Distribution """Diagonal Gaussian distribution."""
"""
def log_prob(self, actions): def log_prob(self, actions):
return super().log_prob(actions).sum(-1, keepdim=True) return super().log_prob(actions).sum(-1, keepdim=True)

View File

@ -55,7 +55,6 @@ class A2CPolicy(PGPolicy):
self._grad_norm = max_grad_norm self._grad_norm = max_grad_norm
self._batch = 64 self._batch = 64
self._rew_norm = reward_normalization self._rew_norm = reward_normalization
self.__eps = np.finfo(np.float32).eps.item()
def process_fn(self, batch: Batch, buffer: ReplayBuffer, def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch: indice: np.ndarray) -> Batch:
@ -99,7 +98,7 @@ class A2CPolicy(PGPolicy):
**kwargs) -> Dict[str, List[float]]: **kwargs) -> Dict[str, List[float]]:
self._batch = batch_size self._batch = batch_size
r = batch.returns r = batch.returns
if self._rew_norm and r.std() > self.__eps: if self._rew_norm and not np.isclose(r.std(), 0):
batch.returns = (r - r.mean()) / r.std() batch.returns = (r - r.mean()) / r.std()
losses, actor_losses, vf_losses, ent_losses = [], [], [], [] losses, actor_losses, vf_losses, ent_losses = [], [], [], []
for _ in range(repeat): for _ in range(repeat):

View File

@ -71,7 +71,6 @@ class DDPGPolicy(BasePolicy):
# self.noise = OUNoise() # self.noise = OUNoise()
self._rm_done = ignore_done self._rm_done = ignore_done
self._rew_norm = reward_normalization self._rew_norm = reward_normalization
self.__eps = np.finfo(np.float32).eps.item()
def set_eps(self, eps: float) -> None: def set_eps(self, eps: float) -> None:
"""Set the eps for exploration.""" """Set the eps for exploration."""
@ -102,7 +101,7 @@ class DDPGPolicy(BasePolicy):
if self._rew_norm: if self._rew_norm:
bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer
mean, std = bfr.mean(), bfr.std() mean, std = bfr.mean(), bfr.std()
if std > self.__eps: if not np.isclose(std, 0):
batch.rew = (batch.rew - mean) / std batch.rew = (batch.rew - mean) / std
if self._rm_done: if self._rm_done:
batch.done = batch.done * 0. batch.done = batch.done * 0.

View File

@ -10,6 +10,7 @@ from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer
class DQNPolicy(BasePolicy): class DQNPolicy(BasePolicy):
"""Implementation of Deep Q Network. arXiv:1312.5602 """Implementation of Deep Q Network. arXiv:1312.5602
Implementation of Double Q-Learning. arXiv:1509.06461
:param torch.nn.Module model: a model following the rules in :param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits) :class:`~tianshou.policy.BasePolicy`. (s -> logits)

View File

@ -36,7 +36,6 @@ class PGPolicy(BasePolicy):
assert 0 <= discount_factor <= 1, 'discount factor should in [0, 1]' assert 0 <= discount_factor <= 1, 'discount factor should in [0, 1]'
self._gamma = discount_factor self._gamma = discount_factor
self._rew_norm = reward_normalization self._rew_norm = reward_normalization
self.__eps = np.finfo(np.float32).eps.item()
def process_fn(self, batch: Batch, buffer: ReplayBuffer, def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch: indice: np.ndarray) -> Batch:
@ -83,7 +82,7 @@ class PGPolicy(BasePolicy):
**kwargs) -> Dict[str, List[float]]: **kwargs) -> Dict[str, List[float]]:
losses = [] losses = []
r = batch.returns r = batch.returns
if self._rew_norm and r.std() > self.__eps: if self._rew_norm and not np.isclose(r.std(), 0):
batch.returns = (r - r.mean()) / r.std() batch.returns = (r - r.mean()) / r.std()
for _ in range(repeat): for _ in range(repeat):
for b in batch.split(batch_size): for b in batch.split(batch_size):

View File

@ -53,7 +53,7 @@ class PPOPolicy(PGPolicy):
ent_coef: float = .01, ent_coef: float = .01,
action_range: Optional[Tuple[float, float]] = None, action_range: Optional[Tuple[float, float]] = None,
gae_lambda: float = 0.95, gae_lambda: float = 0.95,
dual_clip: float = None, dual_clip: Optional[float] = None,
value_clip: bool = True, value_clip: bool = True,
reward_normalization: bool = True, reward_normalization: bool = True,
**kwargs) -> None: **kwargs) -> None:
@ -74,13 +74,12 @@ class PPOPolicy(PGPolicy):
self._dual_clip = dual_clip self._dual_clip = dual_clip
self._value_clip = value_clip self._value_clip = value_clip
self._rew_norm = reward_normalization self._rew_norm = reward_normalization
self.__eps = np.finfo(np.float32).eps.item()
def process_fn(self, batch: Batch, buffer: ReplayBuffer, def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch: indice: np.ndarray) -> Batch:
if self._rew_norm: if self._rew_norm:
mean, std = batch.rew.mean(), batch.rew.std() mean, std = batch.rew.mean(), batch.rew.std()
if std > self.__eps: if not np.isclose(std, 0):
batch.rew = (batch.rew - mean) / std batch.rew = (batch.rew - mean) / std
if self._lambda in [0, 1]: if self._lambda in [0, 1]:
return self.compute_episodic_return( return self.compute_episodic_return(
@ -140,12 +139,12 @@ class PPOPolicy(PGPolicy):
).reshape(batch.v.shape) ).reshape(batch.v.shape)
if self._rew_norm: if self._rew_norm:
mean, std = batch.returns.mean(), batch.returns.std() mean, std = batch.returns.mean(), batch.returns.std()
if std > self.__eps: if not np.isclose(std.item(), 0):
batch.returns = (batch.returns - mean) / std batch.returns = (batch.returns - mean) / std
batch.adv = batch.returns - batch.v batch.adv = batch.returns - batch.v
if self._rew_norm: if self._rew_norm:
mean, std = batch.adv.mean(), batch.adv.std() mean, std = batch.adv.mean(), batch.adv.std()
if std > self.__eps: if not np.isclose(std.item(), 0):
batch.adv = (batch.adv - mean) / std batch.adv = (batch.adv - mean) / std
for _ in range(repeat): for _ in range(repeat):
for b in batch.split(batch_size): for b in batch.split(batch_size):

View File

@ -6,7 +6,7 @@ from typing import Dict, Tuple, Union, Optional
from tianshou.data import Batch from tianshou.data import Batch
from tianshou.policy import DDPGPolicy from tianshou.policy import DDPGPolicy
from tianshou.policy.utils import DiagGaussian from tianshou.policy.dist import DiagGaussian
class SACPolicy(DDPGPolicy): class SACPolicy(DDPGPolicy):