fix ppo
This commit is contained in:
parent
680fc0ffbe
commit
6bf1ea644d
@ -88,7 +88,7 @@ We select some of famous reinforcement learning platforms: 2 GitHub repos with m
|
||||
| PG - CartPole | 9.03±4.18s | None | 15.77±6.28s | None | ? |
|
||||
| DQN - CartPole | 10.61±5.51s | 1046.34±291.27s | 40.16±12.79s | 175.55±53.81s | ? |
|
||||
| A2C - CartPole | 11.72±3.85s | *(~1612s) | 46.15±6.64s | Runtime Error | ? |
|
||||
| PPO - CartPole | 35.25±16.47s | *(~1179s) | 62.21±13.31s (APPO) | 29.16±15.46s | ? |
|
||||
| PPO - CartPole | 32.55±10.09s | *(~1179s) | 62.21±13.31s (APPO) | 29.16±15.46s | ? |
|
||||
| DDPG - Pendulum | 46.95±24.31s | *(>1h) | 377.99±13.79s | 652.83±471.28s | 172.18±62.48s |
|
||||
| TD3 - Pendulum | 48.39±7.22s | None | 620.83±248.43s | 619.33±324.97s | 210.31±76.30s |
|
||||
| SAC - Pendulum | 38.92±2.09s | None | 92.68±4.48s | 808.21±405.70s | 295.92±140.85s |
|
||||
|
@ -62,7 +62,7 @@ Here is about the experience of hyper-parameter tuning on CartPole and Pendulum:
|
||||
* :class:`~tianshou.policy.PPOPolicy`: TBD
|
||||
* :class:`~tianshou.policy.DDPGPolicy`, :class:`~tianshou.policy.TD3Policy`, and :class:`~tianshou.policy.SACPolicy`: We found two tricks. The first is to ignore the done flag. The second is to normalize reward to a standard normal distribution (it is against the theoretical analysis, but indeed works very well). The two tricks work amazingly on Mujoco tasks, typically with a faster converge speed (1M -> 200K).
|
||||
|
||||
* On-policy algorithms: increase the repeat-time (to 2 or 4) of the given batch in each training update will make the algorithm more stable.
|
||||
* On-policy algorithms: increase the repeat-time (to 2 or 4 for trivial benchmark, 10 for mujoco) of the given batch in each training update will make the algorithm more stable.
|
||||
|
||||
|
||||
Code-level optimization
|
||||
|
@ -39,7 +39,8 @@ class ActorProb(nn.Module):
|
||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
||||
self.model = nn.Sequential(*self.model)
|
||||
self.mu = nn.Linear(128, np.prod(action_shape))
|
||||
self.sigma = nn.Linear(128, np.prod(action_shape))
|
||||
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
|
||||
# self.sigma = nn.Linear(128, np.prod(action_shape))
|
||||
self._max = max_action
|
||||
|
||||
def forward(self, s, **kwargs):
|
||||
@ -48,8 +49,13 @@ class ActorProb(nn.Module):
|
||||
batch = s.shape[0]
|
||||
s = s.view(batch, -1)
|
||||
logits = self.model(s)
|
||||
mu = self._max * torch.tanh(self.mu(logits))
|
||||
sigma = torch.exp(self.sigma(logits))
|
||||
mu = self.mu(logits)
|
||||
shape = [1] * len(mu.shape)
|
||||
shape[1] = -1
|
||||
sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp()
|
||||
# assert sigma.shape == mu.shape
|
||||
# mu = self._max * torch.tanh(self.mu(logits))
|
||||
# sigma = torch.exp(self.sigma(logits))
|
||||
return (mu, sigma), None
|
||||
|
||||
|
||||
|
@ -36,6 +36,7 @@ def get_args():
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument('--rew-norm', type=bool, default=True)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
@ -77,7 +78,7 @@ def test_ddpg(args=get_args()):
|
||||
actor, actor_optim, critic, critic_optim,
|
||||
args.tau, args.gamma, args.exploration_noise,
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
reward_normalization=True, ignore_done=True)
|
||||
reward_normalization=args.rew_norm, ignore_done=True)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
|
@ -20,7 +20,7 @@ else: # pytest
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
@ -28,9 +28,9 @@ def get_args():
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=1)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=2)
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--batch-size', type=int, default=128)
|
||||
parser.add_argument('--layer-num', type=int, default=1)
|
||||
parser.add_argument('--training-num', type=int, default=16)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
@ -39,16 +39,18 @@ def get_args():
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
# ppo special
|
||||
parser.add_argument('--vf-coef', type=float, default=0.5)
|
||||
parser.add_argument('--ent-coef', type=float, default=0.0)
|
||||
parser.add_argument('--ent-coef', type=float, default=0.01)
|
||||
parser.add_argument('--eps-clip', type=float, default=0.2)
|
||||
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
||||
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
||||
parser.add_argument('--rew-norm', type=bool, default=True)
|
||||
parser.add_argument('--dual-clip', type=float, default=5.)
|
||||
parser.add_argument('--value-clip', type=bool, default=True)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def _test_ppo(args=get_args()):
|
||||
# just a demo, I have not made it work :(
|
||||
def test_ppo(args=get_args()):
|
||||
torch.set_num_threads(1) # we just need only one thread for NN
|
||||
env = gym.make(args.task)
|
||||
if args.task == 'Pendulum-v0':
|
||||
@ -85,7 +87,11 @@ def _test_ppo(args=get_args()):
|
||||
eps_clip=args.eps_clip,
|
||||
vf_coef=args.vf_coef,
|
||||
ent_coef=args.ent_coef,
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
reward_normalization=args.rew_norm,
|
||||
dual_clip=args.dual_clip,
|
||||
value_clip=args.value_clip,
|
||||
# action_range=[env.action_space.low[0], env.action_space.high[0]],)
|
||||
# if clip the action, ppo would not converge :)
|
||||
gae_lambda=args.gae_lambda)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
@ -121,4 +127,4 @@ def _test_ppo(args=get_args()):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test_ppo()
|
||||
test_ppo()
|
||||
|
@ -37,6 +37,7 @@ def get_args():
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument('--rew-norm', type=bool, default=True)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
@ -82,7 +83,7 @@ def test_sac_with_il(args=get_args()):
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
args.tau, args.gamma, args.alpha,
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
reward_normalization=True, ignore_done=True)
|
||||
reward_normalization=args.rew_norm, ignore_done=True)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
|
@ -39,6 +39,7 @@ def get_args():
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument('--rew-norm', type=bool, default=True)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
@ -85,7 +86,7 @@ def test_td3(args=get_args()):
|
||||
args.tau, args.gamma, args.exploration_noise, args.policy_noise,
|
||||
args.update_actor_freq, args.noise_clip,
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
reward_normalization=True, ignore_done=True)
|
||||
reward_normalization=args.rew_norm, ignore_done=True)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
|
@ -57,6 +57,18 @@ def test_fn(size=2560):
|
||||
batch = fn(batch, buf, 0)
|
||||
ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
|
||||
assert abs(batch.returns - ans).sum() <= 1e-5
|
||||
batch = Batch(
|
||||
done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]),
|
||||
rew=np.array([
|
||||
101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202])
|
||||
)
|
||||
v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3])
|
||||
ret = policy.compute_episodic_return(batch, v, gamma=0.99, gae_lambda=0.95)
|
||||
returns = np.array([
|
||||
454.8344, 376.1143, 291.298, 200.,
|
||||
464.5610, 383.1085, 295.387, 201.,
|
||||
474.2876, 390.1027, 299.476, 202.])
|
||||
assert abs(ret.returns - returns).sum() <= 1e-3
|
||||
if __name__ == '__main__':
|
||||
batch = Batch(
|
||||
done=np.random.randint(100, size=size) == 0,
|
||||
|
@ -20,17 +20,17 @@ else: # pytest
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--epoch', type=int, default=10)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=2000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=20)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=1)
|
||||
parser.add_argument('--training-num', type=int, default=32)
|
||||
parser.add_argument('--training-num', type=int, default=20)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
@ -42,7 +42,10 @@ def get_args():
|
||||
parser.add_argument('--ent-coef', type=float, default=0.0)
|
||||
parser.add_argument('--eps-clip', type=float, default=0.2)
|
||||
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
||||
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
||||
parser.add_argument('--gae-lambda', type=float, default=1)
|
||||
parser.add_argument('--rew-norm', type=bool, default=True)
|
||||
parser.add_argument('--dual-clip', type=float, default=None)
|
||||
parser.add_argument('--value-clip', type=bool, default=True)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
@ -78,7 +81,10 @@ def test_ppo(args=get_args()):
|
||||
vf_coef=args.vf_coef,
|
||||
ent_coef=args.ent_coef,
|
||||
action_range=None,
|
||||
gae_lambda=args.gae_lambda)
|
||||
gae_lambda=args.gae_lambda,
|
||||
reward_normalization=args.rew_norm,
|
||||
dual_clip=args.dual_clip,
|
||||
value_clip=args.value_clip)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
|
@ -79,7 +79,7 @@ class Batch(object):
|
||||
"""Return str(self)."""
|
||||
s = self.__class__.__name__ + '(\n'
|
||||
flag = False
|
||||
for k in self.__dict__.keys():
|
||||
for k in sorted(self.__dict__.keys()):
|
||||
if k[0] != '_' and self.__dict__[k] is not None:
|
||||
rpl = '\n' + ' ' * (6 + len(k))
|
||||
obj = str(self.__dict__[k]).replace('\n', rpl)
|
||||
|
@ -76,7 +76,8 @@ class BasePolicy(ABC, nn.Module):
|
||||
"""
|
||||
pass
|
||||
|
||||
def compute_episodic_return(self, batch, v_s_=None,
|
||||
@staticmethod
|
||||
def compute_episodic_return(batch, v_s_=None,
|
||||
gamma=0.99, gae_lambda=0.95):
|
||||
"""Compute returns over given full-length episodes, including the
|
||||
implementation of Generalized Advantage Estimation (arXiv:1506.02438).
|
||||
@ -93,11 +94,11 @@ class BasePolicy(ABC, nn.Module):
|
||||
"""
|
||||
if v_s_ is None:
|
||||
v_s_ = np.zeros_like(batch.rew)
|
||||
if not isinstance(v_s_, np.ndarray):
|
||||
v_s_ = np.array(v_s_, np.float)
|
||||
else:
|
||||
v_s_ = v_s_.flatten()
|
||||
batch.returns = np.roll(v_s_, 1)
|
||||
if not isinstance(v_s_, np.ndarray):
|
||||
v_s_ = np.array(v_s_, np.float)
|
||||
v_s_ = v_s_.reshape(batch.rew.shape)
|
||||
batch.returns = np.roll(v_s_, 1, axis=0)
|
||||
m = (1. - batch.done) * gamma
|
||||
delta = batch.rew + v_s_ * m - batch.returns
|
||||
m *= gae_lambda
|
||||
|
@ -51,7 +51,7 @@ class A2CPolicy(PGPolicy):
|
||||
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
v_ = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(self._batch * 4, permute=False):
|
||||
for b in batch.split(self._batch, permute=False):
|
||||
v_.append(self.critic(b.obs_next).detach().cpu().numpy())
|
||||
v_ = np.concatenate(v_, axis=0)
|
||||
return self.compute_episodic_return(
|
||||
|
@ -1,8 +1,6 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from copy import deepcopy
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import PGPolicy
|
||||
@ -28,6 +26,13 @@ class PPOPolicy(PGPolicy):
|
||||
:type action_range: [float, float]
|
||||
:param float gae_lambda: in [0, 1], param for Generalized Advantage
|
||||
Estimation, defaults to 0.95.
|
||||
:param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5,
|
||||
where c > 1 is a constant indicating the lower bound,
|
||||
defaults to 5.0 (set ``None`` if you do not want to use it).
|
||||
:param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1,
|
||||
defaults to ``True``.
|
||||
:param bool reward_normalization: normalize the returns to Normal(0, 1),
|
||||
defaults to ``True``.
|
||||
|
||||
.. seealso::
|
||||
|
||||
@ -36,13 +41,9 @@ class PPOPolicy(PGPolicy):
|
||||
"""
|
||||
|
||||
def __init__(self, actor, critic, optim, dist_fn,
|
||||
discount_factor=0.99,
|
||||
max_grad_norm=.5,
|
||||
eps_clip=.2,
|
||||
vf_coef=.5,
|
||||
ent_coef=.0,
|
||||
action_range=None,
|
||||
gae_lambda=0.95,
|
||||
discount_factor=0.99, max_grad_norm=.5, eps_clip=.2,
|
||||
vf_coef=.5, ent_coef=.01, action_range=None, gae_lambda=0.95,
|
||||
dual_clip=5., value_clip=True, reward_normalization=True,
|
||||
**kwargs):
|
||||
super().__init__(None, None, dist_fn, discount_factor, **kwargs)
|
||||
self._max_grad_norm = max_grad_norm
|
||||
@ -50,41 +51,36 @@ class PPOPolicy(PGPolicy):
|
||||
self._w_vf = vf_coef
|
||||
self._w_ent = ent_coef
|
||||
self._range = action_range
|
||||
self.actor, self.actor_old = actor, deepcopy(actor)
|
||||
self.actor_old.eval()
|
||||
self.critic, self.critic_old = critic, deepcopy(critic)
|
||||
self.critic_old.eval()
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
self.optim = optim
|
||||
self._batch = 64
|
||||
assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].'
|
||||
self._lambda = gae_lambda
|
||||
|
||||
def train(self):
|
||||
"""Set the module in training mode, except for the target network."""
|
||||
self.training = True
|
||||
self.actor.train()
|
||||
self.critic.train()
|
||||
|
||||
def eval(self):
|
||||
"""Set the module in evaluation mode, except for the target network."""
|
||||
self.training = False
|
||||
self.actor.eval()
|
||||
self.critic.eval()
|
||||
assert dual_clip is None or dual_clip > 1, \
|
||||
'Dual-clip PPO parameter should greater than 1.'
|
||||
self._dual_clip = dual_clip
|
||||
self._value_clip = value_clip
|
||||
self._rew_norm = reward_normalization
|
||||
self.__eps = np.finfo(np.float32).eps.item()
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
if self._rew_norm:
|
||||
mean, std = batch.rew.mean(), batch.rew.std()
|
||||
if std > self.__eps:
|
||||
batch.rew = (batch.rew - mean) / std
|
||||
if self._lambda in [0, 1]:
|
||||
return self.compute_episodic_return(
|
||||
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
v_ = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(self._batch * 4, permute=False):
|
||||
v_.append(self.critic(b.obs_next).detach().cpu().numpy())
|
||||
v_ = np.concatenate(v_, axis=0)
|
||||
batch.v_ = v_
|
||||
for b in batch.split(self._batch, permute=False):
|
||||
v_.append(self.critic(b.obs_next))
|
||||
v_ = torch.cat(v_, dim=0).cpu().numpy()
|
||||
return self.compute_episodic_return(
|
||||
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
|
||||
def forward(self, batch, state=None, model='actor', **kwargs):
|
||||
def forward(self, batch, state=None, **kwargs):
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
|
||||
@ -99,8 +95,7 @@ class PPOPolicy(PGPolicy):
|
||||
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
||||
more detailed explanation.
|
||||
"""
|
||||
model = getattr(self, model)
|
||||
logits, h = model(batch.obs, state=state, info=batch.info)
|
||||
logits, h = self.actor(batch.obs, state=state, info=batch.info)
|
||||
if isinstance(logits, tuple):
|
||||
dist = self.dist_fn(*logits)
|
||||
else:
|
||||
@ -110,35 +105,54 @@ class PPOPolicy(PGPolicy):
|
||||
act = act.clamp(self._range[0], self._range[1])
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def sync_weight(self):
|
||||
"""Synchronize the weight for the target network."""
|
||||
self.actor_old.load_state_dict(self.actor.state_dict())
|
||||
self.critic_old.load_state_dict(self.critic.state_dict())
|
||||
|
||||
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
|
||||
self._batch = batch_size
|
||||
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
||||
r = batch.returns
|
||||
batch.returns = (r - r.mean()) / (r.std() + self._eps)
|
||||
batch.act = torch.tensor(batch.act)
|
||||
batch.returns = torch.tensor(batch.returns)[:, None]
|
||||
batch.v_ = torch.tensor(batch.v_)
|
||||
v = []
|
||||
old_log_prob = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(batch_size, permute=False):
|
||||
v.append(self.critic(b.obs))
|
||||
old_log_prob.append(self(b).dist.log_prob(
|
||||
torch.tensor(b.act, device=v[0].device)))
|
||||
batch.v = torch.cat(v, dim=0) # old value
|
||||
dev = batch.v.device
|
||||
batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev)
|
||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||
batch.returns = torch.tensor(
|
||||
batch.returns, dtype=torch.float, device=dev
|
||||
).reshape(batch.v.shape)
|
||||
if self._rew_norm:
|
||||
mean, std = batch.returns.mean(), batch.returns.std()
|
||||
if std > self.__eps:
|
||||
batch.returns = (batch.returns - mean) / std
|
||||
batch.adv = batch.returns - batch.v
|
||||
if self._rew_norm:
|
||||
mean, std = batch.adv.mean(), batch.adv.std()
|
||||
if std > self.__eps:
|
||||
batch.adv = (batch.adv - mean) / std
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size):
|
||||
vs_old = self.critic_old(b.obs)
|
||||
vs__old = b.v_.to(vs_old.device)
|
||||
dist = self(b).dist
|
||||
dist_old = self(b, model='actor_old').dist
|
||||
target_v = b.returns.to(vs_old.device) + self._gamma * vs__old
|
||||
adv = (target_v - vs_old).detach()
|
||||
a = b.act.to(adv.device)
|
||||
ratio = torch.exp(dist.log_prob(a) - dist_old.log_prob(a))
|
||||
surr1 = ratio * adv
|
||||
value = self.critic(b.obs)
|
||||
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
||||
surr1 = ratio * b.adv
|
||||
surr2 = ratio.clamp(
|
||||
1. - self._eps_clip, 1. + self._eps_clip) * adv
|
||||
clip_loss = -torch.min(surr1, surr2).mean()
|
||||
1. - self._eps_clip, 1. + self._eps_clip) * b.adv
|
||||
if self._dual_clip:
|
||||
clip_loss = -torch.max(torch.min(surr1, surr2),
|
||||
self._dual_clip * b.adv).mean()
|
||||
else:
|
||||
clip_loss = -torch.min(surr1, surr2).mean()
|
||||
clip_losses.append(clip_loss.item())
|
||||
vf_loss = F.smooth_l1_loss(self.critic(b.obs), target_v)
|
||||
if self._value_clip:
|
||||
v_clip = b.v + (value - b.v).clamp(
|
||||
-self._eps_clip, self._eps_clip)
|
||||
vf1 = (b.returns - value).pow(2)
|
||||
vf2 = (b.returns - v_clip).pow(2)
|
||||
vf_loss = .5 * torch.max(vf1, vf2).mean()
|
||||
else:
|
||||
vf_loss = .5 * (b.returns - value).pow(2).mean()
|
||||
vf_losses.append(vf_loss.item())
|
||||
e_loss = dist.entropy().mean()
|
||||
ent_losses.append(e_loss.item())
|
||||
@ -150,7 +164,6 @@ class PPOPolicy(PGPolicy):
|
||||
self.actor.parameters()) + list(self.critic.parameters()),
|
||||
self._max_grad_norm)
|
||||
self.optim.step()
|
||||
self.sync_weight()
|
||||
return {
|
||||
'loss': losses,
|
||||
'loss/clip': clip_losses,
|
||||
|
@ -26,13 +26,13 @@ class MovAvg(object):
|
||||
def add(self, x):
|
||||
"""Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with
|
||||
only one element, a python scalar, or a list of python scalar. It will
|
||||
automatically exclude the infinity.
|
||||
automatically exclude the infinity and NaN.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.item()
|
||||
if isinstance(x, list):
|
||||
for _ in x:
|
||||
if _ != np.inf:
|
||||
if _ not in [np.inf, np.nan, -np.inf]:
|
||||
self.cache.append(_)
|
||||
elif x != np.inf:
|
||||
self.cache.append(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user