nstep all (fix #51)

This commit is contained in:
Trinkle23897 2020-06-03 13:59:47 +08:00
parent ff81a18f42
commit dc451dfe88
17 changed files with 127 additions and 84 deletions

View File

@ -20,7 +20,7 @@
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) with n-step returns
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf)
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
@ -30,7 +30,7 @@
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
Tianshou supports parallel workers for all algorithms as well since all of them are reformatted as replay-buffer based algorithms. All of the algorithms support recurrent state representation in actor network (RNN-style training in POMDP). The environment state can be any type (Dict, self-defined class, ...).
**Tianshou supports parallel workers for all algorithms as well since all of them are reformatted as replay-buffer based algorithms. All of the algorithms support recurrent state representation in actor network (RNN-style training in POMDP). The environment state can be any type (dict, self-defined class, ...). All Q-learning algorithms support n-step returns estimation.**
In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.
@ -102,7 +102,7 @@ We select some of famous reinforcement learning platforms: 2 GitHub repos with m
| Algo - Task | PyTorch | TensorFlow | TensorFlow | TF/PyTorch | PyTorch | PyTorch |
| PG - CartPole | 6.09±4.60s | None | None | 19.26±2.29s | None | ? |
| DQN - CartPole | 6.09±0.87s | 1046.34±291.27s | 93.47±58.05s | 28.56±4.60s | 31.58±11.30s \*\* | ? |
| A2C - CartPole | 6.36±1.63s | \*(~1612s) | 57.56±12.87s | 57.92±9.94s | \*(Not converged) | ? |
| A2C - CartPole | 10.59±2.04s | \*(~1612s) | 57.56±12.87s | 57.92±9.94s | \*(Not converged) | ? |
| PPO - CartPole | 31.82±7.76s | \*(~1179s) | 34.79±17.02s | 44.60±17.04s | 23.99±9.26s \*\* | ? |
| PPO - Pendulum | 16.18±2.49s | 745.43±160.82s | 259.73±27.37s | 123.62±44.23s | Runtime Error | ? |
| DDPG - Pendulum | 37.26±9.55s | \*(>1h) | 277.52±92.67s | 314.70±7.92s | 59.05±10.03s \*\* | 172.18±62.48s |

View File

@ -36,7 +36,9 @@ def get_args():
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--rew-norm', type=bool, default=True)
parser.add_argument('--rew-norm', type=int, default=1)
parser.add_argument('--ignore-done', type=int, default=1)
parser.add_argument('--n-step', type=int, default=1)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
@ -78,7 +80,9 @@ def test_ddpg(args=get_args()):
actor, actor_optim, critic, critic_optim,
args.tau, args.gamma, args.exploration_noise,
[env.action_space.low[0], env.action_space.high[0]],
reward_normalization=args.rew_norm, ignore_done=True)
reward_normalization=args.rew_norm,
ignore_done=args.ignore_done,
estimation_step=args.n_step)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))

View File

@ -44,9 +44,9 @@ def get_args():
parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--gae-lambda', type=float, default=0.95)
parser.add_argument('--rew-norm', type=bool, default=True)
# parser.add_argument('--dual-clip', type=float, default=5.)
parser.add_argument('--value-clip', type=bool, default=True)
parser.add_argument('--rew-norm', type=int, default=1)
parser.add_argument('--dual-clip', type=float, default=None)
parser.add_argument('--value-clip', type=int, default=1)
args = parser.parse_known_args()[0]
return args

View File

@ -37,7 +37,9 @@ def get_args():
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--rew-norm', type=bool, default=True)
parser.add_argument('--rew-norm', type=int, default=1)
parser.add_argument('--ignore-done', type=int, default=1)
parser.add_argument('--n-step', type=int, default=4)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
@ -83,7 +85,9 @@ def test_sac_with_il(args=get_args()):
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, args.alpha,
[env.action_space.low[0], env.action_space.high[0]],
reward_normalization=args.rew_norm, ignore_done=True)
reward_normalization=args.rew_norm,
ignore_done=args.ignore_done,
estimation_step=args.n_step)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
@ -126,7 +130,7 @@ def test_sac_with_il(args=get_args()):
train_collector.reset()
result = offpolicy_trainer(
il_policy, train_collector, il_test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.step_per_epoch // 5, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()

View File

@ -39,7 +39,9 @@ def get_args():
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--rew-norm', type=bool, default=True)
parser.add_argument('--rew-norm', type=int, default=1)
parser.add_argument('--ignore-done', type=int, default=1)
parser.add_argument('--n-step', type=int, default=1)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
@ -86,7 +88,9 @@ def test_td3(args=get_args()):
args.tau, args.gamma, args.exploration_noise, args.policy_noise,
args.update_actor_freq, args.noise_clip,
[env.action_space.low[0], env.action_space.high[0]],
reward_normalization=args.rew_norm, ignore_done=True)
reward_normalization=args.rew_norm,
ignore_done=args.ignore_done,
estimation_step=args.n_step)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))

View File

@ -31,7 +31,7 @@ def get_args():
parser.add_argument('--repeat-per-collect', type=int, default=1)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=2)
parser.add_argument('--training-num', type=int, default=32)
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
@ -40,7 +40,7 @@ def get_args():
default='cuda' if torch.cuda.is_available() else 'cpu')
# a2c special
parser.add_argument('--vf-coef', type=float, default=0.5)
parser.add_argument('--ent-coef', type=float, default=0.001)
parser.add_argument('--ent-coef', type=float, default=0.0)
parser.add_argument('--max-grad-norm', type=float, default=None)
parser.add_argument('--gae-lambda', type=float, default=1.)
parser.add_argument('--rew-norm', type=bool, default=False)

View File

@ -102,7 +102,7 @@ def get_args():
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--rew-norm', type=bool, default=True)
parser.add_argument('--rew-norm', type=int, default=1)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')

View File

@ -43,9 +43,9 @@ def get_args():
parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--gae-lambda', type=float, default=0.8)
parser.add_argument('--rew-norm', type=bool, default=True)
parser.add_argument('--rew-norm', type=int, default=1)
parser.add_argument('--dual-clip', type=float, default=None)
parser.add_argument('--value-clip', type=bool, default=True)
parser.add_argument('--value-clip', type=int, default=1)
args = parser.parse_known_args()[0]
return args

View File

@ -1,5 +1,6 @@
from tianshou.data.batch import Batch
from tianshou.data.utils import to_numpy, to_torch
from tianshou.data.utils import to_numpy, to_torch, \
to_torch_as
from tianshou.data.buffer import ReplayBuffer, \
ListReplayBuffer, PrioritizedReplayBuffer
from tianshou.data.collector import Collector
@ -8,6 +9,7 @@ __all__ = [
'Batch',
'to_numpy',
'to_torch',
'to_torch_as',
'ReplayBuffer',
'ListReplayBuffer',
'PrioritizedReplayBuffer',

View File

@ -38,3 +38,13 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
elif isinstance(x, Batch):
x.to_torch(dtype, device)
return x
def to_torch_as(x: Union[torch.Tensor, dict, Batch, np.ndarray],
y: torch.Tensor
) -> Union[dict, Batch, torch.Tensor]:
"""Return an object without np.ndarray. Same as
``to_torch(x, dtype=y.dtype, device=y.device)``.
"""
assert isinstance(y, torch.Tensor)
return to_torch(x, dtype=y.dtype, device=y.device)

View File

@ -4,7 +4,7 @@ from torch import nn
from abc import ABC, abstractmethod
from typing import Dict, List, Union, Optional, Callable
from tianshou.data import Batch, ReplayBuffer
from tianshou.data import Batch, ReplayBuffer, to_torch_as
class BasePolicy(ABC, nn.Module):
@ -138,9 +138,10 @@ class BasePolicy(ABC, nn.Module):
batch: Batch,
buffer: ReplayBuffer,
indice: np.ndarray,
target_q_fn: Callable[[ReplayBuffer, np.ndarray], np.ndarray],
target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor],
gamma: float = 0.99,
n_step: int = 1
n_step: int = 1,
rew_norm: bool = False
) -> np.ndarray:
r"""Compute n-step return for Q-learning targets:
@ -159,13 +160,25 @@ class BasePolicy(ABC, nn.Module):
:type buffer: :class:`~tianshou.data.ReplayBuffer`
:param indice: sampled timestep.
:type indice: numpy.ndarray
:param function target_q_fn: a function receives :math:`t+n-1` step's
data and compute target Q value.
:param float gamma: the discount factor, should be in [0, 1], defaults
to 0.99.
:param int n_step: the number of estimation step, should be an int
greater than 0, defaults to 1.
:param bool rew_norm: normalize the reward to Normal(0, 1), defaults
to ``False``.
:return: a Batch. The result will be stored in batch.returns.
:return: a Batch. The result will be stored in batch.returns as a
torch.Tensor with shape (bsz, ).
"""
if rew_norm:
bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer
mean, std = bfr.mean(), bfr.std()
if np.isclose(std, 0):
mean, std = 0, 1
else:
mean, std = 0, 1
returns = np.zeros_like(indice)
gammas = np.zeros_like(indice) + n_step
done, rew, buf_len = buffer.done, buffer.rew, len(buffer)
@ -173,10 +186,11 @@ class BasePolicy(ABC, nn.Module):
now = (indice + n) % buf_len
gammas[done[now] > 0] = n
returns[done[now] > 0] = 0
returns = rew[now] + gamma * returns
returns = (rew[now] - mean) / std + gamma * returns
terminal = (indice + n_step - 1) % buf_len
target_q = target_q_fn(buffer, terminal)
target_q = target_q_fn(buffer, terminal).squeeze()
target_q[gammas != n_step] = 0
returns += (gamma ** gammas) * target_q
batch.returns = returns
returns = to_torch_as(returns, target_q)
gammas = to_torch_as(gamma ** gammas, target_q)
batch.returns = target_q * gammas + returns
return batch

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
from typing import Dict, List, Union, Optional
from tianshou.policy import PGPolicy
from tianshou.data import Batch, ReplayBuffer, to_torch, to_numpy
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
class A2CPolicy(PGPolicy):
@ -106,14 +106,14 @@ class A2CPolicy(PGPolicy):
self.optim.zero_grad()
dist = self(b).dist
v = self.critic(b.obs)
a = to_torch(b.act, device=v.device)
r = to_torch(b.returns, device=v.device)
a = to_torch_as(b.act, v)
r = to_torch_as(b.returns, v)
a_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
vf_loss = F.mse_loss(r[:, None], v)
ent_loss = dist.entropy().mean()
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
loss.backward()
if self._grad_norm:
if self._grad_norm is not None:
nn.utils.clip_grad_norm_(
list(self.actor.parameters()) +
list(self.critic.parameters()),

View File

@ -6,7 +6,7 @@ from typing import Dict, Tuple, Union, Optional
from tianshou.policy import BasePolicy
# from tianshou.exploration import OUNoise
from tianshou.data import Batch, ReplayBuffer, to_torch
from tianshou.data import Batch, ReplayBuffer, to_torch_as
class DDPGPolicy(BasePolicy):
@ -29,6 +29,8 @@ class DDPGPolicy(BasePolicy):
defaults to ``False``.
:param bool ignore_done: ignore the done flag while training the policy,
defaults to ``False``.
:param int estimation_step: greater than 1, the number of steps to look
ahead.
.. seealso::
@ -47,6 +49,7 @@ class DDPGPolicy(BasePolicy):
action_range: Optional[Tuple[float, float]] = None,
reward_normalization: bool = False,
ignore_done: bool = False,
estimation_step: int = 1,
**kwargs) -> None:
super().__init__(**kwargs)
if actor is not None:
@ -71,6 +74,8 @@ class DDPGPolicy(BasePolicy):
# self.noise = OUNoise()
self._rm_done = ignore_done
self._rew_norm = reward_normalization
assert estimation_step > 0, 'estimation_step should greater than 0'
self._n_step = estimation_step
def set_eps(self, eps: float) -> None:
"""Set the eps for exploration."""
@ -96,15 +101,21 @@ class DDPGPolicy(BasePolicy):
self.critic_old.parameters(), self.critic.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
def _target_q(self, buffer: ReplayBuffer,
indice: np.ndarray) -> torch.Tensor:
batch = buffer[indice] # batch.obs_next: s_{t+n}
with torch.no_grad():
target_q = self.critic_old(batch.obs_next, self(
batch, model='actor_old', input='obs_next', eps=0).act)
return target_q
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
if self._rew_norm:
bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer
mean, std = bfr.mean(), bfr.std()
if not np.isclose(std, 0):
batch.rew = (batch.rew - mean) / std
if self._rm_done:
batch.done = batch.done * 0.
batch = self.compute_nstep_return(
batch, buffer, indice, self._target_q,
self._gamma, self._n_step, self._rew_norm)
return batch
def forward(self, batch: Batch,
@ -143,16 +154,9 @@ class DDPGPolicy(BasePolicy):
return Batch(act=logits, state=h)
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
with torch.no_grad():
target_q = self.critic_old(batch.obs_next, self(
batch, model='actor_old', input='obs_next', eps=0).act)
dev = target_q.device
rew = to_torch(batch.rew,
dtype=torch.float, device=dev)[:, None]
done = to_torch(batch.done,
dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q)
current_q = self.critic(batch.obs, batch.act)
target_q = to_torch_as(batch.returns, current_q)
target_q = target_q[:, None]
critic_loss = F.mse_loss(current_q, target_q)
self.critic_optim.zero_grad()
critic_loss.backward()

View File

@ -6,7 +6,7 @@ from typing import Dict, Union, Optional
from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \
to_torch, to_numpy
to_torch_as, to_numpy
class DQNPolicy(BasePolicy):
@ -69,18 +69,18 @@ class DQNPolicy(BasePolicy):
self.model_old.load_state_dict(self.model.state_dict())
def _target_q(self, buffer: ReplayBuffer,
indice: np.ndarray) -> np.ndarray:
data = buffer[indice]
indice: np.ndarray) -> torch.Tensor:
batch = buffer[indice] # batch.obs_next: s_{t+n}
if self._target:
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
a = self(data, input='obs_next', eps=0).act
target_q = self(
data, model='model_old', input='obs_next').logits
target_q = to_numpy(target_q)
a = self(batch, input='obs_next', eps=0).act
with torch.no_grad():
target_q = self(
batch, model='model_old', input='obs_next').logits
target_q = target_q[np.arange(len(a)), a]
else:
target_q = self(data, input='obs_next').logits
target_q = to_numpy(target_q).max(axis=1)
with torch.no_grad():
target_q = self(batch, input='obs_next').logits.max(dim=1)[0]
return target_q
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
@ -144,12 +144,11 @@ class DQNPolicy(BasePolicy):
self.optim.zero_grad()
q = self(batch).logits
q = q[np.arange(len(q)), batch.act]
r = to_torch(batch.returns, device=q.device, dtype=q.dtype)
r = to_torch_as(batch.returns, q)
if hasattr(batch, 'update_weight'):
td = r - q
batch.update_weight(batch.indice, to_numpy(td))
impt_weight = to_torch(batch.impt_weight,
device=q.device, dtype=torch.float)
impt_weight = to_torch_as(batch.impt_weight, q)
loss = (td.pow(2) * impt_weight).mean()
else:
loss = F.mse_loss(q, r)

View File

@ -4,7 +4,7 @@ from torch import nn
from typing import Dict, List, Tuple, Union, Optional
from tianshou.policy import PGPolicy
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
class PPOPolicy(PGPolicy):
@ -129,14 +129,12 @@ class PPOPolicy(PGPolicy):
for b in batch.split(batch_size, shuffle=False):
v.append(self.critic(b.obs))
old_log_prob.append(self(b).dist.log_prob(
to_torch(b.act, device=v[0].device)))
to_torch_as(b.act, v[0])))
batch.v = torch.cat(v, dim=0) # old value
dev = batch.v.device
batch.act = to_torch(batch.act, dtype=torch.float, device=dev)
batch.act = to_torch_as(batch.act, v[0])
batch.logp_old = torch.cat(old_log_prob, dim=0)
batch.returns = to_torch(
batch.returns, dtype=torch.float, device=dev
).reshape(batch.v.shape)
batch.returns = to_torch_as(
batch.returns, v[0]).reshape(batch.v.shape)
if self._rew_norm:
mean, std = batch.returns.mean(), batch.returns.std()
if not np.isclose(std.item(), 0):

View File

@ -4,9 +4,9 @@ from copy import deepcopy
import torch.nn.functional as F
from typing import Dict, Tuple, Union, Optional
from tianshou.data import Batch, to_torch
from tianshou.policy import DDPGPolicy
from tianshou.policy.dist import DiagGaussian
from tianshou.data import Batch, to_torch_as, ReplayBuffer
class SACPolicy(DDPGPolicy):
@ -55,10 +55,11 @@ class SACPolicy(DDPGPolicy):
action_range: Optional[Tuple[float, float]] = None,
reward_normalization: bool = False,
ignore_done: bool = False,
estimation_step: int = 1,
**kwargs) -> None:
super().__init__(None, None, None, None, tau, gamma, 0,
action_range, reward_normalization, ignore_done,
**kwargs)
estimation_step, **kwargs)
self.actor, self.actor_optim = actor, actor_optim
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
self.critic1_old.eval()
@ -105,23 +106,23 @@ class SACPolicy(DDPGPolicy):
return Batch(
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
def _target_q(self, buffer: ReplayBuffer,
indice: np.ndarray) -> torch.Tensor:
batch = buffer[indice] # batch.obs: s_{t+n}
with torch.no_grad():
obs_next_result = self(batch, input='obs_next')
a_ = obs_next_result.act
dev = a_.device
batch.act = to_torch(batch.act, dtype=torch.float, device=dev)
batch.act = to_torch_as(batch.act, a_)
target_q = torch.min(
self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_),
) - self._alpha * obs_next_result.log_prob
rew = to_torch(batch.rew,
dtype=torch.float, device=dev)[:, None]
done = to_torch(batch.done,
dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q)
return target_q
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
# critic 1
current_q1 = self.critic1(batch.obs, batch.act)
target_q = to_torch_as(batch.returns, current_q1)[:, None]
critic1_loss = F.mse_loss(current_q1, target_q)
self.critic1_optim.zero_grad()
critic1_loss.backward()

View File

@ -1,10 +1,11 @@
import torch
import numpy as np
from copy import deepcopy
import torch.nn.functional as F
from typing import Dict, Tuple, Optional
from tianshou.data import Batch, to_torch
from tianshou.policy import DDPGPolicy
from tianshou.data import Batch, ReplayBuffer
class TD3Policy(DDPGPolicy):
@ -62,10 +63,11 @@ class TD3Policy(DDPGPolicy):
action_range: Optional[Tuple[float, float]] = None,
reward_normalization: bool = False,
ignore_done: bool = False,
estimation_step: int = 1,
**kwargs) -> None:
super().__init__(actor, actor_optim, None, None, tau, gamma,
exploration_noise, action_range, reward_normalization,
ignore_done, **kwargs)
ignore_done, estimation_step, **kwargs)
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
self.critic1_old.eval()
self.critic1_optim = critic1_optim
@ -100,25 +102,26 @@ class TD3Policy(DDPGPolicy):
self.critic2_old.parameters(), self.critic2.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
def _target_q(self, buffer: ReplayBuffer,
indice: np.ndarray) -> torch.Tensor:
batch = buffer[indice] # batch.obs: s_{t+n}
with torch.no_grad():
a_ = self(batch, model='actor_old', input='obs_next').act
dev = a_.device
noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise
if self._noise_clip >= 0:
if self._noise_clip > 0:
noise = noise.clamp(-self._noise_clip, self._noise_clip)
a_ += noise
a_ = a_.clamp(self._range[0], self._range[1])
target_q = torch.min(
self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_))
rew = to_torch(batch.rew,
dtype=torch.float, device=dev)[:, None]
done = to_torch(batch.done,
dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q)
return target_q
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
# critic 1
current_q1 = self.critic1(batch.obs, batch.act)
target_q = batch.returns[:, None]
critic1_loss = F.mse_loss(current_q1, target_q)
self.critic1_optim.zero_grad()
critic1_loss.backward()