nstep all (fix #51)

This commit is contained in:
Trinkle23897 2020-06-03 13:59:47 +08:00
parent ff81a18f42
commit dc451dfe88
17 changed files with 127 additions and 84 deletions

View File

@ -20,7 +20,7 @@
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf) - [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) - [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) with n-step returns - [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf)
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/) - [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf) - [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf) - [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
@ -30,7 +30,7 @@
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf) - [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf) - [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
Tianshou supports parallel workers for all algorithms as well since all of them are reformatted as replay-buffer based algorithms. All of the algorithms support recurrent state representation in actor network (RNN-style training in POMDP). The environment state can be any type (Dict, self-defined class, ...). **Tianshou supports parallel workers for all algorithms as well since all of them are reformatted as replay-buffer based algorithms. All of the algorithms support recurrent state representation in actor network (RNN-style training in POMDP). The environment state can be any type (dict, self-defined class, ...). All Q-learning algorithms support n-step returns estimation.**
In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment. In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.
@ -102,7 +102,7 @@ We select some of famous reinforcement learning platforms: 2 GitHub repos with m
| Algo - Task | PyTorch | TensorFlow | TensorFlow | TF/PyTorch | PyTorch | PyTorch | | Algo - Task | PyTorch | TensorFlow | TensorFlow | TF/PyTorch | PyTorch | PyTorch |
| PG - CartPole | 6.09±4.60s | None | None | 19.26±2.29s | None | ? | | PG - CartPole | 6.09±4.60s | None | None | 19.26±2.29s | None | ? |
| DQN - CartPole | 6.09±0.87s | 1046.34±291.27s | 93.47±58.05s | 28.56±4.60s | 31.58±11.30s \*\* | ? | | DQN - CartPole | 6.09±0.87s | 1046.34±291.27s | 93.47±58.05s | 28.56±4.60s | 31.58±11.30s \*\* | ? |
| A2C - CartPole | 6.36±1.63s | \*(~1612s) | 57.56±12.87s | 57.92±9.94s | \*(Not converged) | ? | | A2C - CartPole | 10.59±2.04s | \*(~1612s) | 57.56±12.87s | 57.92±9.94s | \*(Not converged) | ? |
| PPO - CartPole | 31.82±7.76s | \*(~1179s) | 34.79±17.02s | 44.60±17.04s | 23.99±9.26s \*\* | ? | | PPO - CartPole | 31.82±7.76s | \*(~1179s) | 34.79±17.02s | 44.60±17.04s | 23.99±9.26s \*\* | ? |
| PPO - Pendulum | 16.18±2.49s | 745.43±160.82s | 259.73±27.37s | 123.62±44.23s | Runtime Error | ? | | PPO - Pendulum | 16.18±2.49s | 745.43±160.82s | 259.73±27.37s | 123.62±44.23s | Runtime Error | ? |
| DDPG - Pendulum | 37.26±9.55s | \*(>1h) | 277.52±92.67s | 314.70±7.92s | 59.05±10.03s \*\* | 172.18±62.48s | | DDPG - Pendulum | 37.26±9.55s | \*(>1h) | 277.52±92.67s | 314.70±7.92s | 59.05±10.03s \*\* | 172.18±62.48s |

View File

@ -36,7 +36,9 @@ def get_args():
parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.) parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--rew-norm', type=bool, default=True) parser.add_argument('--rew-norm', type=int, default=1)
parser.add_argument('--ignore-done', type=int, default=1)
parser.add_argument('--n-step', type=int, default=1)
parser.add_argument( parser.add_argument(
'--device', type=str, '--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu') default='cuda' if torch.cuda.is_available() else 'cpu')
@ -78,7 +80,9 @@ def test_ddpg(args=get_args()):
actor, actor_optim, critic, critic_optim, actor, actor_optim, critic, critic_optim,
args.tau, args.gamma, args.exploration_noise, args.tau, args.gamma, args.exploration_noise,
[env.action_space.low[0], env.action_space.high[0]], [env.action_space.low[0], env.action_space.high[0]],
reward_normalization=args.rew_norm, ignore_done=True) reward_normalization=args.rew_norm,
ignore_done=args.ignore_done,
estimation_step=args.n_step)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))

View File

@ -44,9 +44,9 @@ def get_args():
parser.add_argument('--eps-clip', type=float, default=0.2) parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--max-grad-norm', type=float, default=0.5) parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--gae-lambda', type=float, default=0.95) parser.add_argument('--gae-lambda', type=float, default=0.95)
parser.add_argument('--rew-norm', type=bool, default=True) parser.add_argument('--rew-norm', type=int, default=1)
# parser.add_argument('--dual-clip', type=float, default=5.) parser.add_argument('--dual-clip', type=float, default=None)
parser.add_argument('--value-clip', type=bool, default=True) parser.add_argument('--value-clip', type=int, default=1)
args = parser.parse_known_args()[0] args = parser.parse_known_args()[0]
return args return args

View File

@ -37,7 +37,9 @@ def get_args():
parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.) parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--rew-norm', type=bool, default=True) parser.add_argument('--rew-norm', type=int, default=1)
parser.add_argument('--ignore-done', type=int, default=1)
parser.add_argument('--n-step', type=int, default=4)
parser.add_argument( parser.add_argument(
'--device', type=str, '--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu') default='cuda' if torch.cuda.is_available() else 'cpu')
@ -83,7 +85,9 @@ def test_sac_with_il(args=get_args()):
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, args.alpha, args.tau, args.gamma, args.alpha,
[env.action_space.low[0], env.action_space.high[0]], [env.action_space.low[0], env.action_space.high[0]],
reward_normalization=args.rew_norm, ignore_done=True) reward_normalization=args.rew_norm,
ignore_done=args.ignore_done,
estimation_step=args.n_step)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))
@ -126,7 +130,7 @@ def test_sac_with_il(args=get_args()):
train_collector.reset() train_collector.reset()
result = offpolicy_trainer( result = offpolicy_trainer(
il_policy, train_collector, il_test_collector, args.epoch, il_policy, train_collector, il_test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num, args.step_per_epoch // 5, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
train_collector.close() train_collector.close()

View File

@ -39,7 +39,9 @@ def get_args():
parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.) parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--rew-norm', type=bool, default=True) parser.add_argument('--rew-norm', type=int, default=1)
parser.add_argument('--ignore-done', type=int, default=1)
parser.add_argument('--n-step', type=int, default=1)
parser.add_argument( parser.add_argument(
'--device', type=str, '--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu') default='cuda' if torch.cuda.is_available() else 'cpu')
@ -86,7 +88,9 @@ def test_td3(args=get_args()):
args.tau, args.gamma, args.exploration_noise, args.policy_noise, args.tau, args.gamma, args.exploration_noise, args.policy_noise,
args.update_actor_freq, args.noise_clip, args.update_actor_freq, args.noise_clip,
[env.action_space.low[0], env.action_space.high[0]], [env.action_space.low[0], env.action_space.high[0]],
reward_normalization=args.rew_norm, ignore_done=True) reward_normalization=args.rew_norm,
ignore_done=args.ignore_done,
estimation_step=args.n_step)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))

View File

@ -31,7 +31,7 @@ def get_args():
parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--repeat-per-collect', type=int, default=1)
parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=2) parser.add_argument('--layer-num', type=int, default=2)
parser.add_argument('--training-num', type=int, default=32) parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.) parser.add_argument('--render', type=float, default=0.)
@ -40,7 +40,7 @@ def get_args():
default='cuda' if torch.cuda.is_available() else 'cpu') default='cuda' if torch.cuda.is_available() else 'cpu')
# a2c special # a2c special
parser.add_argument('--vf-coef', type=float, default=0.5) parser.add_argument('--vf-coef', type=float, default=0.5)
parser.add_argument('--ent-coef', type=float, default=0.001) parser.add_argument('--ent-coef', type=float, default=0.0)
parser.add_argument('--max-grad-norm', type=float, default=None) parser.add_argument('--max-grad-norm', type=float, default=None)
parser.add_argument('--gae-lambda', type=float, default=1.) parser.add_argument('--gae-lambda', type=float, default=1.)
parser.add_argument('--rew-norm', type=bool, default=False) parser.add_argument('--rew-norm', type=bool, default=False)

View File

@ -102,7 +102,7 @@ def get_args():
parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.) parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--rew-norm', type=bool, default=True) parser.add_argument('--rew-norm', type=int, default=1)
parser.add_argument( parser.add_argument(
'--device', type=str, '--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu') default='cuda' if torch.cuda.is_available() else 'cpu')

View File

@ -43,9 +43,9 @@ def get_args():
parser.add_argument('--eps-clip', type=float, default=0.2) parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--max-grad-norm', type=float, default=0.5) parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--gae-lambda', type=float, default=0.8) parser.add_argument('--gae-lambda', type=float, default=0.8)
parser.add_argument('--rew-norm', type=bool, default=True) parser.add_argument('--rew-norm', type=int, default=1)
parser.add_argument('--dual-clip', type=float, default=None) parser.add_argument('--dual-clip', type=float, default=None)
parser.add_argument('--value-clip', type=bool, default=True) parser.add_argument('--value-clip', type=int, default=1)
args = parser.parse_known_args()[0] args = parser.parse_known_args()[0]
return args return args

View File

@ -1,5 +1,6 @@
from tianshou.data.batch import Batch from tianshou.data.batch import Batch
from tianshou.data.utils import to_numpy, to_torch from tianshou.data.utils import to_numpy, to_torch, \
to_torch_as
from tianshou.data.buffer import ReplayBuffer, \ from tianshou.data.buffer import ReplayBuffer, \
ListReplayBuffer, PrioritizedReplayBuffer ListReplayBuffer, PrioritizedReplayBuffer
from tianshou.data.collector import Collector from tianshou.data.collector import Collector
@ -8,6 +9,7 @@ __all__ = [
'Batch', 'Batch',
'to_numpy', 'to_numpy',
'to_torch', 'to_torch',
'to_torch_as',
'ReplayBuffer', 'ReplayBuffer',
'ListReplayBuffer', 'ListReplayBuffer',
'PrioritizedReplayBuffer', 'PrioritizedReplayBuffer',

View File

@ -38,3 +38,13 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
elif isinstance(x, Batch): elif isinstance(x, Batch):
x.to_torch(dtype, device) x.to_torch(dtype, device)
return x return x
def to_torch_as(x: Union[torch.Tensor, dict, Batch, np.ndarray],
y: torch.Tensor
) -> Union[dict, Batch, torch.Tensor]:
"""Return an object without np.ndarray. Same as
``to_torch(x, dtype=y.dtype, device=y.device)``.
"""
assert isinstance(y, torch.Tensor)
return to_torch(x, dtype=y.dtype, device=y.device)

View File

@ -4,7 +4,7 @@ from torch import nn
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Union, Optional, Callable from typing import Dict, List, Union, Optional, Callable
from tianshou.data import Batch, ReplayBuffer from tianshou.data import Batch, ReplayBuffer, to_torch_as
class BasePolicy(ABC, nn.Module): class BasePolicy(ABC, nn.Module):
@ -138,9 +138,10 @@ class BasePolicy(ABC, nn.Module):
batch: Batch, batch: Batch,
buffer: ReplayBuffer, buffer: ReplayBuffer,
indice: np.ndarray, indice: np.ndarray,
target_q_fn: Callable[[ReplayBuffer, np.ndarray], np.ndarray], target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor],
gamma: float = 0.99, gamma: float = 0.99,
n_step: int = 1 n_step: int = 1,
rew_norm: bool = False
) -> np.ndarray: ) -> np.ndarray:
r"""Compute n-step return for Q-learning targets: r"""Compute n-step return for Q-learning targets:
@ -159,13 +160,25 @@ class BasePolicy(ABC, nn.Module):
:type buffer: :class:`~tianshou.data.ReplayBuffer` :type buffer: :class:`~tianshou.data.ReplayBuffer`
:param indice: sampled timestep. :param indice: sampled timestep.
:type indice: numpy.ndarray :type indice: numpy.ndarray
:param function target_q_fn: a function receives :math:`t+n-1` step's
data and compute target Q value.
:param float gamma: the discount factor, should be in [0, 1], defaults :param float gamma: the discount factor, should be in [0, 1], defaults
to 0.99. to 0.99.
:param int n_step: the number of estimation step, should be an int :param int n_step: the number of estimation step, should be an int
greater than 0, defaults to 1. greater than 0, defaults to 1.
:param bool rew_norm: normalize the reward to Normal(0, 1), defaults
to ``False``.
:return: a Batch. The result will be stored in batch.returns. :return: a Batch. The result will be stored in batch.returns as a
torch.Tensor with shape (bsz, ).
""" """
if rew_norm:
bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer
mean, std = bfr.mean(), bfr.std()
if np.isclose(std, 0):
mean, std = 0, 1
else:
mean, std = 0, 1
returns = np.zeros_like(indice) returns = np.zeros_like(indice)
gammas = np.zeros_like(indice) + n_step gammas = np.zeros_like(indice) + n_step
done, rew, buf_len = buffer.done, buffer.rew, len(buffer) done, rew, buf_len = buffer.done, buffer.rew, len(buffer)
@ -173,10 +186,11 @@ class BasePolicy(ABC, nn.Module):
now = (indice + n) % buf_len now = (indice + n) % buf_len
gammas[done[now] > 0] = n gammas[done[now] > 0] = n
returns[done[now] > 0] = 0 returns[done[now] > 0] = 0
returns = rew[now] + gamma * returns returns = (rew[now] - mean) / std + gamma * returns
terminal = (indice + n_step - 1) % buf_len terminal = (indice + n_step - 1) % buf_len
target_q = target_q_fn(buffer, terminal) target_q = target_q_fn(buffer, terminal).squeeze()
target_q[gammas != n_step] = 0 target_q[gammas != n_step] = 0
returns += (gamma ** gammas) * target_q returns = to_torch_as(returns, target_q)
batch.returns = returns gammas = to_torch_as(gamma ** gammas, target_q)
batch.returns = target_q * gammas + returns
return batch return batch

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
from typing import Dict, List, Union, Optional from typing import Dict, List, Union, Optional
from tianshou.policy import PGPolicy from tianshou.policy import PGPolicy
from tianshou.data import Batch, ReplayBuffer, to_torch, to_numpy from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
class A2CPolicy(PGPolicy): class A2CPolicy(PGPolicy):
@ -106,14 +106,14 @@ class A2CPolicy(PGPolicy):
self.optim.zero_grad() self.optim.zero_grad()
dist = self(b).dist dist = self(b).dist
v = self.critic(b.obs) v = self.critic(b.obs)
a = to_torch(b.act, device=v.device) a = to_torch_as(b.act, v)
r = to_torch(b.returns, device=v.device) r = to_torch_as(b.returns, v)
a_loss = -(dist.log_prob(a) * (r - v).detach()).mean() a_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
vf_loss = F.mse_loss(r[:, None], v) vf_loss = F.mse_loss(r[:, None], v)
ent_loss = dist.entropy().mean() ent_loss = dist.entropy().mean()
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
loss.backward() loss.backward()
if self._grad_norm: if self._grad_norm is not None:
nn.utils.clip_grad_norm_( nn.utils.clip_grad_norm_(
list(self.actor.parameters()) + list(self.actor.parameters()) +
list(self.critic.parameters()), list(self.critic.parameters()),

View File

@ -6,7 +6,7 @@ from typing import Dict, Tuple, Union, Optional
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
# from tianshou.exploration import OUNoise # from tianshou.exploration import OUNoise
from tianshou.data import Batch, ReplayBuffer, to_torch from tianshou.data import Batch, ReplayBuffer, to_torch_as
class DDPGPolicy(BasePolicy): class DDPGPolicy(BasePolicy):
@ -29,6 +29,8 @@ class DDPGPolicy(BasePolicy):
defaults to ``False``. defaults to ``False``.
:param bool ignore_done: ignore the done flag while training the policy, :param bool ignore_done: ignore the done flag while training the policy,
defaults to ``False``. defaults to ``False``.
:param int estimation_step: greater than 1, the number of steps to look
ahead.
.. seealso:: .. seealso::
@ -47,6 +49,7 @@ class DDPGPolicy(BasePolicy):
action_range: Optional[Tuple[float, float]] = None, action_range: Optional[Tuple[float, float]] = None,
reward_normalization: bool = False, reward_normalization: bool = False,
ignore_done: bool = False, ignore_done: bool = False,
estimation_step: int = 1,
**kwargs) -> None: **kwargs) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
if actor is not None: if actor is not None:
@ -71,6 +74,8 @@ class DDPGPolicy(BasePolicy):
# self.noise = OUNoise() # self.noise = OUNoise()
self._rm_done = ignore_done self._rm_done = ignore_done
self._rew_norm = reward_normalization self._rew_norm = reward_normalization
assert estimation_step > 0, 'estimation_step should greater than 0'
self._n_step = estimation_step
def set_eps(self, eps: float) -> None: def set_eps(self, eps: float) -> None:
"""Set the eps for exploration.""" """Set the eps for exploration."""
@ -96,15 +101,21 @@ class DDPGPolicy(BasePolicy):
self.critic_old.parameters(), self.critic.parameters()): self.critic_old.parameters(), self.critic.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
def _target_q(self, buffer: ReplayBuffer,
indice: np.ndarray) -> torch.Tensor:
batch = buffer[indice] # batch.obs_next: s_{t+n}
with torch.no_grad():
target_q = self.critic_old(batch.obs_next, self(
batch, model='actor_old', input='obs_next', eps=0).act)
return target_q
def process_fn(self, batch: Batch, buffer: ReplayBuffer, def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch: indice: np.ndarray) -> Batch:
if self._rew_norm:
bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer
mean, std = bfr.mean(), bfr.std()
if not np.isclose(std, 0):
batch.rew = (batch.rew - mean) / std
if self._rm_done: if self._rm_done:
batch.done = batch.done * 0. batch.done = batch.done * 0.
batch = self.compute_nstep_return(
batch, buffer, indice, self._target_q,
self._gamma, self._n_step, self._rew_norm)
return batch return batch
def forward(self, batch: Batch, def forward(self, batch: Batch,
@ -143,16 +154,9 @@ class DDPGPolicy(BasePolicy):
return Batch(act=logits, state=h) return Batch(act=logits, state=h)
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
with torch.no_grad():
target_q = self.critic_old(batch.obs_next, self(
batch, model='actor_old', input='obs_next', eps=0).act)
dev = target_q.device
rew = to_torch(batch.rew,
dtype=torch.float, device=dev)[:, None]
done = to_torch(batch.done,
dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q)
current_q = self.critic(batch.obs, batch.act) current_q = self.critic(batch.obs, batch.act)
target_q = to_torch_as(batch.returns, current_q)
target_q = target_q[:, None]
critic_loss = F.mse_loss(current_q, target_q) critic_loss = F.mse_loss(current_q, target_q)
self.critic_optim.zero_grad() self.critic_optim.zero_grad()
critic_loss.backward() critic_loss.backward()

View File

@ -6,7 +6,7 @@ from typing import Dict, Union, Optional
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \ from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \
to_torch, to_numpy to_torch_as, to_numpy
class DQNPolicy(BasePolicy): class DQNPolicy(BasePolicy):
@ -69,18 +69,18 @@ class DQNPolicy(BasePolicy):
self.model_old.load_state_dict(self.model.state_dict()) self.model_old.load_state_dict(self.model.state_dict())
def _target_q(self, buffer: ReplayBuffer, def _target_q(self, buffer: ReplayBuffer,
indice: np.ndarray) -> np.ndarray: indice: np.ndarray) -> torch.Tensor:
data = buffer[indice] batch = buffer[indice] # batch.obs_next: s_{t+n}
if self._target: if self._target:
# target_Q = Q_old(s_, argmax(Q_new(s_, *))) # target_Q = Q_old(s_, argmax(Q_new(s_, *)))
a = self(data, input='obs_next', eps=0).act a = self(batch, input='obs_next', eps=0).act
target_q = self( with torch.no_grad():
data, model='model_old', input='obs_next').logits target_q = self(
target_q = to_numpy(target_q) batch, model='model_old', input='obs_next').logits
target_q = target_q[np.arange(len(a)), a] target_q = target_q[np.arange(len(a)), a]
else: else:
target_q = self(data, input='obs_next').logits with torch.no_grad():
target_q = to_numpy(target_q).max(axis=1) target_q = self(batch, input='obs_next').logits.max(dim=1)[0]
return target_q return target_q
def process_fn(self, batch: Batch, buffer: ReplayBuffer, def process_fn(self, batch: Batch, buffer: ReplayBuffer,
@ -144,12 +144,11 @@ class DQNPolicy(BasePolicy):
self.optim.zero_grad() self.optim.zero_grad()
q = self(batch).logits q = self(batch).logits
q = q[np.arange(len(q)), batch.act] q = q[np.arange(len(q)), batch.act]
r = to_torch(batch.returns, device=q.device, dtype=q.dtype) r = to_torch_as(batch.returns, q)
if hasattr(batch, 'update_weight'): if hasattr(batch, 'update_weight'):
td = r - q td = r - q
batch.update_weight(batch.indice, to_numpy(td)) batch.update_weight(batch.indice, to_numpy(td))
impt_weight = to_torch(batch.impt_weight, impt_weight = to_torch_as(batch.impt_weight, q)
device=q.device, dtype=torch.float)
loss = (td.pow(2) * impt_weight).mean() loss = (td.pow(2) * impt_weight).mean()
else: else:
loss = F.mse_loss(q, r) loss = F.mse_loss(q, r)

View File

@ -4,7 +4,7 @@ from torch import nn
from typing import Dict, List, Tuple, Union, Optional from typing import Dict, List, Tuple, Union, Optional
from tianshou.policy import PGPolicy from tianshou.policy import PGPolicy
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
class PPOPolicy(PGPolicy): class PPOPolicy(PGPolicy):
@ -129,14 +129,12 @@ class PPOPolicy(PGPolicy):
for b in batch.split(batch_size, shuffle=False): for b in batch.split(batch_size, shuffle=False):
v.append(self.critic(b.obs)) v.append(self.critic(b.obs))
old_log_prob.append(self(b).dist.log_prob( old_log_prob.append(self(b).dist.log_prob(
to_torch(b.act, device=v[0].device))) to_torch_as(b.act, v[0])))
batch.v = torch.cat(v, dim=0) # old value batch.v = torch.cat(v, dim=0) # old value
dev = batch.v.device batch.act = to_torch_as(batch.act, v[0])
batch.act = to_torch(batch.act, dtype=torch.float, device=dev)
batch.logp_old = torch.cat(old_log_prob, dim=0) batch.logp_old = torch.cat(old_log_prob, dim=0)
batch.returns = to_torch( batch.returns = to_torch_as(
batch.returns, dtype=torch.float, device=dev batch.returns, v[0]).reshape(batch.v.shape)
).reshape(batch.v.shape)
if self._rew_norm: if self._rew_norm:
mean, std = batch.returns.mean(), batch.returns.std() mean, std = batch.returns.mean(), batch.returns.std()
if not np.isclose(std.item(), 0): if not np.isclose(std.item(), 0):

View File

@ -4,9 +4,9 @@ from copy import deepcopy
import torch.nn.functional as F import torch.nn.functional as F
from typing import Dict, Tuple, Union, Optional from typing import Dict, Tuple, Union, Optional
from tianshou.data import Batch, to_torch
from tianshou.policy import DDPGPolicy from tianshou.policy import DDPGPolicy
from tianshou.policy.dist import DiagGaussian from tianshou.policy.dist import DiagGaussian
from tianshou.data import Batch, to_torch_as, ReplayBuffer
class SACPolicy(DDPGPolicy): class SACPolicy(DDPGPolicy):
@ -55,10 +55,11 @@ class SACPolicy(DDPGPolicy):
action_range: Optional[Tuple[float, float]] = None, action_range: Optional[Tuple[float, float]] = None,
reward_normalization: bool = False, reward_normalization: bool = False,
ignore_done: bool = False, ignore_done: bool = False,
estimation_step: int = 1,
**kwargs) -> None: **kwargs) -> None:
super().__init__(None, None, None, None, tau, gamma, 0, super().__init__(None, None, None, None, tau, gamma, 0,
action_range, reward_normalization, ignore_done, action_range, reward_normalization, ignore_done,
**kwargs) estimation_step, **kwargs)
self.actor, self.actor_optim = actor, actor_optim self.actor, self.actor_optim = actor, actor_optim
self.critic1, self.critic1_old = critic1, deepcopy(critic1) self.critic1, self.critic1_old = critic1, deepcopy(critic1)
self.critic1_old.eval() self.critic1_old.eval()
@ -105,23 +106,23 @@ class SACPolicy(DDPGPolicy):
return Batch( return Batch(
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob) logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: def _target_q(self, buffer: ReplayBuffer,
indice: np.ndarray) -> torch.Tensor:
batch = buffer[indice] # batch.obs: s_{t+n}
with torch.no_grad(): with torch.no_grad():
obs_next_result = self(batch, input='obs_next') obs_next_result = self(batch, input='obs_next')
a_ = obs_next_result.act a_ = obs_next_result.act
dev = a_.device batch.act = to_torch_as(batch.act, a_)
batch.act = to_torch(batch.act, dtype=torch.float, device=dev)
target_q = torch.min( target_q = torch.min(
self.critic1_old(batch.obs_next, a_), self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_), self.critic2_old(batch.obs_next, a_),
) - self._alpha * obs_next_result.log_prob ) - self._alpha * obs_next_result.log_prob
rew = to_torch(batch.rew, return target_q
dtype=torch.float, device=dev)[:, None]
done = to_torch(batch.done, def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q)
# critic 1 # critic 1
current_q1 = self.critic1(batch.obs, batch.act) current_q1 = self.critic1(batch.obs, batch.act)
target_q = to_torch_as(batch.returns, current_q1)[:, None]
critic1_loss = F.mse_loss(current_q1, target_q) critic1_loss = F.mse_loss(current_q1, target_q)
self.critic1_optim.zero_grad() self.critic1_optim.zero_grad()
critic1_loss.backward() critic1_loss.backward()

View File

@ -1,10 +1,11 @@
import torch import torch
import numpy as np
from copy import deepcopy from copy import deepcopy
import torch.nn.functional as F import torch.nn.functional as F
from typing import Dict, Tuple, Optional from typing import Dict, Tuple, Optional
from tianshou.data import Batch, to_torch
from tianshou.policy import DDPGPolicy from tianshou.policy import DDPGPolicy
from tianshou.data import Batch, ReplayBuffer
class TD3Policy(DDPGPolicy): class TD3Policy(DDPGPolicy):
@ -62,10 +63,11 @@ class TD3Policy(DDPGPolicy):
action_range: Optional[Tuple[float, float]] = None, action_range: Optional[Tuple[float, float]] = None,
reward_normalization: bool = False, reward_normalization: bool = False,
ignore_done: bool = False, ignore_done: bool = False,
estimation_step: int = 1,
**kwargs) -> None: **kwargs) -> None:
super().__init__(actor, actor_optim, None, None, tau, gamma, super().__init__(actor, actor_optim, None, None, tau, gamma,
exploration_noise, action_range, reward_normalization, exploration_noise, action_range, reward_normalization,
ignore_done, **kwargs) ignore_done, estimation_step, **kwargs)
self.critic1, self.critic1_old = critic1, deepcopy(critic1) self.critic1, self.critic1_old = critic1, deepcopy(critic1)
self.critic1_old.eval() self.critic1_old.eval()
self.critic1_optim = critic1_optim self.critic1_optim = critic1_optim
@ -100,25 +102,26 @@ class TD3Policy(DDPGPolicy):
self.critic2_old.parameters(), self.critic2.parameters()): self.critic2_old.parameters(), self.critic2.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: def _target_q(self, buffer: ReplayBuffer,
indice: np.ndarray) -> torch.Tensor:
batch = buffer[indice] # batch.obs: s_{t+n}
with torch.no_grad(): with torch.no_grad():
a_ = self(batch, model='actor_old', input='obs_next').act a_ = self(batch, model='actor_old', input='obs_next').act
dev = a_.device dev = a_.device
noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise
if self._noise_clip >= 0: if self._noise_clip > 0:
noise = noise.clamp(-self._noise_clip, self._noise_clip) noise = noise.clamp(-self._noise_clip, self._noise_clip)
a_ += noise a_ += noise
a_ = a_.clamp(self._range[0], self._range[1]) a_ = a_.clamp(self._range[0], self._range[1])
target_q = torch.min( target_q = torch.min(
self.critic1_old(batch.obs_next, a_), self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_)) self.critic2_old(batch.obs_next, a_))
rew = to_torch(batch.rew, return target_q
dtype=torch.float, device=dev)[:, None]
done = to_torch(batch.done, def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q)
# critic 1 # critic 1
current_q1 = self.critic1(batch.obs, batch.act) current_q1 = self.critic1(batch.obs, batch.act)
target_q = batch.returns[:, None]
critic1_loss = F.mse_loss(current_q1, target_q) critic1_loss = F.mse_loss(current_q1, target_q)
self.critic1_optim.zero_grad() self.critic1_optim.zero_grad()
critic1_loss.backward() critic1_loss.backward()