This commit is contained in:
Trinkle23897 2020-04-19 14:30:42 +08:00
parent 680fc0ffbe
commit 6bf1ea644d
14 changed files with 135 additions and 88 deletions

View File

@ -88,7 +88,7 @@ We select some of famous reinforcement learning platforms: 2 GitHub repos with m
| PG - CartPole | 9.03±4.18s | None | 15.77±6.28s | None | ? | | PG - CartPole | 9.03±4.18s | None | 15.77±6.28s | None | ? |
| DQN - CartPole | 10.61±5.51s | 1046.34±291.27s | 40.16±12.79s | 175.55±53.81s | ? | | DQN - CartPole | 10.61±5.51s | 1046.34±291.27s | 40.16±12.79s | 175.55±53.81s | ? |
| A2C - CartPole | 11.72±3.85s | *(~1612s) | 46.15±6.64s | Runtime Error | ? | | A2C - CartPole | 11.72±3.85s | *(~1612s) | 46.15±6.64s | Runtime Error | ? |
| PPO - CartPole | 35.25±16.47s | *(~1179s) | 62.21±13.31s (APPO) | 29.16±15.46s | ? | | PPO - CartPole | 32.55±10.09s | *(~1179s) | 62.21±13.31s (APPO) | 29.16±15.46s | ? |
| DDPG - Pendulum | 46.95±24.31s | *(>1h) | 377.99±13.79s | 652.83±471.28s | 172.18±62.48s | | DDPG - Pendulum | 46.95±24.31s | *(>1h) | 377.99±13.79s | 652.83±471.28s | 172.18±62.48s |
| TD3 - Pendulum | 48.39±7.22s | None | 620.83±248.43s | 619.33±324.97s | 210.31±76.30s | | TD3 - Pendulum | 48.39±7.22s | None | 620.83±248.43s | 619.33±324.97s | 210.31±76.30s |
| SAC - Pendulum | 38.92±2.09s | None | 92.68±4.48s | 808.21±405.70s | 295.92±140.85s | | SAC - Pendulum | 38.92±2.09s | None | 92.68±4.48s | 808.21±405.70s | 295.92±140.85s |

View File

@ -62,7 +62,7 @@ Here is about the experience of hyper-parameter tuning on CartPole and Pendulum:
* :class:`~tianshou.policy.PPOPolicy`: TBD * :class:`~tianshou.policy.PPOPolicy`: TBD
* :class:`~tianshou.policy.DDPGPolicy`, :class:`~tianshou.policy.TD3Policy`, and :class:`~tianshou.policy.SACPolicy`: We found two tricks. The first is to ignore the done flag. The second is to normalize reward to a standard normal distribution (it is against the theoretical analysis, but indeed works very well). The two tricks work amazingly on Mujoco tasks, typically with a faster converge speed (1M -> 200K). * :class:`~tianshou.policy.DDPGPolicy`, :class:`~tianshou.policy.TD3Policy`, and :class:`~tianshou.policy.SACPolicy`: We found two tricks. The first is to ignore the done flag. The second is to normalize reward to a standard normal distribution (it is against the theoretical analysis, but indeed works very well). The two tricks work amazingly on Mujoco tasks, typically with a faster converge speed (1M -> 200K).
* On-policy algorithms: increase the repeat-time (to 2 or 4) of the given batch in each training update will make the algorithm more stable. * On-policy algorithms: increase the repeat-time (to 2 or 4 for trivial benchmark, 10 for mujoco) of the given batch in each training update will make the algorithm more stable.
Code-level optimization Code-level optimization

View File

@ -39,7 +39,8 @@ class ActorProb(nn.Module):
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
self.model = nn.Sequential(*self.model) self.model = nn.Sequential(*self.model)
self.mu = nn.Linear(128, np.prod(action_shape)) self.mu = nn.Linear(128, np.prod(action_shape))
self.sigma = nn.Linear(128, np.prod(action_shape)) self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
# self.sigma = nn.Linear(128, np.prod(action_shape))
self._max = max_action self._max = max_action
def forward(self, s, **kwargs): def forward(self, s, **kwargs):
@ -48,8 +49,13 @@ class ActorProb(nn.Module):
batch = s.shape[0] batch = s.shape[0]
s = s.view(batch, -1) s = s.view(batch, -1)
logits = self.model(s) logits = self.model(s)
mu = self._max * torch.tanh(self.mu(logits)) mu = self.mu(logits)
sigma = torch.exp(self.sigma(logits)) shape = [1] * len(mu.shape)
shape[1] = -1
sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp()
# assert sigma.shape == mu.shape
# mu = self._max * torch.tanh(self.mu(logits))
# sigma = torch.exp(self.sigma(logits))
return (mu, sigma), None return (mu, sigma), None

View File

@ -36,6 +36,7 @@ def get_args():
parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.) parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--rew-norm', type=bool, default=True)
parser.add_argument( parser.add_argument(
'--device', type=str, '--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu') default='cuda' if torch.cuda.is_available() else 'cpu')
@ -77,7 +78,7 @@ def test_ddpg(args=get_args()):
actor, actor_optim, critic, critic_optim, actor, actor_optim, critic, critic_optim,
args.tau, args.gamma, args.exploration_noise, args.tau, args.gamma, args.exploration_noise,
[env.action_space.low[0], env.action_space.high[0]], [env.action_space.low[0], env.action_space.high[0]],
reward_normalization=True, ignore_done=True) reward_normalization=args.rew_norm, ignore_done=True)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))

View File

@ -20,7 +20,7 @@ else: # pytest
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=0) parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--gamma', type=float, default=0.99)
@ -28,9 +28,9 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=1) parser.add_argument('--collect-per-step', type=int, default=1)
parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--layer-num', type=int, default=2) parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.) parser.add_argument('--render', type=float, default=0.)
@ -39,16 +39,18 @@ def get_args():
default='cuda' if torch.cuda.is_available() else 'cpu') default='cuda' if torch.cuda.is_available() else 'cpu')
# ppo special # ppo special
parser.add_argument('--vf-coef', type=float, default=0.5) parser.add_argument('--vf-coef', type=float, default=0.5)
parser.add_argument('--ent-coef', type=float, default=0.0) parser.add_argument('--ent-coef', type=float, default=0.01)
parser.add_argument('--eps-clip', type=float, default=0.2) parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--max-grad-norm', type=float, default=0.5) parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--gae-lambda', type=float, default=0.95) parser.add_argument('--gae-lambda', type=float, default=0.95)
parser.add_argument('--rew-norm', type=bool, default=True)
parser.add_argument('--dual-clip', type=float, default=5.)
parser.add_argument('--value-clip', type=bool, default=True)
args = parser.parse_known_args()[0] args = parser.parse_known_args()[0]
return args return args
def _test_ppo(args=get_args()): def test_ppo(args=get_args()):
# just a demo, I have not made it work :(
torch.set_num_threads(1) # we just need only one thread for NN torch.set_num_threads(1) # we just need only one thread for NN
env = gym.make(args.task) env = gym.make(args.task)
if args.task == 'Pendulum-v0': if args.task == 'Pendulum-v0':
@ -85,7 +87,11 @@ def _test_ppo(args=get_args()):
eps_clip=args.eps_clip, eps_clip=args.eps_clip,
vf_coef=args.vf_coef, vf_coef=args.vf_coef,
ent_coef=args.ent_coef, ent_coef=args.ent_coef,
action_range=[env.action_space.low[0], env.action_space.high[0]], reward_normalization=args.rew_norm,
dual_clip=args.dual_clip,
value_clip=args.value_clip,
# action_range=[env.action_space.low[0], env.action_space.high[0]],)
# if clip the action, ppo would not converge :)
gae_lambda=args.gae_lambda) gae_lambda=args.gae_lambda)
# collector # collector
train_collector = Collector( train_collector = Collector(
@ -121,4 +127,4 @@ def _test_ppo(args=get_args()):
if __name__ == '__main__': if __name__ == '__main__':
_test_ppo() test_ppo()

View File

@ -37,6 +37,7 @@ def get_args():
parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.) parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--rew-norm', type=bool, default=True)
parser.add_argument( parser.add_argument(
'--device', type=str, '--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu') default='cuda' if torch.cuda.is_available() else 'cpu')
@ -82,7 +83,7 @@ def test_sac_with_il(args=get_args()):
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, args.alpha, args.tau, args.gamma, args.alpha,
[env.action_space.low[0], env.action_space.high[0]], [env.action_space.low[0], env.action_space.high[0]],
reward_normalization=True, ignore_done=True) reward_normalization=args.rew_norm, ignore_done=True)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))

View File

@ -39,6 +39,7 @@ def get_args():
parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.) parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--rew-norm', type=bool, default=True)
parser.add_argument( parser.add_argument(
'--device', type=str, '--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu') default='cuda' if torch.cuda.is_available() else 'cpu')
@ -85,7 +86,7 @@ def test_td3(args=get_args()):
args.tau, args.gamma, args.exploration_noise, args.policy_noise, args.tau, args.gamma, args.exploration_noise, args.policy_noise,
args.update_actor_freq, args.noise_clip, args.update_actor_freq, args.noise_clip,
[env.action_space.low[0], env.action_space.high[0]], [env.action_space.low[0], env.action_space.high[0]],
reward_normalization=True, ignore_done=True) reward_normalization=args.rew_norm, ignore_done=True)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))

View File

@ -57,6 +57,18 @@ def test_fn(size=2560):
batch = fn(batch, buf, 0) batch = fn(batch, buf, 0)
ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]) ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
assert abs(batch.returns - ans).sum() <= 1e-5 assert abs(batch.returns - ans).sum() <= 1e-5
batch = Batch(
done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]),
rew=np.array([
101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202])
)
v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3])
ret = policy.compute_episodic_return(batch, v, gamma=0.99, gae_lambda=0.95)
returns = np.array([
454.8344, 376.1143, 291.298, 200.,
464.5610, 383.1085, 295.387, 201.,
474.2876, 390.1027, 299.476, 202.])
assert abs(ret.returns - returns).sum() <= 1e-3
if __name__ == '__main__': if __name__ == '__main__':
batch = Batch( batch = Batch(
done=np.random.randint(100, size=size) == 0, done=np.random.randint(100, size=size) == 0,

View File

@ -20,17 +20,17 @@ else: # pytest
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--step-per-epoch', type=int, default=2000)
parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--collect-per-step', type=int, default=20)
parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=1) parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--training-num', type=int, default=32) parser.add_argument('--training-num', type=int, default=20)
parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.) parser.add_argument('--render', type=float, default=0.)
@ -42,7 +42,10 @@ def get_args():
parser.add_argument('--ent-coef', type=float, default=0.0) parser.add_argument('--ent-coef', type=float, default=0.0)
parser.add_argument('--eps-clip', type=float, default=0.2) parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--max-grad-norm', type=float, default=0.5) parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--gae-lambda', type=float, default=0.95) parser.add_argument('--gae-lambda', type=float, default=1)
parser.add_argument('--rew-norm', type=bool, default=True)
parser.add_argument('--dual-clip', type=float, default=None)
parser.add_argument('--value-clip', type=bool, default=True)
args = parser.parse_known_args()[0] args = parser.parse_known_args()[0]
return args return args
@ -78,7 +81,10 @@ def test_ppo(args=get_args()):
vf_coef=args.vf_coef, vf_coef=args.vf_coef,
ent_coef=args.ent_coef, ent_coef=args.ent_coef,
action_range=None, action_range=None,
gae_lambda=args.gae_lambda) gae_lambda=args.gae_lambda,
reward_normalization=args.rew_norm,
dual_clip=args.dual_clip,
value_clip=args.value_clip)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))

View File

@ -79,7 +79,7 @@ class Batch(object):
"""Return str(self).""" """Return str(self)."""
s = self.__class__.__name__ + '(\n' s = self.__class__.__name__ + '(\n'
flag = False flag = False
for k in self.__dict__.keys(): for k in sorted(self.__dict__.keys()):
if k[0] != '_' and self.__dict__[k] is not None: if k[0] != '_' and self.__dict__[k] is not None:
rpl = '\n' + ' ' * (6 + len(k)) rpl = '\n' + ' ' * (6 + len(k))
obj = str(self.__dict__[k]).replace('\n', rpl) obj = str(self.__dict__[k]).replace('\n', rpl)

View File

@ -76,7 +76,8 @@ class BasePolicy(ABC, nn.Module):
""" """
pass pass
def compute_episodic_return(self, batch, v_s_=None, @staticmethod
def compute_episodic_return(batch, v_s_=None,
gamma=0.99, gae_lambda=0.95): gamma=0.99, gae_lambda=0.95):
"""Compute returns over given full-length episodes, including the """Compute returns over given full-length episodes, including the
implementation of Generalized Advantage Estimation (arXiv:1506.02438). implementation of Generalized Advantage Estimation (arXiv:1506.02438).
@ -93,11 +94,11 @@ class BasePolicy(ABC, nn.Module):
""" """
if v_s_ is None: if v_s_ is None:
v_s_ = np.zeros_like(batch.rew) v_s_ = np.zeros_like(batch.rew)
else:
if not isinstance(v_s_, np.ndarray): if not isinstance(v_s_, np.ndarray):
v_s_ = np.array(v_s_, np.float) v_s_ = np.array(v_s_, np.float)
else: v_s_ = v_s_.reshape(batch.rew.shape)
v_s_ = v_s_.flatten() batch.returns = np.roll(v_s_, 1, axis=0)
batch.returns = np.roll(v_s_, 1)
m = (1. - batch.done) * gamma m = (1. - batch.done) * gamma
delta = batch.rew + v_s_ * m - batch.returns delta = batch.rew + v_s_ * m - batch.returns
m *= gae_lambda m *= gae_lambda

View File

@ -51,7 +51,7 @@ class A2CPolicy(PGPolicy):
batch, None, gamma=self._gamma, gae_lambda=self._lambda) batch, None, gamma=self._gamma, gae_lambda=self._lambda)
v_ = [] v_ = []
with torch.no_grad(): with torch.no_grad():
for b in batch.split(self._batch * 4, permute=False): for b in batch.split(self._batch, permute=False):
v_.append(self.critic(b.obs_next).detach().cpu().numpy()) v_.append(self.critic(b.obs_next).detach().cpu().numpy())
v_ = np.concatenate(v_, axis=0) v_ = np.concatenate(v_, axis=0)
return self.compute_episodic_return( return self.compute_episodic_return(

View File

@ -1,8 +1,6 @@
import torch import torch
import numpy as np import numpy as np
from torch import nn from torch import nn
from copy import deepcopy
import torch.nn.functional as F
from tianshou.data import Batch from tianshou.data import Batch
from tianshou.policy import PGPolicy from tianshou.policy import PGPolicy
@ -28,6 +26,13 @@ class PPOPolicy(PGPolicy):
:type action_range: [float, float] :type action_range: [float, float]
:param float gae_lambda: in [0, 1], param for Generalized Advantage :param float gae_lambda: in [0, 1], param for Generalized Advantage
Estimation, defaults to 0.95. Estimation, defaults to 0.95.
:param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5,
where c > 1 is a constant indicating the lower bound,
defaults to 5.0 (set ``None`` if you do not want to use it).
:param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1,
defaults to ``True``.
:param bool reward_normalization: normalize the returns to Normal(0, 1),
defaults to ``True``.
.. seealso:: .. seealso::
@ -36,13 +41,9 @@ class PPOPolicy(PGPolicy):
""" """
def __init__(self, actor, critic, optim, dist_fn, def __init__(self, actor, critic, optim, dist_fn,
discount_factor=0.99, discount_factor=0.99, max_grad_norm=.5, eps_clip=.2,
max_grad_norm=.5, vf_coef=.5, ent_coef=.01, action_range=None, gae_lambda=0.95,
eps_clip=.2, dual_clip=5., value_clip=True, reward_normalization=True,
vf_coef=.5,
ent_coef=.0,
action_range=None,
gae_lambda=0.95,
**kwargs): **kwargs):
super().__init__(None, None, dist_fn, discount_factor, **kwargs) super().__init__(None, None, dist_fn, discount_factor, **kwargs)
self._max_grad_norm = max_grad_norm self._max_grad_norm = max_grad_norm
@ -50,41 +51,36 @@ class PPOPolicy(PGPolicy):
self._w_vf = vf_coef self._w_vf = vf_coef
self._w_ent = ent_coef self._w_ent = ent_coef
self._range = action_range self._range = action_range
self.actor, self.actor_old = actor, deepcopy(actor) self.actor = actor
self.actor_old.eval() self.critic = critic
self.critic, self.critic_old = critic, deepcopy(critic)
self.critic_old.eval()
self.optim = optim self.optim = optim
self._batch = 64 self._batch = 64
assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].' assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].'
self._lambda = gae_lambda self._lambda = gae_lambda
assert dual_clip is None or dual_clip > 1, \
def train(self): 'Dual-clip PPO parameter should greater than 1.'
"""Set the module in training mode, except for the target network.""" self._dual_clip = dual_clip
self.training = True self._value_clip = value_clip
self.actor.train() self._rew_norm = reward_normalization
self.critic.train() self.__eps = np.finfo(np.float32).eps.item()
def eval(self):
"""Set the module in evaluation mode, except for the target network."""
self.training = False
self.actor.eval()
self.critic.eval()
def process_fn(self, batch, buffer, indice): def process_fn(self, batch, buffer, indice):
if self._rew_norm:
mean, std = batch.rew.mean(), batch.rew.std()
if std > self.__eps:
batch.rew = (batch.rew - mean) / std
if self._lambda in [0, 1]: if self._lambda in [0, 1]:
return self.compute_episodic_return( return self.compute_episodic_return(
batch, None, gamma=self._gamma, gae_lambda=self._lambda) batch, None, gamma=self._gamma, gae_lambda=self._lambda)
v_ = [] v_ = []
with torch.no_grad(): with torch.no_grad():
for b in batch.split(self._batch * 4, permute=False): for b in batch.split(self._batch, permute=False):
v_.append(self.critic(b.obs_next).detach().cpu().numpy()) v_.append(self.critic(b.obs_next))
v_ = np.concatenate(v_, axis=0) v_ = torch.cat(v_, dim=0).cpu().numpy()
batch.v_ = v_
return self.compute_episodic_return( return self.compute_episodic_return(
batch, v_, gamma=self._gamma, gae_lambda=self._lambda) batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
def forward(self, batch, state=None, model='actor', **kwargs): def forward(self, batch, state=None, **kwargs):
"""Compute action over the given batch data. """Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 4 keys: :return: A :class:`~tianshou.data.Batch` which has 4 keys:
@ -99,8 +95,7 @@ class PPOPolicy(PGPolicy):
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation. more detailed explanation.
""" """
model = getattr(self, model) logits, h = self.actor(batch.obs, state=state, info=batch.info)
logits, h = model(batch.obs, state=state, info=batch.info)
if isinstance(logits, tuple): if isinstance(logits, tuple):
dist = self.dist_fn(*logits) dist = self.dist_fn(*logits)
else: else:
@ -110,35 +105,54 @@ class PPOPolicy(PGPolicy):
act = act.clamp(self._range[0], self._range[1]) act = act.clamp(self._range[0], self._range[1])
return Batch(logits=logits, act=act, state=h, dist=dist) return Batch(logits=logits, act=act, state=h, dist=dist)
def sync_weight(self):
"""Synchronize the weight for the target network."""
self.actor_old.load_state_dict(self.actor.state_dict())
self.critic_old.load_state_dict(self.critic.state_dict())
def learn(self, batch, batch_size=None, repeat=1, **kwargs): def learn(self, batch, batch_size=None, repeat=1, **kwargs):
self._batch = batch_size self._batch = batch_size
losses, clip_losses, vf_losses, ent_losses = [], [], [], [] losses, clip_losses, vf_losses, ent_losses = [], [], [], []
r = batch.returns v = []
batch.returns = (r - r.mean()) / (r.std() + self._eps) old_log_prob = []
batch.act = torch.tensor(batch.act) with torch.no_grad():
batch.returns = torch.tensor(batch.returns)[:, None] for b in batch.split(batch_size, permute=False):
batch.v_ = torch.tensor(batch.v_) v.append(self.critic(b.obs))
old_log_prob.append(self(b).dist.log_prob(
torch.tensor(b.act, device=v[0].device)))
batch.v = torch.cat(v, dim=0) # old value
dev = batch.v.device
batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev)
batch.logp_old = torch.cat(old_log_prob, dim=0)
batch.returns = torch.tensor(
batch.returns, dtype=torch.float, device=dev
).reshape(batch.v.shape)
if self._rew_norm:
mean, std = batch.returns.mean(), batch.returns.std()
if std > self.__eps:
batch.returns = (batch.returns - mean) / std
batch.adv = batch.returns - batch.v
if self._rew_norm:
mean, std = batch.adv.mean(), batch.adv.std()
if std > self.__eps:
batch.adv = (batch.adv - mean) / std
for _ in range(repeat): for _ in range(repeat):
for b in batch.split(batch_size): for b in batch.split(batch_size):
vs_old = self.critic_old(b.obs)
vs__old = b.v_.to(vs_old.device)
dist = self(b).dist dist = self(b).dist
dist_old = self(b, model='actor_old').dist value = self.critic(b.obs)
target_v = b.returns.to(vs_old.device) + self._gamma * vs__old ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
adv = (target_v - vs_old).detach() surr1 = ratio * b.adv
a = b.act.to(adv.device)
ratio = torch.exp(dist.log_prob(a) - dist_old.log_prob(a))
surr1 = ratio * adv
surr2 = ratio.clamp( surr2 = ratio.clamp(
1. - self._eps_clip, 1. + self._eps_clip) * adv 1. - self._eps_clip, 1. + self._eps_clip) * b.adv
if self._dual_clip:
clip_loss = -torch.max(torch.min(surr1, surr2),
self._dual_clip * b.adv).mean()
else:
clip_loss = -torch.min(surr1, surr2).mean() clip_loss = -torch.min(surr1, surr2).mean()
clip_losses.append(clip_loss.item()) clip_losses.append(clip_loss.item())
vf_loss = F.smooth_l1_loss(self.critic(b.obs), target_v) if self._value_clip:
v_clip = b.v + (value - b.v).clamp(
-self._eps_clip, self._eps_clip)
vf1 = (b.returns - value).pow(2)
vf2 = (b.returns - v_clip).pow(2)
vf_loss = .5 * torch.max(vf1, vf2).mean()
else:
vf_loss = .5 * (b.returns - value).pow(2).mean()
vf_losses.append(vf_loss.item()) vf_losses.append(vf_loss.item())
e_loss = dist.entropy().mean() e_loss = dist.entropy().mean()
ent_losses.append(e_loss.item()) ent_losses.append(e_loss.item())
@ -150,7 +164,6 @@ class PPOPolicy(PGPolicy):
self.actor.parameters()) + list(self.critic.parameters()), self.actor.parameters()) + list(self.critic.parameters()),
self._max_grad_norm) self._max_grad_norm)
self.optim.step() self.optim.step()
self.sync_weight()
return { return {
'loss': losses, 'loss': losses,
'loss/clip': clip_losses, 'loss/clip': clip_losses,

View File

@ -26,13 +26,13 @@ class MovAvg(object):
def add(self, x): def add(self, x):
"""Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with """Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with
only one element, a python scalar, or a list of python scalar. It will only one element, a python scalar, or a list of python scalar. It will
automatically exclude the infinity. automatically exclude the infinity and NaN.
""" """
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
x = x.item() x = x.item()
if isinstance(x, list): if isinstance(x, list):
for _ in x: for _ in x:
if _ != np.inf: if _ not in [np.inf, np.nan, -np.inf]:
self.cache.append(_) self.cache.append(_)
elif x != np.inf: elif x != np.inf:
self.cache.append(x) self.cache.append(x)