Refactor PG algorithm and change behavior of compute_episodic_return (#319)

- simplify code
- apply value normalization (global) and adv norm (per-batch) in on-policy algorithms
This commit is contained in:
ChenDRAG 2021-03-23 22:05:48 +08:00 committed by GitHub
parent 2c11b6e43b
commit e27b5a26f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 109 additions and 192 deletions

View File

@ -30,9 +30,9 @@ def test_episodic_returns(size=2560):
for b in batch:
b.obs = b.act = 1
buf.add(b)
batch = fn(batch, buf, buf.sample_index(0), None, gamma=.1, gae_lambda=1)
returns, _ = fn(batch, buf, buf.sample_index(0), gamma=.1, gae_lambda=1)
ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
assert np.allclose(batch.returns, ans)
assert np.allclose(returns, ans)
buf.reset()
batch = Batch(
done=np.array([0, 1, 0, 1, 0, 1, 0.]),
@ -41,9 +41,9 @@ def test_episodic_returns(size=2560):
for b in batch:
b.obs = b.act = 1
buf.add(b)
batch = fn(batch, buf, buf.sample_index(0), None, gamma=.1, gae_lambda=1)
returns, _ = fn(batch, buf, buf.sample_index(0), gamma=.1, gae_lambda=1)
ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
assert np.allclose(batch.returns, ans)
assert np.allclose(returns, ans)
buf.reset()
batch = Batch(
done=np.array([0, 1, 0, 1, 0, 0, 1.]),
@ -52,9 +52,9 @@ def test_episodic_returns(size=2560):
for b in batch:
b.obs = b.act = 1
buf.add(b)
batch = fn(batch, buf, buf.sample_index(0), None, gamma=.1, gae_lambda=1)
returns, _ = fn(batch, buf, buf.sample_index(0), gamma=.1, gae_lambda=1)
ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
assert np.allclose(batch.returns, ans)
assert np.allclose(returns, ans)
buf.reset()
batch = Batch(
done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]),
@ -64,12 +64,12 @@ def test_episodic_returns(size=2560):
b.obs = b.act = 1
buf.add(b)
v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3])
ret = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95)
returns = np.array([
returns, _ = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95)
ground_truth = np.array([
454.8344, 376.1143, 291.298, 200.,
464.5610, 383.1085, 295.387, 201.,
474.2876, 390.1027, 299.476, 202.])
assert np.allclose(ret.returns, returns)
assert np.allclose(returns, ground_truth)
buf.reset()
batch = Batch(
done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]),
@ -82,12 +82,12 @@ def test_episodic_returns(size=2560):
b.obs = b.act = 1
buf.add(b)
v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3])
ret = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95)
returns = np.array([
returns, _ = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95)
ground_truth = np.array([
454.0109, 375.2386, 290.3669, 199.01,
462.9138, 381.3571, 293.5248, 199.02,
474.2876, 390.1027, 299.476, 202.])
assert np.allclose(ret.returns, returns)
assert np.allclose(returns, ground_truth)
if __name__ == '__main__':
buf = ReplayBuffer(size)

View File

@ -91,7 +91,8 @@ def test_ppo(args=get_args()):
def dist(*logits):
return Independent(Normal(*logits), 1)
policy = PPOPolicy(
actor, critic, optim, dist, args.gamma,
actor, critic, optim, dist,
discount_factor=args.gamma,
max_grad_norm=args.max_grad_norm,
eps_clip=args.eps_clip,
vf_coef=args.vf_coef,

View File

@ -78,7 +78,8 @@ def test_a2c_with_il(args=get_args()):
actor.parameters()).union(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical
policy = A2CPolicy(
actor, critic, optim, dist, args.gamma, gae_lambda=args.gae_lambda,
actor, critic, optim, dist,
discount_factor=args.gamma, gae_lambda=args.gae_lambda,
vf_coef=args.vf_coef, ent_coef=args.ent_coef,
max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm,
action_space=env.action_space)

View File

@ -17,7 +17,7 @@ from tianshou.data import Collector, VectorReplayBuffer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.95)
@ -27,7 +27,7 @@ def get_args():
parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128, 128, 128])
nargs='*', default=[64, 64])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
@ -65,6 +65,11 @@ def test_pg(args=get_args()):
policy = PGPolicy(net, optim, dist, args.gamma,
reward_normalization=args.rew_norm,
action_space=env.action_space)
for m in net.modules():
if isinstance(m, torch.nn.Linear):
# orthogonal initialization
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
# collector
train_collector = Collector(
policy, train_envs,

View File

@ -80,7 +80,8 @@ def test_ppo(args=get_args()):
actor.parameters()).union(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical
policy = PPOPolicy(
actor, critic, optim, dist, args.gamma,
actor, critic, optim, dist,
discount_factor=args.gamma,
max_grad_norm=args.max_grad_norm,
eps_clip=args.eps_clip,
vf_coef=args.vf_coef,

View File

@ -4,7 +4,7 @@ import numpy as np
from torch import nn
from numba import njit
from abc import ABC, abstractmethod
from typing import Any, Dict, Union, Optional, Callable
from typing import Any, Dict, Tuple, Union, Optional, Callable
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
@ -254,14 +254,14 @@ class BasePolicy(ABC, nn.Module):
buffer: ReplayBuffer,
indice: np.ndarray,
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
v_s: Optional[Union[np.ndarray, torch.Tensor]] = None,
gamma: float = 0.99,
gae_lambda: float = 0.95,
rew_norm: bool = False,
) -> Batch:
) -> Tuple[np.ndarray, np.ndarray]:
"""Compute returns over given batch.
Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438)
to calculate q function/reward to go of given batch.
to calculate q/advantage value of given batch.
:param Batch batch: a data batch which contains several episodes of data in
sequential order. Mind that the end of each finished episode of batch
@ -273,10 +273,8 @@ class BasePolicy(ABC, nn.Module):
:param float gamma: the discount factor, should be in [0, 1]. Default to 0.99.
:param float gae_lambda: the parameter for Generalized Advantage Estimation,
should be in [0, 1]. Default to 0.95.
:param bool rew_norm: normalize the reward to Normal(0, 1). Default to False.
:return: a Batch. The result will be stored in batch.returns as a numpy
array with shape (bsz, ).
:return: two numpy arrays (returns, advantage) with each shape (bsz, ).
"""
rew = batch.rew
if v_s_ is None:
@ -284,14 +282,14 @@ class BasePolicy(ABC, nn.Module):
v_s_ = np.zeros_like(rew)
else:
v_s_ = to_numpy(v_s_.flatten()) * BasePolicy.value_mask(buffer, indice)
v_s = np.roll(v_s_, 1) if v_s is None else to_numpy(v_s.flatten())
end_flag = batch.done.copy()
end_flag[np.isin(indice, buffer.unfinished_index())] = True
returns = _episodic_return(v_s_, rew, end_flag, gamma, gae_lambda)
if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2):
returns = (returns - returns.mean()) / returns.std()
batch.returns = returns
return batch
advantage = _gae_return(v_s, v_s_, rew, end_flag, gamma, gae_lambda)
returns = advantage + v_s
# normalization varies from each policy, so we don't do it here
return returns, advantage
@staticmethod
def compute_nstep_return(
@ -355,8 +353,6 @@ class BasePolicy(ABC, nn.Module):
i64 = np.array([[0, 1]], dtype=np.int64)
_gae_return(f64, f64, f64, b, 0.1, 0.1)
_gae_return(f32, f32, f64, b, 0.1, 0.1)
_episodic_return(f64, f64, b, 0.1, 0.1)
_episodic_return(f32, f64, b, 0.1, 0.1)
_nstep_return(f64, b, f32.reshape(-1, 1), i64, 0.1, 1)
@ -379,19 +375,6 @@ def _gae_return(
return returns
@njit
def _episodic_return(
v_s_: np.ndarray,
rew: np.ndarray,
end_flag: np.ndarray,
gamma: float,
gae_lambda: float,
) -> np.ndarray:
"""Numba speedup: 4.1s -> 0.057s."""
v_s = np.roll(v_s_, 1)
return _gae_return(v_s, v_s_, rew, end_flag, gamma, gae_lambda) + v_s
@njit
def _nstep_return(
rew: np.ndarray,

View File

@ -2,7 +2,7 @@ import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from typing import Any, Dict, List, Type, Union, Optional
from typing import Any, Dict, List, Type, Optional
from tianshou.policy import PGPolicy
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
@ -53,17 +53,14 @@ class A2CPolicy(PGPolicy):
critic: torch.nn.Module,
optim: torch.optim.Optimizer,
dist_fn: Type[torch.distributions.Distribution],
discount_factor: float = 0.99,
vf_coef: float = 0.5,
ent_coef: float = 0.01,
max_grad_norm: Optional[float] = None,
gae_lambda: float = 0.95,
reward_normalization: bool = False,
max_batchsize: int = 256,
**kwargs: Any
) -> None:
super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
self.actor = actor
super().__init__(actor, optim, dist_fn, **kwargs)
self.critic = critic
assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]."
self._lambda = gae_lambda
@ -71,51 +68,27 @@ class A2CPolicy(PGPolicy):
self._weight_ent = ent_coef
self._grad_norm = max_grad_norm
self._batch = max_batchsize
self._rew_norm = reward_normalization
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
if self._lambda in [0.0, 1.0]:
return self.compute_episodic_return(
batch, buffer, indice,
None, gamma=self._gamma, gae_lambda=self._lambda)
v_ = []
v_s_ = []
with torch.no_grad():
for b in batch.split(self._batch, shuffle=False, merge_last=True):
v_.append(to_numpy(self.critic(b.obs_next)))
v_ = np.concatenate(v_, axis=0)
return self.compute_episodic_return(
batch, buffer, indice, v_,
gamma=self._gamma, gae_lambda=self._lambda, rew_norm=self._rew_norm)
def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any
) -> Batch:
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
* ``act`` the action.
* ``logits`` the network's raw output.
* ``dist`` the action distribution.
* ``state`` the hidden state.
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
logits, h = self.actor(batch.obs, state=state, info=batch.info)
if isinstance(logits, tuple):
dist = self.dist_fn(*logits)
v_s_.append(to_numpy(self.critic(b.obs_next)))
v_s_ = np.concatenate(v_s_, axis=0)
if self._rew_norm: # unnormalize v_s_
v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean
unnormalized_returns, _ = self.compute_episodic_return(
batch, buffer, indice, v_s_=v_s_,
gamma=self._gamma, gae_lambda=self._lambda)
if self._rew_norm:
batch.returns = (unnormalized_returns - self.ret_rms.mean) / \
np.sqrt(self.ret_rms.var + self._eps)
self.ret_rms.update(unnormalized_returns)
else:
dist = self.dist_fn(logits)
act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist)
batch.returns = unnormalized_returns
return batch
def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any

View File

@ -4,10 +4,11 @@ from typing import Any, Dict, List, Type, Union, Optional
from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, to_torch_as
from tianshou.utils import RunningMeanStd
class PGPolicy(BasePolicy):
"""Implementation of Vanilla Policy Gradient.
"""Implementation of REINFORCE algorithm.
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
@ -33,7 +34,7 @@ class PGPolicy(BasePolicy):
def __init__(
self,
model: Optional[torch.nn.Module],
model: torch.nn.Module,
optim: torch.optim.Optimizer,
dist_fn: Type[torch.distributions.Distribution],
discount_factor: float = 0.99,
@ -45,14 +46,15 @@ class PGPolicy(BasePolicy):
) -> None:
super().__init__(action_scaling=action_scaling,
action_bound_method=action_bound_method, **kwargs)
if model is not None:
self.model: torch.nn.Module = model
self.actor = model
self.optim = optim
self.lr_scheduler = lr_scheduler
self.dist_fn = dist_fn
assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]"
self._gamma = discount_factor
self._rew_norm = reward_normalization
self.ret_rms = RunningMeanStd()
self._eps = 1e-8
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
@ -65,11 +67,16 @@ class PGPolicy(BasePolicy):
where :math:`T` is the terminal time step, :math:`\gamma` is the
discount factor, :math:`\gamma \in [0, 1]`.
"""
# batch.returns = self._vanilla_returns(batch)
# batch.returns = self._vectorized_returns(batch)
return self.compute_episodic_return(
batch, buffer, indice, gamma=self._gamma,
gae_lambda=1.0, rew_norm=self._rew_norm)
v_s_ = np.full(indice.shape, self.ret_rms.mean)
unnormalized_returns, _ = self.compute_episodic_return(
batch, buffer, indice, v_s_=v_s_, gamma=self._gamma, gae_lambda=1.0)
if self._rew_norm:
batch.returns = (unnormalized_returns - self.ret_rms.mean) / \
np.sqrt(self.ret_rms.var + self._eps)
self.ret_rms.update(unnormalized_returns)
else:
batch.returns = unnormalized_returns
return batch
def forward(
self,
@ -91,7 +98,7 @@ class PGPolicy(BasePolicy):
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
logits, h = self.model(batch.obs, state=state, info=batch.info)
logits, h = self.actor(batch.obs, state=state)
if isinstance(logits, tuple):
dist = self.dist_fn(*logits)
else:
@ -106,9 +113,10 @@ class PGPolicy(BasePolicy):
for _ in range(repeat):
for b in batch.split(batch_size, merge_last=True):
self.optim.zero_grad()
dist = self(b).dist
a = to_torch_as(b.act, dist.logits)
r = to_torch_as(b.returns, dist.logits)
result = self(b)
dist = result.dist
a = to_torch_as(b.act, result.act)
r = to_torch_as(b.returns, result.act)
log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1)
loss = -(log_prob * r).mean()
loss.backward()
@ -119,27 +127,3 @@ class PGPolicy(BasePolicy):
self.lr_scheduler.step()
return {"loss": losses}
# def _vanilla_returns(self, batch):
# returns = batch.rew[:]
# last = 0
# for i in range(len(returns) - 1, -1, -1):
# if not batch.done[i]:
# returns[i] += self._gamma * last
# last = returns[i]
# return returns
# def _vectorized_returns(self, batch):
# # according to my tests, it is slower than _vanilla_returns
# # import scipy.signal
# convolve = np.convolve
# # convolve = scipy.signal.convolve
# rew = batch.rew[::-1]
# batch_size = len(rew)
# gammas = self._gamma ** np.arange(batch_size)
# c = convolve(rew, gammas)[:batch_size]
# T = np.where(batch.done[::-1])[0]
# d = np.zeros_like(rew)
# d[T] += c[T] - rew[T]
# d[T[1:]] -= d[T[:-1]] * self._gamma ** np.diff(T)
# return (c - convolve(d, gammas)[:batch_size])[::-1]

View File

@ -1,13 +1,13 @@
import torch
import numpy as np
from torch import nn
from typing import Any, Dict, List, Type, Union, Optional
from typing import Any, Dict, List, Type, Optional
from tianshou.policy import PGPolicy
from tianshou.policy import A2CPolicy
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
class PPOPolicy(PGPolicy):
class PPOPolicy(A2CPolicy):
r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347.
:param torch.nn.Module actor: the actor network following the rules in
@ -30,8 +30,8 @@ class PPOPolicy(PGPolicy):
Default to 5.0 (set None if you do not want to use it).
:param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1.
Default to True.
:param bool reward_normalization: normalize the returns to Normal(0, 1).
Default to True.
:param bool reward_normalization: normalize the returns and advantage to
Normal(0, 1). Default to False.
:param int max_batchsize: the maximum size of the batch when computing GAE,
depends on the size of available memory and the memory cost of the
model; should be as large as possible within the memory constraint.
@ -58,7 +58,6 @@ class PPOPolicy(PGPolicy):
critic: torch.nn.Module,
optim: torch.optim.Optimizer,
dist_fn: Type[torch.distributions.Distribution],
discount_factor: float = 0.99,
max_grad_norm: Optional[float] = None,
eps_clip: float = 0.2,
vf_coef: float = 0.5,
@ -66,81 +65,50 @@ class PPOPolicy(PGPolicy):
gae_lambda: float = 0.95,
dual_clip: Optional[float] = None,
value_clip: bool = True,
reward_normalization: bool = True,
max_batchsize: int = 256,
**kwargs: Any,
) -> None:
super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
self._max_grad_norm = max_grad_norm
super().__init__(
actor, critic, optim, dist_fn, max_grad_norm=max_grad_norm,
vf_coef=vf_coef, ent_coef=ent_coef, gae_lambda=gae_lambda,
max_batchsize=max_batchsize, **kwargs)
self._eps_clip = eps_clip
self._weight_vf = vf_coef
self._weight_ent = ent_coef
self.actor = actor
self.critic = critic
self._batch = max_batchsize
assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]."
self._lambda = gae_lambda
assert dual_clip is None or dual_clip > 1.0, \
"Dual-clip PPO parameter should greater than 1.0."
self._dual_clip = dual_clip
self._value_clip = value_clip
self._rew_norm = reward_normalization
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
if self._rew_norm:
mean, std = batch.rew.mean(), batch.rew.std()
if not np.isclose(std, 0.0, 1e-2):
batch.rew = (batch.rew - mean) / std
v, v_, old_log_prob = [], [], []
v_s, v_s_, old_log_prob = [], [], []
with torch.no_grad():
for b in batch.split(self._batch, shuffle=False, merge_last=True):
v_.append(self.critic(b.obs_next))
v.append(self.critic(b.obs))
old_log_prob.append(self(b).dist.log_prob(to_torch_as(b.act, v[0])))
v_ = to_numpy(torch.cat(v_, dim=0))
batch = self.compute_episodic_return(
batch, buffer, indice, v_, gamma=self._gamma,
gae_lambda=self._lambda, rew_norm=self._rew_norm)
batch.v = torch.cat(v, dim=0).flatten() # old value
batch.act = to_torch_as(batch.act, v[0])
batch.logp_old = torch.cat(old_log_prob, dim=0)
batch.returns = to_torch_as(batch.returns, v[0])
batch.adv = batch.returns - batch.v
v_s.append(self.critic(b.obs))
v_s_.append(self.critic(b.obs_next))
old_log_prob.append(self(b).dist.log_prob(to_torch_as(b.act, v_s[0])))
batch.v_s = torch.cat(v_s, dim=0).flatten() # old value
v_s = to_numpy(batch.v_s)
v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten())
if self._rew_norm: # unnormalize v_s & v_s_
v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean
v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean
unnormalized_returns, advantages = self.compute_episodic_return(
batch, buffer, indice, v_s_, v_s,
gamma=self._gamma, gae_lambda=self._lambda)
if self._rew_norm:
mean, std = batch.adv.mean(), batch.adv.std()
if not np.isclose(std.item(), 0.0, 1e-2):
batch.adv = (batch.adv - mean) / std
return batch
def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> Batch:
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
* ``act`` the action.
* ``logits`` the network's raw output.
* ``dist`` the action distribution.
* ``state`` the hidden state.
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
logits, h = self.actor(batch.obs, state=state, info=batch.info)
if isinstance(logits, tuple):
dist = self.dist_fn(*logits)
batch.returns = (unnormalized_returns - self.ret_rms.mean) / \
np.sqrt(self.ret_rms.var + self._eps)
self.ret_rms.update(unnormalized_returns)
mean, std = np.mean(advantages), np.std(advantages)
advantages = (advantages - mean) / std # per-batch norm
else:
dist = self.dist_fn(logits)
act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist)
batch.returns = unnormalized_returns
batch.act = to_torch_as(batch.act, batch.v_s)
batch.logp_old = torch.cat(old_log_prob, dim=0)
batch.returns = to_torch_as(batch.returns, batch.v_s)
batch.adv = to_torch_as(advantages, batch.v_s)
return batch
def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
@ -162,7 +130,8 @@ class PPOPolicy(PGPolicy):
clip_loss = -torch.min(surr1, surr2).mean()
clip_losses.append(clip_loss.item())
if self._value_clip:
v_clip = b.v + (value - b.v).clamp(-self._eps_clip, self._eps_clip)
v_clip = b.v_s + (value - b.v_s).clamp(
-self._eps_clip, self._eps_clip)
vf1 = (b.returns - value).pow(2)
vf2 = (b.returns - v_clip).pow(2)
vf_loss = 0.5 * torch.max(vf1, vf2).mean()
@ -176,10 +145,10 @@ class PPOPolicy(PGPolicy):
losses.append(loss.item())
self.optim.zero_grad()
loss.backward()
if self._max_grad_norm:
if self._grad_norm is not None:
nn.utils.clip_grad_norm_(
list(self.actor.parameters()) + list(self.critic.parameters()),
self._max_grad_norm)
self._grad_norm)
self.optim.step()
# update learning rate if lr_scheduler is given
if self.lr_scheduler is not None: