item3 of #51
This commit is contained in:
parent
6237cc0d52
commit
de556fd22d
@ -20,14 +20,14 @@
|
|||||||
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
|
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
|
||||||
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
|
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
|
||||||
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) with n-step returns
|
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) with n-step returns
|
||||||
- [Prioritized DQN (PDQN)](https://arxiv.org/pdf/1511.05952.pdf)
|
|
||||||
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
|
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
|
||||||
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
|
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
|
||||||
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
|
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
|
||||||
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
|
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
|
||||||
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
|
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
|
||||||
- Vanilla Imitation Learning
|
- Vanilla Imitation Learning
|
||||||
- [Generalized Advantage Estimation (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
|
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
|
||||||
|
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
|
||||||
|
|
||||||
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms. Our team is working on supporting more algorithms and more scenarios on Tianshou in this period of development.
|
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms. Our team is working on supporting more algorithms and more scenarios on Tianshou in this period of development.
|
||||||
|
|
||||||
|
@ -11,14 +11,14 @@ Welcome to Tianshou!
|
|||||||
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
|
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
|
||||||
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
|
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
|
||||||
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns
|
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns
|
||||||
* :class:`~tianshou.policy.DQNPolicy` `Prioritized DQN <https://arxiv.org/pdf/1511.05952.pdf>`_
|
|
||||||
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
|
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
|
||||||
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
|
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
|
||||||
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
||||||
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
|
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
|
||||||
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
|
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
|
||||||
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
|
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
|
||||||
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimation <https://arxiv.org/pdf/1506.02438.pdf>`_
|
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
|
||||||
|
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_
|
||||||
|
|
||||||
|
|
||||||
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms.
|
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms.
|
||||||
|
@ -20,5 +20,14 @@ def test_batch():
|
|||||||
print(batch)
|
print(batch)
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_over_batch():
|
||||||
|
batch = Batch(a=[3, 4, 5], b=[4, 5, 6])
|
||||||
|
batch2 = Batch(b=batch, c=[6, 7, 8])
|
||||||
|
batch2.b.b[-1] = 0
|
||||||
|
print(batch2)
|
||||||
|
assert batch2[-1].b.b == 0
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_batch()
|
test_batch()
|
||||||
|
test_batch_over_batch()
|
||||||
|
@ -8,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
|
|
||||||
from tianshou.env import VectorEnv
|
from tianshou.env import VectorEnv
|
||||||
from tianshou.policy import PPOPolicy
|
from tianshou.policy import PPOPolicy
|
||||||
from tianshou.policy.utils import DiagGaussian
|
from tianshou.policy.dist import DiagGaussian
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ class Batch(object):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self._meta = {}
|
self._meta = {}
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
if (isinstance(v, list) or isinstance(v, np.ndarray)) \
|
if isinstance(v, (list, np.ndarray)) \
|
||||||
and len(v) > 0 and isinstance(v[0], dict) and k != 'info':
|
and len(v) > 0 and isinstance(v[0], dict) and k != 'info':
|
||||||
self._meta[k] = list(v[0].keys())
|
self._meta[k] = list(v[0].keys())
|
||||||
for k_ in v[0].keys():
|
for k_ in v[0].keys():
|
||||||
@ -78,7 +78,7 @@ class Batch(object):
|
|||||||
self.__dict__[k__] = np.array([
|
self.__dict__[k__] = np.array([
|
||||||
v[i][k_] for i in range(len(v))
|
v[i][k_] for i in range(len(v))
|
||||||
])
|
])
|
||||||
elif isinstance(v, dict) or isinstance(v, Batch):
|
elif isinstance(v, dict):
|
||||||
self._meta[k] = list(v.keys())
|
self._meta[k] = list(v.keys())
|
||||||
for k_ in v.keys():
|
for k_ in v.keys():
|
||||||
k__ = '_' + k + '@' + k_
|
k__ = '_' + k + '@' + k_
|
||||||
|
@ -151,7 +151,7 @@ class ReplayBuffer(object):
|
|||||||
if self.__dict__.get(name, None) is None:
|
if self.__dict__.get(name, None) is None:
|
||||||
if isinstance(inst, np.ndarray):
|
if isinstance(inst, np.ndarray):
|
||||||
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
|
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
|
||||||
elif isinstance(inst, dict) or isinstance(inst, Batch):
|
elif isinstance(inst, (dict, Batch)):
|
||||||
if name == 'info':
|
if name == 'info':
|
||||||
self.__dict__[name] = np.array(
|
self.__dict__[name] = np.array(
|
||||||
[{} for _ in range(self._maxsize)])
|
[{} for _ in range(self._maxsize)])
|
||||||
|
@ -192,15 +192,13 @@ class Collector(object):
|
|||||||
return
|
return
|
||||||
if isinstance(self.state, list):
|
if isinstance(self.state, list):
|
||||||
self.state[id] = None
|
self.state[id] = None
|
||||||
elif isinstance(self.state, dict) or isinstance(self.state, Batch):
|
elif isinstance(self.state, (dict, Batch)):
|
||||||
for k in self.state.keys():
|
for k in self.state.keys():
|
||||||
if isinstance(self.state[k], list):
|
if isinstance(self.state[k], list):
|
||||||
self.state[k][id] = None
|
self.state[k][id] = None
|
||||||
elif isinstance(self.state[k], torch.Tensor) or \
|
elif isinstance(self.state[k], (torch.Tensor, np.ndarray)):
|
||||||
isinstance(self.state[k], np.ndarray):
|
|
||||||
self.state[k][id] = 0
|
self.state[k][id] = 0
|
||||||
elif isinstance(self.state, torch.Tensor) or \
|
elif isinstance(self.state, (torch.Tensor, np.ndarray)):
|
||||||
isinstance(self.state, np.ndarray):
|
|
||||||
self.state[id] = 0
|
self.state[id] = 0
|
||||||
|
|
||||||
def _to_numpy(self, x: Union[
|
def _to_numpy(self, x: Union[
|
||||||
|
@ -102,7 +102,7 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
gamma: float = 0.99,
|
gamma: float = 0.99,
|
||||||
gae_lambda: float = 0.95) -> Batch:
|
gae_lambda: float = 0.95) -> Batch:
|
||||||
"""Compute returns over given full-length episodes, including the
|
"""Compute returns over given full-length episodes, including the
|
||||||
implementation of Generalized Advantage Estimation (arXiv:1506.02438).
|
implementation of Generalized Advantage Estimator (arXiv:1506.02438).
|
||||||
|
|
||||||
:param batch: a data batch which contains several full-episode data
|
:param batch: a data batch which contains several full-episode data
|
||||||
chronologically.
|
chronologically.
|
||||||
|
@ -2,9 +2,7 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
class DiagGaussian(torch.distributions.Normal):
|
class DiagGaussian(torch.distributions.Normal):
|
||||||
"""Diagonal Gaussian Distribution
|
"""Diagonal Gaussian distribution."""
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def log_prob(self, actions):
|
def log_prob(self, actions):
|
||||||
return super().log_prob(actions).sum(-1, keepdim=True)
|
return super().log_prob(actions).sum(-1, keepdim=True)
|
@ -55,7 +55,6 @@ class A2CPolicy(PGPolicy):
|
|||||||
self._grad_norm = max_grad_norm
|
self._grad_norm = max_grad_norm
|
||||||
self._batch = 64
|
self._batch = 64
|
||||||
self._rew_norm = reward_normalization
|
self._rew_norm = reward_normalization
|
||||||
self.__eps = np.finfo(np.float32).eps.item()
|
|
||||||
|
|
||||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||||
indice: np.ndarray) -> Batch:
|
indice: np.ndarray) -> Batch:
|
||||||
@ -99,7 +98,7 @@ class A2CPolicy(PGPolicy):
|
|||||||
**kwargs) -> Dict[str, List[float]]:
|
**kwargs) -> Dict[str, List[float]]:
|
||||||
self._batch = batch_size
|
self._batch = batch_size
|
||||||
r = batch.returns
|
r = batch.returns
|
||||||
if self._rew_norm and r.std() > self.__eps:
|
if self._rew_norm and not np.isclose(r.std(), 0):
|
||||||
batch.returns = (r - r.mean()) / r.std()
|
batch.returns = (r - r.mean()) / r.std()
|
||||||
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
||||||
for _ in range(repeat):
|
for _ in range(repeat):
|
||||||
|
@ -71,7 +71,6 @@ class DDPGPolicy(BasePolicy):
|
|||||||
# self.noise = OUNoise()
|
# self.noise = OUNoise()
|
||||||
self._rm_done = ignore_done
|
self._rm_done = ignore_done
|
||||||
self._rew_norm = reward_normalization
|
self._rew_norm = reward_normalization
|
||||||
self.__eps = np.finfo(np.float32).eps.item()
|
|
||||||
|
|
||||||
def set_eps(self, eps: float) -> None:
|
def set_eps(self, eps: float) -> None:
|
||||||
"""Set the eps for exploration."""
|
"""Set the eps for exploration."""
|
||||||
@ -102,7 +101,7 @@ class DDPGPolicy(BasePolicy):
|
|||||||
if self._rew_norm:
|
if self._rew_norm:
|
||||||
bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer
|
bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer
|
||||||
mean, std = bfr.mean(), bfr.std()
|
mean, std = bfr.mean(), bfr.std()
|
||||||
if std > self.__eps:
|
if not np.isclose(std, 0):
|
||||||
batch.rew = (batch.rew - mean) / std
|
batch.rew = (batch.rew - mean) / std
|
||||||
if self._rm_done:
|
if self._rm_done:
|
||||||
batch.done = batch.done * 0.
|
batch.done = batch.done * 0.
|
||||||
|
@ -10,6 +10,7 @@ from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer
|
|||||||
|
|
||||||
class DQNPolicy(BasePolicy):
|
class DQNPolicy(BasePolicy):
|
||||||
"""Implementation of Deep Q Network. arXiv:1312.5602
|
"""Implementation of Deep Q Network. arXiv:1312.5602
|
||||||
|
Implementation of Double Q-Learning. arXiv:1509.06461
|
||||||
|
|
||||||
:param torch.nn.Module model: a model following the rules in
|
:param torch.nn.Module model: a model following the rules in
|
||||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||||
|
@ -36,7 +36,6 @@ class PGPolicy(BasePolicy):
|
|||||||
assert 0 <= discount_factor <= 1, 'discount factor should in [0, 1]'
|
assert 0 <= discount_factor <= 1, 'discount factor should in [0, 1]'
|
||||||
self._gamma = discount_factor
|
self._gamma = discount_factor
|
||||||
self._rew_norm = reward_normalization
|
self._rew_norm = reward_normalization
|
||||||
self.__eps = np.finfo(np.float32).eps.item()
|
|
||||||
|
|
||||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||||
indice: np.ndarray) -> Batch:
|
indice: np.ndarray) -> Batch:
|
||||||
@ -83,7 +82,7 @@ class PGPolicy(BasePolicy):
|
|||||||
**kwargs) -> Dict[str, List[float]]:
|
**kwargs) -> Dict[str, List[float]]:
|
||||||
losses = []
|
losses = []
|
||||||
r = batch.returns
|
r = batch.returns
|
||||||
if self._rew_norm and r.std() > self.__eps:
|
if self._rew_norm and not np.isclose(r.std(), 0):
|
||||||
batch.returns = (r - r.mean()) / r.std()
|
batch.returns = (r - r.mean()) / r.std()
|
||||||
for _ in range(repeat):
|
for _ in range(repeat):
|
||||||
for b in batch.split(batch_size):
|
for b in batch.split(batch_size):
|
||||||
|
@ -53,7 +53,7 @@ class PPOPolicy(PGPolicy):
|
|||||||
ent_coef: float = .01,
|
ent_coef: float = .01,
|
||||||
action_range: Optional[Tuple[float, float]] = None,
|
action_range: Optional[Tuple[float, float]] = None,
|
||||||
gae_lambda: float = 0.95,
|
gae_lambda: float = 0.95,
|
||||||
dual_clip: float = None,
|
dual_clip: Optional[float] = None,
|
||||||
value_clip: bool = True,
|
value_clip: bool = True,
|
||||||
reward_normalization: bool = True,
|
reward_normalization: bool = True,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
@ -74,13 +74,12 @@ class PPOPolicy(PGPolicy):
|
|||||||
self._dual_clip = dual_clip
|
self._dual_clip = dual_clip
|
||||||
self._value_clip = value_clip
|
self._value_clip = value_clip
|
||||||
self._rew_norm = reward_normalization
|
self._rew_norm = reward_normalization
|
||||||
self.__eps = np.finfo(np.float32).eps.item()
|
|
||||||
|
|
||||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||||
indice: np.ndarray) -> Batch:
|
indice: np.ndarray) -> Batch:
|
||||||
if self._rew_norm:
|
if self._rew_norm:
|
||||||
mean, std = batch.rew.mean(), batch.rew.std()
|
mean, std = batch.rew.mean(), batch.rew.std()
|
||||||
if std > self.__eps:
|
if not np.isclose(std, 0):
|
||||||
batch.rew = (batch.rew - mean) / std
|
batch.rew = (batch.rew - mean) / std
|
||||||
if self._lambda in [0, 1]:
|
if self._lambda in [0, 1]:
|
||||||
return self.compute_episodic_return(
|
return self.compute_episodic_return(
|
||||||
@ -140,12 +139,12 @@ class PPOPolicy(PGPolicy):
|
|||||||
).reshape(batch.v.shape)
|
).reshape(batch.v.shape)
|
||||||
if self._rew_norm:
|
if self._rew_norm:
|
||||||
mean, std = batch.returns.mean(), batch.returns.std()
|
mean, std = batch.returns.mean(), batch.returns.std()
|
||||||
if std > self.__eps:
|
if not np.isclose(std.item(), 0):
|
||||||
batch.returns = (batch.returns - mean) / std
|
batch.returns = (batch.returns - mean) / std
|
||||||
batch.adv = batch.returns - batch.v
|
batch.adv = batch.returns - batch.v
|
||||||
if self._rew_norm:
|
if self._rew_norm:
|
||||||
mean, std = batch.adv.mean(), batch.adv.std()
|
mean, std = batch.adv.mean(), batch.adv.std()
|
||||||
if std > self.__eps:
|
if not np.isclose(std.item(), 0):
|
||||||
batch.adv = (batch.adv - mean) / std
|
batch.adv = (batch.adv - mean) / std
|
||||||
for _ in range(repeat):
|
for _ in range(repeat):
|
||||||
for b in batch.split(batch_size):
|
for b in batch.split(batch_size):
|
||||||
|
@ -6,7 +6,7 @@ from typing import Dict, Tuple, Union, Optional
|
|||||||
|
|
||||||
from tianshou.data import Batch
|
from tianshou.data import Batch
|
||||||
from tianshou.policy import DDPGPolicy
|
from tianshou.policy import DDPGPolicy
|
||||||
from tianshou.policy.utils import DiagGaussian
|
from tianshou.policy.dist import DiagGaussian
|
||||||
|
|
||||||
|
|
||||||
class SACPolicy(DDPGPolicy):
|
class SACPolicy(DDPGPolicy):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user