item3 of #51
This commit is contained in:
parent
6237cc0d52
commit
de556fd22d
@ -20,14 +20,14 @@
|
||||
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
|
||||
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
|
||||
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) with n-step returns
|
||||
- [Prioritized DQN (PDQN)](https://arxiv.org/pdf/1511.05952.pdf)
|
||||
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
|
||||
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
|
||||
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
|
||||
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
|
||||
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
|
||||
- Vanilla Imitation Learning
|
||||
- [Generalized Advantage Estimation (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
|
||||
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
|
||||
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
|
||||
|
||||
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms. Our team is working on supporting more algorithms and more scenarios on Tianshou in this period of development.
|
||||
|
||||
|
@ -11,14 +11,14 @@ Welcome to Tianshou!
|
||||
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Prioritized DQN <https://arxiv.org/pdf/1511.05952.pdf>`_
|
||||
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
|
||||
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
|
||||
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
||||
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
|
||||
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
|
||||
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
|
||||
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimation <https://arxiv.org/pdf/1506.02438.pdf>`_
|
||||
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
|
||||
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_
|
||||
|
||||
|
||||
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms.
|
||||
|
@ -20,5 +20,14 @@ def test_batch():
|
||||
print(batch)
|
||||
|
||||
|
||||
def test_batch_over_batch():
|
||||
batch = Batch(a=[3, 4, 5], b=[4, 5, 6])
|
||||
batch2 = Batch(b=batch, c=[6, 7, 8])
|
||||
batch2.b.b[-1] = 0
|
||||
print(batch2)
|
||||
assert batch2[-1].b.b == 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_batch()
|
||||
test_batch_over_batch()
|
||||
|
@ -8,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.policy import PPOPolicy
|
||||
from tianshou.policy.utils import DiagGaussian
|
||||
from tianshou.policy.dist import DiagGaussian
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
|
||||
|
@ -70,7 +70,7 @@ class Batch(object):
|
||||
super().__init__()
|
||||
self._meta = {}
|
||||
for k, v in kwargs.items():
|
||||
if (isinstance(v, list) or isinstance(v, np.ndarray)) \
|
||||
if isinstance(v, (list, np.ndarray)) \
|
||||
and len(v) > 0 and isinstance(v[0], dict) and k != 'info':
|
||||
self._meta[k] = list(v[0].keys())
|
||||
for k_ in v[0].keys():
|
||||
@ -78,7 +78,7 @@ class Batch(object):
|
||||
self.__dict__[k__] = np.array([
|
||||
v[i][k_] for i in range(len(v))
|
||||
])
|
||||
elif isinstance(v, dict) or isinstance(v, Batch):
|
||||
elif isinstance(v, dict):
|
||||
self._meta[k] = list(v.keys())
|
||||
for k_ in v.keys():
|
||||
k__ = '_' + k + '@' + k_
|
||||
|
@ -151,7 +151,7 @@ class ReplayBuffer(object):
|
||||
if self.__dict__.get(name, None) is None:
|
||||
if isinstance(inst, np.ndarray):
|
||||
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
|
||||
elif isinstance(inst, dict) or isinstance(inst, Batch):
|
||||
elif isinstance(inst, (dict, Batch)):
|
||||
if name == 'info':
|
||||
self.__dict__[name] = np.array(
|
||||
[{} for _ in range(self._maxsize)])
|
||||
|
@ -192,15 +192,13 @@ class Collector(object):
|
||||
return
|
||||
if isinstance(self.state, list):
|
||||
self.state[id] = None
|
||||
elif isinstance(self.state, dict) or isinstance(self.state, Batch):
|
||||
elif isinstance(self.state, (dict, Batch)):
|
||||
for k in self.state.keys():
|
||||
if isinstance(self.state[k], list):
|
||||
self.state[k][id] = None
|
||||
elif isinstance(self.state[k], torch.Tensor) or \
|
||||
isinstance(self.state[k], np.ndarray):
|
||||
elif isinstance(self.state[k], (torch.Tensor, np.ndarray)):
|
||||
self.state[k][id] = 0
|
||||
elif isinstance(self.state, torch.Tensor) or \
|
||||
isinstance(self.state, np.ndarray):
|
||||
elif isinstance(self.state, (torch.Tensor, np.ndarray)):
|
||||
self.state[id] = 0
|
||||
|
||||
def _to_numpy(self, x: Union[
|
||||
|
@ -102,7 +102,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
gamma: float = 0.99,
|
||||
gae_lambda: float = 0.95) -> Batch:
|
||||
"""Compute returns over given full-length episodes, including the
|
||||
implementation of Generalized Advantage Estimation (arXiv:1506.02438).
|
||||
implementation of Generalized Advantage Estimator (arXiv:1506.02438).
|
||||
|
||||
:param batch: a data batch which contains several full-episode data
|
||||
chronologically.
|
||||
|
@ -2,9 +2,7 @@ import torch
|
||||
|
||||
|
||||
class DiagGaussian(torch.distributions.Normal):
|
||||
"""Diagonal Gaussian Distribution
|
||||
|
||||
"""
|
||||
"""Diagonal Gaussian distribution."""
|
||||
|
||||
def log_prob(self, actions):
|
||||
return super().log_prob(actions).sum(-1, keepdim=True)
|
@ -55,7 +55,6 @@ class A2CPolicy(PGPolicy):
|
||||
self._grad_norm = max_grad_norm
|
||||
self._batch = 64
|
||||
self._rew_norm = reward_normalization
|
||||
self.__eps = np.finfo(np.float32).eps.item()
|
||||
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> Batch:
|
||||
@ -99,7 +98,7 @@ class A2CPolicy(PGPolicy):
|
||||
**kwargs) -> Dict[str, List[float]]:
|
||||
self._batch = batch_size
|
||||
r = batch.returns
|
||||
if self._rew_norm and r.std() > self.__eps:
|
||||
if self._rew_norm and not np.isclose(r.std(), 0):
|
||||
batch.returns = (r - r.mean()) / r.std()
|
||||
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
||||
for _ in range(repeat):
|
||||
|
@ -71,7 +71,6 @@ class DDPGPolicy(BasePolicy):
|
||||
# self.noise = OUNoise()
|
||||
self._rm_done = ignore_done
|
||||
self._rew_norm = reward_normalization
|
||||
self.__eps = np.finfo(np.float32).eps.item()
|
||||
|
||||
def set_eps(self, eps: float) -> None:
|
||||
"""Set the eps for exploration."""
|
||||
@ -102,7 +101,7 @@ class DDPGPolicy(BasePolicy):
|
||||
if self._rew_norm:
|
||||
bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer
|
||||
mean, std = bfr.mean(), bfr.std()
|
||||
if std > self.__eps:
|
||||
if not np.isclose(std, 0):
|
||||
batch.rew = (batch.rew - mean) / std
|
||||
if self._rm_done:
|
||||
batch.done = batch.done * 0.
|
||||
|
@ -10,6 +10,7 @@ from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer
|
||||
|
||||
class DQNPolicy(BasePolicy):
|
||||
"""Implementation of Deep Q Network. arXiv:1312.5602
|
||||
Implementation of Double Q-Learning. arXiv:1509.06461
|
||||
|
||||
:param torch.nn.Module model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
|
@ -36,7 +36,6 @@ class PGPolicy(BasePolicy):
|
||||
assert 0 <= discount_factor <= 1, 'discount factor should in [0, 1]'
|
||||
self._gamma = discount_factor
|
||||
self._rew_norm = reward_normalization
|
||||
self.__eps = np.finfo(np.float32).eps.item()
|
||||
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> Batch:
|
||||
@ -83,7 +82,7 @@ class PGPolicy(BasePolicy):
|
||||
**kwargs) -> Dict[str, List[float]]:
|
||||
losses = []
|
||||
r = batch.returns
|
||||
if self._rew_norm and r.std() > self.__eps:
|
||||
if self._rew_norm and not np.isclose(r.std(), 0):
|
||||
batch.returns = (r - r.mean()) / r.std()
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size):
|
||||
|
@ -53,7 +53,7 @@ class PPOPolicy(PGPolicy):
|
||||
ent_coef: float = .01,
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
gae_lambda: float = 0.95,
|
||||
dual_clip: float = None,
|
||||
dual_clip: Optional[float] = None,
|
||||
value_clip: bool = True,
|
||||
reward_normalization: bool = True,
|
||||
**kwargs) -> None:
|
||||
@ -74,13 +74,12 @@ class PPOPolicy(PGPolicy):
|
||||
self._dual_clip = dual_clip
|
||||
self._value_clip = value_clip
|
||||
self._rew_norm = reward_normalization
|
||||
self.__eps = np.finfo(np.float32).eps.item()
|
||||
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> Batch:
|
||||
if self._rew_norm:
|
||||
mean, std = batch.rew.mean(), batch.rew.std()
|
||||
if std > self.__eps:
|
||||
if not np.isclose(std, 0):
|
||||
batch.rew = (batch.rew - mean) / std
|
||||
if self._lambda in [0, 1]:
|
||||
return self.compute_episodic_return(
|
||||
@ -140,12 +139,12 @@ class PPOPolicy(PGPolicy):
|
||||
).reshape(batch.v.shape)
|
||||
if self._rew_norm:
|
||||
mean, std = batch.returns.mean(), batch.returns.std()
|
||||
if std > self.__eps:
|
||||
if not np.isclose(std.item(), 0):
|
||||
batch.returns = (batch.returns - mean) / std
|
||||
batch.adv = batch.returns - batch.v
|
||||
if self._rew_norm:
|
||||
mean, std = batch.adv.mean(), batch.adv.std()
|
||||
if std > self.__eps:
|
||||
if not np.isclose(std.item(), 0):
|
||||
batch.adv = (batch.adv - mean) / std
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size):
|
||||
|
@ -6,7 +6,7 @@ from typing import Dict, Tuple, Union, Optional
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.policy.utils import DiagGaussian
|
||||
from tianshou.policy.dist import DiagGaussian
|
||||
|
||||
|
||||
class SACPolicy(DDPGPolicy):
|
||||
|
Loading…
x
Reference in New Issue
Block a user