From dc451dfe888a99abec99fe8373fa15a874007374 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 3 Jun 2020 13:59:47 +0800 Subject: [PATCH] nstep all (fix #51) --- README.md | 6 ++--- test/continuous/test_ddpg.py | 8 +++++-- test/continuous/test_ppo.py | 6 ++--- test/continuous/test_sac_with_il.py | 10 ++++++--- test/continuous/test_td3.py | 8 +++++-- test/discrete/test_a2c_with_il.py | 4 ++-- test/discrete/test_pg.py | 2 +- test/discrete/test_ppo.py | 4 ++-- tianshou/data/__init__.py | 4 +++- tianshou/data/utils.py | 10 +++++++++ tianshou/policy/base.py | 30 ++++++++++++++++++------- tianshou/policy/modelfree/a2c.py | 8 +++---- tianshou/policy/modelfree/ddpg.py | 34 ++++++++++++++++------------- tianshou/policy/modelfree/dqn.py | 23 ++++++++++--------- tianshou/policy/modelfree/ppo.py | 12 +++++----- tianshou/policy/modelfree/sac.py | 21 +++++++++--------- tianshou/policy/modelfree/td3.py | 21 ++++++++++-------- 17 files changed, 127 insertions(+), 84 deletions(-) diff --git a/README.md b/README.md index 65d576a..586cbb6 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ - [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf) - [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) -- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) with n-step returns +- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) - [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/) - [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf) - [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf) @@ -30,7 +30,7 @@ - [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf) - [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf) -Tianshou supports parallel workers for all algorithms as well since all of them are reformatted as replay-buffer based algorithms. All of the algorithms support recurrent state representation in actor network (RNN-style training in POMDP). The environment state can be any type (Dict, self-defined class, ...). +**Tianshou supports parallel workers for all algorithms as well since all of them are reformatted as replay-buffer based algorithms. All of the algorithms support recurrent state representation in actor network (RNN-style training in POMDP). The environment state can be any type (dict, self-defined class, ...). All Q-learning algorithms support n-step returns estimation.** In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment. @@ -102,7 +102,7 @@ We select some of famous reinforcement learning platforms: 2 GitHub repos with m | Algo - Task | PyTorch | TensorFlow | TensorFlow | TF/PyTorch | PyTorch | PyTorch | | PG - CartPole | 6.09±4.60s | None | None | 19.26±2.29s | None | ? | | DQN - CartPole | 6.09±0.87s | 1046.34±291.27s | 93.47±58.05s | 28.56±4.60s | 31.58±11.30s \*\* | ? | -| A2C - CartPole | 6.36±1.63s | \*(~1612s) | 57.56±12.87s | 57.92±9.94s | \*(Not converged) | ? | +| A2C - CartPole | 10.59±2.04s | \*(~1612s) | 57.56±12.87s | 57.92±9.94s | \*(Not converged) | ? | | PPO - CartPole | 31.82±7.76s | \*(~1179s) | 34.79±17.02s | 44.60±17.04s | 23.99±9.26s \*\* | ? | | PPO - Pendulum | 16.18±2.49s | 745.43±160.82s | 259.73±27.37s | 123.62±44.23s | Runtime Error | ? | | DDPG - Pendulum | 37.26±9.55s | \*(>1h) | 277.52±92.67s | 314.70±7.92s | 59.05±10.03s \*\* | 172.18±62.48s | diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 60241c7..ddfcd04 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -36,7 +36,9 @@ def get_args(): parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - parser.add_argument('--rew-norm', type=bool, default=True) + parser.add_argument('--rew-norm', type=int, default=1) + parser.add_argument('--ignore-done', type=int, default=1) + parser.add_argument('--n-step', type=int, default=1) parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') @@ -78,7 +80,9 @@ def test_ddpg(args=get_args()): actor, actor_optim, critic, critic_optim, args.tau, args.gamma, args.exploration_noise, [env.action_space.low[0], env.action_space.high[0]], - reward_normalization=args.rew_norm, ignore_done=True) + reward_normalization=args.rew_norm, + ignore_done=args.ignore_done, + estimation_step=args.n_step) # collector train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 8280f7f..dd0e765 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -44,9 +44,9 @@ def get_args(): parser.add_argument('--eps-clip', type=float, default=0.2) parser.add_argument('--max-grad-norm', type=float, default=0.5) parser.add_argument('--gae-lambda', type=float, default=0.95) - parser.add_argument('--rew-norm', type=bool, default=True) - # parser.add_argument('--dual-clip', type=float, default=5.) - parser.add_argument('--value-clip', type=bool, default=True) + parser.add_argument('--rew-norm', type=int, default=1) + parser.add_argument('--dual-clip', type=float, default=None) + parser.add_argument('--value-clip', type=int, default=1) args = parser.parse_known_args()[0] return args diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 65d9e47..8b1a3d1 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -37,7 +37,9 @@ def get_args(): parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - parser.add_argument('--rew-norm', type=bool, default=True) + parser.add_argument('--rew-norm', type=int, default=1) + parser.add_argument('--ignore-done', type=int, default=1) + parser.add_argument('--n-step', type=int, default=4) parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') @@ -83,7 +85,9 @@ def test_sac_with_il(args=get_args()): actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, args.tau, args.gamma, args.alpha, [env.action_space.low[0], env.action_space.high[0]], - reward_normalization=args.rew_norm, ignore_done=True) + reward_normalization=args.rew_norm, + ignore_done=args.ignore_done, + estimation_step=args.n_step) # collector train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) @@ -126,7 +130,7 @@ def test_sac_with_il(args=get_args()): train_collector.reset() result = offpolicy_trainer( il_policy, train_collector, il_test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch // 5, args.collect_per_step, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) train_collector.close() diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 5312134..5ee3d80 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -39,7 +39,9 @@ def get_args(): parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - parser.add_argument('--rew-norm', type=bool, default=True) + parser.add_argument('--rew-norm', type=int, default=1) + parser.add_argument('--ignore-done', type=int, default=1) + parser.add_argument('--n-step', type=int, default=1) parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') @@ -86,7 +88,9 @@ def test_td3(args=get_args()): args.tau, args.gamma, args.exploration_noise, args.policy_noise, args.update_actor_freq, args.noise_clip, [env.action_space.low[0], env.action_space.high[0]], - reward_normalization=args.rew_norm, ignore_done=True) + reward_normalization=args.rew_norm, + ignore_done=args.ignore_done, + estimation_step=args.n_step) # collector train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 1b12803..267732f 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -31,7 +31,7 @@ def get_args(): parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--layer-num', type=int, default=2) - parser.add_argument('--training-num', type=int, default=32) + parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) @@ -40,7 +40,7 @@ def get_args(): default='cuda' if torch.cuda.is_available() else 'cpu') # a2c special parser.add_argument('--vf-coef', type=float, default=0.5) - parser.add_argument('--ent-coef', type=float, default=0.001) + parser.add_argument('--ent-coef', type=float, default=0.0) parser.add_argument('--max-grad-norm', type=float, default=None) parser.add_argument('--gae-lambda', type=float, default=1.) parser.add_argument('--rew-norm', type=bool, default=False) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 1890fba..7607151 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -102,7 +102,7 @@ def get_args(): parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - parser.add_argument('--rew-norm', type=bool, default=True) + parser.add_argument('--rew-norm', type=int, default=1) parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 2d2a484..44850a8 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -43,9 +43,9 @@ def get_args(): parser.add_argument('--eps-clip', type=float, default=0.2) parser.add_argument('--max-grad-norm', type=float, default=0.5) parser.add_argument('--gae-lambda', type=float, default=0.8) - parser.add_argument('--rew-norm', type=bool, default=True) + parser.add_argument('--rew-norm', type=int, default=1) parser.add_argument('--dual-clip', type=float, default=None) - parser.add_argument('--value-clip', type=bool, default=True) + parser.add_argument('--value-clip', type=int, default=1) args = parser.parse_known_args()[0] return args diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index fd57d87..5d097a0 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -1,5 +1,6 @@ from tianshou.data.batch import Batch -from tianshou.data.utils import to_numpy, to_torch +from tianshou.data.utils import to_numpy, to_torch, \ + to_torch_as from tianshou.data.buffer import ReplayBuffer, \ ListReplayBuffer, PrioritizedReplayBuffer from tianshou.data.collector import Collector @@ -8,6 +9,7 @@ __all__ = [ 'Batch', 'to_numpy', 'to_torch', + 'to_torch_as', 'ReplayBuffer', 'ListReplayBuffer', 'PrioritizedReplayBuffer', diff --git a/tianshou/data/utils.py b/tianshou/data/utils.py index 62dfb46..8890c4b 100644 --- a/tianshou/data/utils.py +++ b/tianshou/data/utils.py @@ -38,3 +38,13 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray], elif isinstance(x, Batch): x.to_torch(dtype, device) return x + + +def to_torch_as(x: Union[torch.Tensor, dict, Batch, np.ndarray], + y: torch.Tensor + ) -> Union[dict, Batch, torch.Tensor]: + """Return an object without np.ndarray. Same as + ``to_torch(x, dtype=y.dtype, device=y.device)``. + """ + assert isinstance(y, torch.Tensor) + return to_torch(x, dtype=y.dtype, device=y.device) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 387cfed..dba4357 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -4,7 +4,7 @@ from torch import nn from abc import ABC, abstractmethod from typing import Dict, List, Union, Optional, Callable -from tianshou.data import Batch, ReplayBuffer +from tianshou.data import Batch, ReplayBuffer, to_torch_as class BasePolicy(ABC, nn.Module): @@ -138,9 +138,10 @@ class BasePolicy(ABC, nn.Module): batch: Batch, buffer: ReplayBuffer, indice: np.ndarray, - target_q_fn: Callable[[ReplayBuffer, np.ndarray], np.ndarray], + target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor], gamma: float = 0.99, - n_step: int = 1 + n_step: int = 1, + rew_norm: bool = False ) -> np.ndarray: r"""Compute n-step return for Q-learning targets: @@ -159,13 +160,25 @@ class BasePolicy(ABC, nn.Module): :type buffer: :class:`~tianshou.data.ReplayBuffer` :param indice: sampled timestep. :type indice: numpy.ndarray + :param function target_q_fn: a function receives :math:`t+n-1` step's + data and compute target Q value. :param float gamma: the discount factor, should be in [0, 1], defaults to 0.99. :param int n_step: the number of estimation step, should be an int greater than 0, defaults to 1. + :param bool rew_norm: normalize the reward to Normal(0, 1), defaults + to ``False``. - :return: a Batch. The result will be stored in batch.returns. + :return: a Batch. The result will be stored in batch.returns as a + torch.Tensor with shape (bsz, ). """ + if rew_norm: + bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer + mean, std = bfr.mean(), bfr.std() + if np.isclose(std, 0): + mean, std = 0, 1 + else: + mean, std = 0, 1 returns = np.zeros_like(indice) gammas = np.zeros_like(indice) + n_step done, rew, buf_len = buffer.done, buffer.rew, len(buffer) @@ -173,10 +186,11 @@ class BasePolicy(ABC, nn.Module): now = (indice + n) % buf_len gammas[done[now] > 0] = n returns[done[now] > 0] = 0 - returns = rew[now] + gamma * returns + returns = (rew[now] - mean) / std + gamma * returns terminal = (indice + n_step - 1) % buf_len - target_q = target_q_fn(buffer, terminal) + target_q = target_q_fn(buffer, terminal).squeeze() target_q[gammas != n_step] = 0 - returns += (gamma ** gammas) * target_q - batch.returns = returns + returns = to_torch_as(returns, target_q) + gammas = to_torch_as(gamma ** gammas, target_q) + batch.returns = target_q * gammas + returns return batch diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index bb40a25..f4de285 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from typing import Dict, List, Union, Optional from tianshou.policy import PGPolicy -from tianshou.data import Batch, ReplayBuffer, to_torch, to_numpy +from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy class A2CPolicy(PGPolicy): @@ -106,14 +106,14 @@ class A2CPolicy(PGPolicy): self.optim.zero_grad() dist = self(b).dist v = self.critic(b.obs) - a = to_torch(b.act, device=v.device) - r = to_torch(b.returns, device=v.device) + a = to_torch_as(b.act, v) + r = to_torch_as(b.returns, v) a_loss = -(dist.log_prob(a) * (r - v).detach()).mean() vf_loss = F.mse_loss(r[:, None], v) ent_loss = dist.entropy().mean() loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss loss.backward() - if self._grad_norm: + if self._grad_norm is not None: nn.utils.clip_grad_norm_( list(self.actor.parameters()) + list(self.critic.parameters()), diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index b211b44..5d3dde3 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -6,7 +6,7 @@ from typing import Dict, Tuple, Union, Optional from tianshou.policy import BasePolicy # from tianshou.exploration import OUNoise -from tianshou.data import Batch, ReplayBuffer, to_torch +from tianshou.data import Batch, ReplayBuffer, to_torch_as class DDPGPolicy(BasePolicy): @@ -29,6 +29,8 @@ class DDPGPolicy(BasePolicy): defaults to ``False``. :param bool ignore_done: ignore the done flag while training the policy, defaults to ``False``. + :param int estimation_step: greater than 1, the number of steps to look + ahead. .. seealso:: @@ -47,6 +49,7 @@ class DDPGPolicy(BasePolicy): action_range: Optional[Tuple[float, float]] = None, reward_normalization: bool = False, ignore_done: bool = False, + estimation_step: int = 1, **kwargs) -> None: super().__init__(**kwargs) if actor is not None: @@ -71,6 +74,8 @@ class DDPGPolicy(BasePolicy): # self.noise = OUNoise() self._rm_done = ignore_done self._rew_norm = reward_normalization + assert estimation_step > 0, 'estimation_step should greater than 0' + self._n_step = estimation_step def set_eps(self, eps: float) -> None: """Set the eps for exploration.""" @@ -96,15 +101,21 @@ class DDPGPolicy(BasePolicy): self.critic_old.parameters(), self.critic.parameters()): o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) + def _target_q(self, buffer: ReplayBuffer, + indice: np.ndarray) -> torch.Tensor: + batch = buffer[indice] # batch.obs_next: s_{t+n} + with torch.no_grad(): + target_q = self.critic_old(batch.obs_next, self( + batch, model='actor_old', input='obs_next', eps=0).act) + return target_q + def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: - if self._rew_norm: - bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer - mean, std = bfr.mean(), bfr.std() - if not np.isclose(std, 0): - batch.rew = (batch.rew - mean) / std if self._rm_done: batch.done = batch.done * 0. + batch = self.compute_nstep_return( + batch, buffer, indice, self._target_q, + self._gamma, self._n_step, self._rew_norm) return batch def forward(self, batch: Batch, @@ -143,16 +154,9 @@ class DDPGPolicy(BasePolicy): return Batch(act=logits, state=h) def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: - with torch.no_grad(): - target_q = self.critic_old(batch.obs_next, self( - batch, model='actor_old', input='obs_next', eps=0).act) - dev = target_q.device - rew = to_torch(batch.rew, - dtype=torch.float, device=dev)[:, None] - done = to_torch(batch.done, - dtype=torch.float, device=dev)[:, None] - target_q = (rew + (1. - done) * self._gamma * target_q) current_q = self.critic(batch.obs, batch.act) + target_q = to_torch_as(batch.returns, current_q) + target_q = target_q[:, None] critic_loss = F.mse_loss(current_q, target_q) self.critic_optim.zero_grad() critic_loss.backward() diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index f535079..0df4e0d 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -6,7 +6,7 @@ from typing import Dict, Union, Optional from tianshou.policy import BasePolicy from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \ - to_torch, to_numpy + to_torch_as, to_numpy class DQNPolicy(BasePolicy): @@ -69,18 +69,18 @@ class DQNPolicy(BasePolicy): self.model_old.load_state_dict(self.model.state_dict()) def _target_q(self, buffer: ReplayBuffer, - indice: np.ndarray) -> np.ndarray: - data = buffer[indice] + indice: np.ndarray) -> torch.Tensor: + batch = buffer[indice] # batch.obs_next: s_{t+n} if self._target: # target_Q = Q_old(s_, argmax(Q_new(s_, *))) - a = self(data, input='obs_next', eps=0).act - target_q = self( - data, model='model_old', input='obs_next').logits - target_q = to_numpy(target_q) + a = self(batch, input='obs_next', eps=0).act + with torch.no_grad(): + target_q = self( + batch, model='model_old', input='obs_next').logits target_q = target_q[np.arange(len(a)), a] else: - target_q = self(data, input='obs_next').logits - target_q = to_numpy(target_q).max(axis=1) + with torch.no_grad(): + target_q = self(batch, input='obs_next').logits.max(dim=1)[0] return target_q def process_fn(self, batch: Batch, buffer: ReplayBuffer, @@ -144,12 +144,11 @@ class DQNPolicy(BasePolicy): self.optim.zero_grad() q = self(batch).logits q = q[np.arange(len(q)), batch.act] - r = to_torch(batch.returns, device=q.device, dtype=q.dtype) + r = to_torch_as(batch.returns, q) if hasattr(batch, 'update_weight'): td = r - q batch.update_weight(batch.indice, to_numpy(td)) - impt_weight = to_torch(batch.impt_weight, - device=q.device, dtype=torch.float) + impt_weight = to_torch_as(batch.impt_weight, q) loss = (td.pow(2) * impt_weight).mean() else: loss = F.mse_loss(q, r) diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 7c1d538..a7abec5 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -4,7 +4,7 @@ from torch import nn from typing import Dict, List, Tuple, Union, Optional from tianshou.policy import PGPolicy -from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch +from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as class PPOPolicy(PGPolicy): @@ -129,14 +129,12 @@ class PPOPolicy(PGPolicy): for b in batch.split(batch_size, shuffle=False): v.append(self.critic(b.obs)) old_log_prob.append(self(b).dist.log_prob( - to_torch(b.act, device=v[0].device))) + to_torch_as(b.act, v[0]))) batch.v = torch.cat(v, dim=0) # old value - dev = batch.v.device - batch.act = to_torch(batch.act, dtype=torch.float, device=dev) + batch.act = to_torch_as(batch.act, v[0]) batch.logp_old = torch.cat(old_log_prob, dim=0) - batch.returns = to_torch( - batch.returns, dtype=torch.float, device=dev - ).reshape(batch.v.shape) + batch.returns = to_torch_as( + batch.returns, v[0]).reshape(batch.v.shape) if self._rew_norm: mean, std = batch.returns.mean(), batch.returns.std() if not np.isclose(std.item(), 0): diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index bbf258c..9dd4b8a 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -4,9 +4,9 @@ from copy import deepcopy import torch.nn.functional as F from typing import Dict, Tuple, Union, Optional -from tianshou.data import Batch, to_torch from tianshou.policy import DDPGPolicy from tianshou.policy.dist import DiagGaussian +from tianshou.data import Batch, to_torch_as, ReplayBuffer class SACPolicy(DDPGPolicy): @@ -55,10 +55,11 @@ class SACPolicy(DDPGPolicy): action_range: Optional[Tuple[float, float]] = None, reward_normalization: bool = False, ignore_done: bool = False, + estimation_step: int = 1, **kwargs) -> None: super().__init__(None, None, None, None, tau, gamma, 0, action_range, reward_normalization, ignore_done, - **kwargs) + estimation_step, **kwargs) self.actor, self.actor_optim = actor, actor_optim self.critic1, self.critic1_old = critic1, deepcopy(critic1) self.critic1_old.eval() @@ -105,23 +106,23 @@ class SACPolicy(DDPGPolicy): return Batch( logits=logits, act=act, state=h, dist=dist, log_prob=log_prob) - def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: + def _target_q(self, buffer: ReplayBuffer, + indice: np.ndarray) -> torch.Tensor: + batch = buffer[indice] # batch.obs: s_{t+n} with torch.no_grad(): obs_next_result = self(batch, input='obs_next') a_ = obs_next_result.act - dev = a_.device - batch.act = to_torch(batch.act, dtype=torch.float, device=dev) + batch.act = to_torch_as(batch.act, a_) target_q = torch.min( self.critic1_old(batch.obs_next, a_), self.critic2_old(batch.obs_next, a_), ) - self._alpha * obs_next_result.log_prob - rew = to_torch(batch.rew, - dtype=torch.float, device=dev)[:, None] - done = to_torch(batch.done, - dtype=torch.float, device=dev)[:, None] - target_q = (rew + (1. - done) * self._gamma * target_q) + return target_q + + def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: # critic 1 current_q1 = self.critic1(batch.obs, batch.act) + target_q = to_torch_as(batch.returns, current_q1)[:, None] critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() critic1_loss.backward() diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 9f0b3cd..dd3004c 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -1,10 +1,11 @@ import torch +import numpy as np from copy import deepcopy import torch.nn.functional as F from typing import Dict, Tuple, Optional -from tianshou.data import Batch, to_torch from tianshou.policy import DDPGPolicy +from tianshou.data import Batch, ReplayBuffer class TD3Policy(DDPGPolicy): @@ -62,10 +63,11 @@ class TD3Policy(DDPGPolicy): action_range: Optional[Tuple[float, float]] = None, reward_normalization: bool = False, ignore_done: bool = False, + estimation_step: int = 1, **kwargs) -> None: super().__init__(actor, actor_optim, None, None, tau, gamma, exploration_noise, action_range, reward_normalization, - ignore_done, **kwargs) + ignore_done, estimation_step, **kwargs) self.critic1, self.critic1_old = critic1, deepcopy(critic1) self.critic1_old.eval() self.critic1_optim = critic1_optim @@ -100,25 +102,26 @@ class TD3Policy(DDPGPolicy): self.critic2_old.parameters(), self.critic2.parameters()): o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) - def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: + def _target_q(self, buffer: ReplayBuffer, + indice: np.ndarray) -> torch.Tensor: + batch = buffer[indice] # batch.obs: s_{t+n} with torch.no_grad(): a_ = self(batch, model='actor_old', input='obs_next').act dev = a_.device noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise - if self._noise_clip >= 0: + if self._noise_clip > 0: noise = noise.clamp(-self._noise_clip, self._noise_clip) a_ += noise a_ = a_.clamp(self._range[0], self._range[1]) target_q = torch.min( self.critic1_old(batch.obs_next, a_), self.critic2_old(batch.obs_next, a_)) - rew = to_torch(batch.rew, - dtype=torch.float, device=dev)[:, None] - done = to_torch(batch.done, - dtype=torch.float, device=dev)[:, None] - target_q = (rew + (1. - done) * self._gamma * target_q) + return target_q + + def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: # critic 1 current_q1 = self.critic1(batch.obs, batch.act) + target_q = batch.returns[:, None] critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() critic1_loss.backward()