This commit is contained in:
Trinkle23897 2020-04-14 21:11:06 +08:00
parent 7b65d43394
commit 680fc0ffbe
13 changed files with 129 additions and 51 deletions

View File

@ -26,6 +26,7 @@
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf) - [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf) - [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
- Vanilla Imitation Learning - Vanilla Imitation Learning
- [Generalized Advantage Estimation (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms. Our team is working on supporting more algorithms and more scenarios on Tianshou in this period of development. Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms. Our team is working on supporting more algorithms and more scenarios on Tianshou in this period of development.

View File

@ -16,7 +16,8 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_ * :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_ * :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_ * :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
* :class:`~tianshou.policy.ImitationPolicy` * :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimation <https://arxiv.org/pdf/1506.02438.pdf>`_
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms. Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms.

View File

@ -12,8 +12,6 @@ else: # pytest
class MyPolicy(BasePolicy): class MyPolicy(BasePolicy):
"""docstring for MyPolicy"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -6,8 +6,8 @@ import argparse
import numpy as np import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import PPOPolicy
from tianshou.env import VectorEnv from tianshou.env import VectorEnv
from tianshou.policy import PPOPolicy
from tianshou.trainer import onpolicy_trainer from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
@ -22,15 +22,15 @@ def get_args():
parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=0) parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--collect-per-step', type=int, default=1)
parser.add_argument('--repeat-per-collect', type=int, default=10) parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=1) parser.add_argument('--layer-num', type=int, default=2)
parser.add_argument('--training-num', type=int, default=16) parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.) parser.add_argument('--render', type=float, default=0.)
@ -42,6 +42,7 @@ def get_args():
parser.add_argument('--ent-coef', type=float, default=0.0) parser.add_argument('--ent-coef', type=float, default=0.0)
parser.add_argument('--eps-clip', type=float, default=0.2) parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--max-grad-norm', type=float, default=0.5) parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--gae-lambda', type=float, default=0.95)
args = parser.parse_known_args()[0] args = parser.parse_known_args()[0]
return args return args
@ -84,12 +85,12 @@ def _test_ppo(args=get_args()):
eps_clip=args.eps_clip, eps_clip=args.eps_clip,
vf_coef=args.vf_coef, vf_coef=args.vf_coef,
ent_coef=args.ent_coef, ent_coef=args.ent_coef,
action_range=[env.action_space.low[0], env.action_space.high[0]]) action_range=[env.action_space.low[0], env.action_space.high[0]],
gae_lambda=args.gae_lambda)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.step_per_epoch)
# log # log
log_path = os.path.join(args.logdir, args.task, 'ppo') log_path = os.path.join(args.logdir, args.task, 'ppo')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)

View File

@ -24,7 +24,7 @@ def get_args():
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--repeat-per-collect', type=int, default=1)
@ -41,6 +41,7 @@ def get_args():
parser.add_argument('--vf-coef', type=float, default=0.5) parser.add_argument('--vf-coef', type=float, default=0.5)
parser.add_argument('--ent-coef', type=float, default=0.001) parser.add_argument('--ent-coef', type=float, default=0.001)
parser.add_argument('--max-grad-norm', type=float, default=None) parser.add_argument('--max-grad-norm', type=float, default=None)
parser.add_argument('--gae-lambda', type=float, default=1.)
args = parser.parse_known_args()[0] args = parser.parse_known_args()[0]
return args return args
@ -70,8 +71,9 @@ def test_a2c(args=get_args()):
actor.parameters()) + list(critic.parameters()), lr=args.lr) actor.parameters()) + list(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical dist = torch.distributions.Categorical
policy = A2CPolicy( policy = A2CPolicy(
actor, critic, optim, dist, args.gamma, vf_coef=args.vf_coef, actor, critic, optim, dist, args.gamma, gae_lambda=args.gae_lambda,
ent_coef=args.ent_coef, max_grad_norm=args.max_grad_norm) vf_coef=args.vf_coef, ent_coef=args.ent_coef,
max_grad_norm=args.max_grad_norm)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))

View File

@ -28,7 +28,7 @@ def get_args():
parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--batch-size', type=int, default=64)

View File

@ -29,7 +29,7 @@ def get_args():
parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--n-step', type=int, default=4) parser.add_argument('--n-step', type=int, default=4)
parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--batch-size', type=int, default=64)

View File

@ -80,7 +80,7 @@ def get_args():
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--repeat-per-collect', type=int, default=2)

View File

@ -23,7 +23,7 @@ def get_args():
parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--collect-per-step', type=int, default=10)
@ -42,6 +42,7 @@ def get_args():
parser.add_argument('--ent-coef', type=float, default=0.0) parser.add_argument('--ent-coef', type=float, default=0.0)
parser.add_argument('--eps-clip', type=float, default=0.2) parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--max-grad-norm', type=float, default=0.5) parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--gae-lambda', type=float, default=0.95)
args = parser.parse_known_args()[0] args = parser.parse_known_args()[0]
return args return args
@ -76,7 +77,8 @@ def test_ppo(args=get_args()):
eps_clip=args.eps_clip, eps_clip=args.eps_clip,
vf_coef=args.vf_coef, vf_coef=args.vf_coef,
ent_coef=args.ent_coef, ent_coef=args.ent_coef,
action_range=None) action_range=None,
gae_lambda=args.gae_lambda)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))

View File

@ -1,3 +1,4 @@
import numpy as np
from torch import nn from torch import nn
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -74,3 +75,34 @@ class BasePolicy(ABC, nn.Module):
:return: A dict which includes loss and its corresponding label. :return: A dict which includes loss and its corresponding label.
""" """
pass pass
def compute_episodic_return(self, batch, v_s_=None,
gamma=0.99, gae_lambda=0.95):
"""Compute returns over given full-length episodes, including the
implementation of Generalized Advantage Estimation (arXiv:1506.02438).
:param batch: a data batch which contains several full-episode data
chronologically.
:type batch: :class:`~tianshou.data.Batch`
:param v_s_: the value function of all next states :math:`V(s')`.
:type v_s_: numpy.ndarray
:param float gamma: the discount factor, should be in [0, 1], defaults
to 0.99.
:param float gae_lambda: the parameter for Generalized Advantage
Estimation, should be in [0, 1], defaults to 0.95.
"""
if v_s_ is None:
v_s_ = np.zeros_like(batch.rew)
if not isinstance(v_s_, np.ndarray):
v_s_ = np.array(v_s_, np.float)
else:
v_s_ = v_s_.flatten()
batch.returns = np.roll(v_s_, 1)
m = (1. - batch.done) * gamma
delta = batch.rew + v_s_ * m - batch.returns
m *= gae_lambda
gae = 0.
for i in range(len(batch.rew) - 1, -1, -1):
gae = delta[i] + m[i] * gae
batch.returns[i] += gae
return batch

View File

@ -1,4 +1,5 @@
import torch import torch
import numpy as np
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
@ -21,6 +22,8 @@ class A2CPolicy(PGPolicy):
:param float ent_coef: weight for entropy loss, defaults to 0.01. :param float ent_coef: weight for entropy loss, defaults to 0.01.
:param float max_grad_norm: clipping gradients in back propagation, :param float max_grad_norm: clipping gradients in back propagation,
defaults to ``None``. defaults to ``None``.
:param float gae_lambda: in [0, 1], param for Generalized Advantage
Estimation, defaults to 0.95.
.. seealso:: .. seealso::
@ -31,13 +34,28 @@ class A2CPolicy(PGPolicy):
def __init__(self, actor, critic, optim, def __init__(self, actor, critic, optim,
dist_fn=torch.distributions.Categorical, dist_fn=torch.distributions.Categorical,
discount_factor=0.99, vf_coef=.5, ent_coef=.01, discount_factor=0.99, vf_coef=.5, ent_coef=.01,
max_grad_norm=None, **kwargs): max_grad_norm=None, gae_lambda=0.95, **kwargs):
super().__init__(None, optim, dist_fn, discount_factor, **kwargs) super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
self.actor = actor self.actor = actor
self.critic = critic self.critic = critic
assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].'
self._lambda = gae_lambda
self._w_vf = vf_coef self._w_vf = vf_coef
self._w_ent = ent_coef self._w_ent = ent_coef
self._grad_norm = max_grad_norm self._grad_norm = max_grad_norm
self._batch = 64
def process_fn(self, batch, buffer, indice):
if self._lambda in [0, 1]:
return self.compute_episodic_return(
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
v_ = []
with torch.no_grad():
for b in batch.split(self._batch * 4, permute=False):
v_.append(self.critic(b.obs_next).detach().cpu().numpy())
v_ = np.concatenate(v_, axis=0)
return self.compute_episodic_return(
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
def forward(self, batch, state=None, **kwargs): def forward(self, batch, state=None, **kwargs):
"""Compute action over the given batch data. """Compute action over the given batch data.
@ -63,6 +81,7 @@ class A2CPolicy(PGPolicy):
return Batch(logits=logits, act=act, state=h, dist=dist) return Batch(logits=logits, act=act, state=h, dist=dist)
def learn(self, batch, batch_size=None, repeat=1, **kwargs): def learn(self, batch, batch_size=None, repeat=1, **kwargs):
self._batch = batch_size
losses, actor_losses, vf_losses, ent_losses = [], [], [], [] losses, actor_losses, vf_losses, ent_losses = [], [], [], []
for _ in range(repeat): for _ in range(repeat):
for b in batch.split(batch_size): for b in batch.split(batch_size):
@ -70,12 +89,11 @@ class A2CPolicy(PGPolicy):
result = self(b) result = self(b)
dist = result.dist dist = result.dist
v = self.critic(b.obs) v = self.critic(b.obs)
a = torch.tensor(b.act, device=dist.logits.device) a = torch.tensor(b.act, device=v.device)
r = torch.tensor(b.returns, device=dist.logits.device) r = torch.tensor(b.returns, device=v.device)
a_loss = -(dist.log_prob(a) * (r - v).detach()).mean() a_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
vf_loss = F.mse_loss(r[:, None], v) vf_loss = F.mse_loss(r[:, None], v)
ent_loss = dist.entropy().mean() ent_loss = dist.entropy().mean()
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
loss.backward() loss.backward()
if self._grad_norm: if self._grad_norm:

View File

@ -39,9 +39,11 @@ class PGPolicy(BasePolicy):
, where :math:`T` is the terminal time step, :math:`\gamma` is the , where :math:`T` is the terminal time step, :math:`\gamma` is the
discount factor, :math:`\gamma \in [0, 1]`. discount factor, :math:`\gamma \in [0, 1]`.
""" """
batch.returns = self._vanilla_returns(batch) # batch.returns = self._vanilla_returns(batch)
# batch.returns = self._vectorized_returns(batch) # batch.returns = self._vectorized_returns(batch)
return batch # return batch
return self.compute_episodic_return(
batch, gamma=self._gamma, gae_lambda=1.)
def forward(self, batch, state=None, **kwargs): def forward(self, batch, state=None, **kwargs):
"""Compute action over the given batch data. """Compute action over the given batch data.
@ -82,26 +84,26 @@ class PGPolicy(BasePolicy):
losses.append(loss.item()) losses.append(loss.item())
return {'loss': losses} return {'loss': losses}
def _vanilla_returns(self, batch): # def _vanilla_returns(self, batch):
returns = batch.rew[:] # returns = batch.rew[:]
last = 0 # last = 0
for i in range(len(returns) - 1, -1, -1): # for i in range(len(returns) - 1, -1, -1):
if not batch.done[i]: # if not batch.done[i]:
returns[i] += self._gamma * last # returns[i] += self._gamma * last
last = returns[i] # last = returns[i]
return returns # return returns
def _vectorized_returns(self, batch): # def _vectorized_returns(self, batch):
# according to my tests, it is slower than _vanilla_returns # # according to my tests, it is slower than _vanilla_returns
# import scipy.signal # # import scipy.signal
convolve = np.convolve # convolve = np.convolve
# convolve = scipy.signal.convolve # # convolve = scipy.signal.convolve
rew = batch.rew[::-1] # rew = batch.rew[::-1]
batch_size = len(rew) # batch_size = len(rew)
gammas = self._gamma ** np.arange(batch_size) # gammas = self._gamma ** np.arange(batch_size)
c = convolve(rew, gammas)[:batch_size] # c = convolve(rew, gammas)[:batch_size]
T = np.where(batch.done[::-1])[0] # T = np.where(batch.done[::-1])[0]
d = np.zeros_like(rew) # d = np.zeros_like(rew)
d[T] += c[T] - rew[T] # d[T] += c[T] - rew[T]
d[T[1:]] -= d[T[:-1]] * self._gamma ** np.diff(T) # d[T[1:]] -= d[T[:-1]] * self._gamma ** np.diff(T)
return (c - convolve(d, gammas)[:batch_size])[::-1] # return (c - convolve(d, gammas)[:batch_size])[::-1]

View File

@ -26,6 +26,8 @@ class PPOPolicy(PGPolicy):
:param float ent_coef: weight for entropy loss, defaults to 0.01. :param float ent_coef: weight for entropy loss, defaults to 0.01.
:param action_range: the action range (minimum, maximum). :param action_range: the action range (minimum, maximum).
:type action_range: [float, float] :type action_range: [float, float]
:param float gae_lambda: in [0, 1], param for Generalized Advantage
Estimation, defaults to 0.95.
.. seealso:: .. seealso::
@ -40,6 +42,7 @@ class PPOPolicy(PGPolicy):
vf_coef=.5, vf_coef=.5,
ent_coef=.0, ent_coef=.0,
action_range=None, action_range=None,
gae_lambda=0.95,
**kwargs): **kwargs):
super().__init__(None, None, dist_fn, discount_factor, **kwargs) super().__init__(None, None, dist_fn, discount_factor, **kwargs)
self._max_grad_norm = max_grad_norm self._max_grad_norm = max_grad_norm
@ -52,6 +55,9 @@ class PPOPolicy(PGPolicy):
self.critic, self.critic_old = critic, deepcopy(critic) self.critic, self.critic_old = critic, deepcopy(critic)
self.critic_old.eval() self.critic_old.eval()
self.optim = optim self.optim = optim
self._batch = 64
assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].'
self._lambda = gae_lambda
def train(self): def train(self):
"""Set the module in training mode, except for the target network.""" """Set the module in training mode, except for the target network."""
@ -65,6 +71,19 @@ class PPOPolicy(PGPolicy):
self.actor.eval() self.actor.eval()
self.critic.eval() self.critic.eval()
def process_fn(self, batch, buffer, indice):
if self._lambda in [0, 1]:
return self.compute_episodic_return(
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
v_ = []
with torch.no_grad():
for b in batch.split(self._batch * 4, permute=False):
v_.append(self.critic(b.obs_next).detach().cpu().numpy())
v_ = np.concatenate(v_, axis=0)
batch.v_ = v_
return self.compute_episodic_return(
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
def forward(self, batch, state=None, model='actor', **kwargs): def forward(self, batch, state=None, model='actor', **kwargs):
"""Compute action over the given batch data. """Compute action over the given batch data.
@ -97,18 +116,20 @@ class PPOPolicy(PGPolicy):
self.critic_old.load_state_dict(self.critic.state_dict()) self.critic_old.load_state_dict(self.critic.state_dict())
def learn(self, batch, batch_size=None, repeat=1, **kwargs): def learn(self, batch, batch_size=None, repeat=1, **kwargs):
self._batch = batch_size
losses, clip_losses, vf_losses, ent_losses = [], [], [], [] losses, clip_losses, vf_losses, ent_losses = [], [], [], []
r = batch.returns r = batch.returns
batch.returns = (r - r.mean()) / (r.std() + self._eps) batch.returns = (r - r.mean()) / (r.std() + self._eps)
batch.act = torch.tensor(batch.act) batch.act = torch.tensor(batch.act)
batch.returns = torch.tensor(batch.returns)[:, None] batch.returns = torch.tensor(batch.returns)[:, None]
batch.v_ = torch.tensor(batch.v_)
for _ in range(repeat): for _ in range(repeat):
for b in batch.split(batch_size): for b in batch.split(batch_size):
vs_old, vs__old = self.critic_old(np.concatenate([ vs_old = self.critic_old(b.obs)
b.obs, b.obs_next])).split(b.obs.shape[0]) vs__old = b.v_.to(vs_old.device)
dist = self(b).dist dist = self(b).dist
dist_old = self(b, model='actor_old').dist dist_old = self(b, model='actor_old').dist
target_v = b.returns.to(vs__old.device) + self._gamma * vs__old target_v = b.returns.to(vs_old.device) + self._gamma * vs__old
adv = (target_v - vs_old).detach() adv = (target_v - vs_old).detach()
a = b.act.to(adv.device) a = b.act.to(adv.device)
ratio = torch.exp(dist.log_prob(a) - dist_old.log_prob(a)) ratio = torch.exp(dist.log_prob(a) - dist_old.log_prob(a))