fix ppo
This commit is contained in:
parent
680fc0ffbe
commit
6bf1ea644d
@ -88,7 +88,7 @@ We select some of famous reinforcement learning platforms: 2 GitHub repos with m
|
|||||||
| PG - CartPole | 9.03±4.18s | None | 15.77±6.28s | None | ? |
|
| PG - CartPole | 9.03±4.18s | None | 15.77±6.28s | None | ? |
|
||||||
| DQN - CartPole | 10.61±5.51s | 1046.34±291.27s | 40.16±12.79s | 175.55±53.81s | ? |
|
| DQN - CartPole | 10.61±5.51s | 1046.34±291.27s | 40.16±12.79s | 175.55±53.81s | ? |
|
||||||
| A2C - CartPole | 11.72±3.85s | *(~1612s) | 46.15±6.64s | Runtime Error | ? |
|
| A2C - CartPole | 11.72±3.85s | *(~1612s) | 46.15±6.64s | Runtime Error | ? |
|
||||||
| PPO - CartPole | 35.25±16.47s | *(~1179s) | 62.21±13.31s (APPO) | 29.16±15.46s | ? |
|
| PPO - CartPole | 32.55±10.09s | *(~1179s) | 62.21±13.31s (APPO) | 29.16±15.46s | ? |
|
||||||
| DDPG - Pendulum | 46.95±24.31s | *(>1h) | 377.99±13.79s | 652.83±471.28s | 172.18±62.48s |
|
| DDPG - Pendulum | 46.95±24.31s | *(>1h) | 377.99±13.79s | 652.83±471.28s | 172.18±62.48s |
|
||||||
| TD3 - Pendulum | 48.39±7.22s | None | 620.83±248.43s | 619.33±324.97s | 210.31±76.30s |
|
| TD3 - Pendulum | 48.39±7.22s | None | 620.83±248.43s | 619.33±324.97s | 210.31±76.30s |
|
||||||
| SAC - Pendulum | 38.92±2.09s | None | 92.68±4.48s | 808.21±405.70s | 295.92±140.85s |
|
| SAC - Pendulum | 38.92±2.09s | None | 92.68±4.48s | 808.21±405.70s | 295.92±140.85s |
|
||||||
|
|||||||
@ -62,7 +62,7 @@ Here is about the experience of hyper-parameter tuning on CartPole and Pendulum:
|
|||||||
* :class:`~tianshou.policy.PPOPolicy`: TBD
|
* :class:`~tianshou.policy.PPOPolicy`: TBD
|
||||||
* :class:`~tianshou.policy.DDPGPolicy`, :class:`~tianshou.policy.TD3Policy`, and :class:`~tianshou.policy.SACPolicy`: We found two tricks. The first is to ignore the done flag. The second is to normalize reward to a standard normal distribution (it is against the theoretical analysis, but indeed works very well). The two tricks work amazingly on Mujoco tasks, typically with a faster converge speed (1M -> 200K).
|
* :class:`~tianshou.policy.DDPGPolicy`, :class:`~tianshou.policy.TD3Policy`, and :class:`~tianshou.policy.SACPolicy`: We found two tricks. The first is to ignore the done flag. The second is to normalize reward to a standard normal distribution (it is against the theoretical analysis, but indeed works very well). The two tricks work amazingly on Mujoco tasks, typically with a faster converge speed (1M -> 200K).
|
||||||
|
|
||||||
* On-policy algorithms: increase the repeat-time (to 2 or 4) of the given batch in each training update will make the algorithm more stable.
|
* On-policy algorithms: increase the repeat-time (to 2 or 4 for trivial benchmark, 10 for mujoco) of the given batch in each training update will make the algorithm more stable.
|
||||||
|
|
||||||
|
|
||||||
Code-level optimization
|
Code-level optimization
|
||||||
|
|||||||
@ -39,7 +39,8 @@ class ActorProb(nn.Module):
|
|||||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
||||||
self.model = nn.Sequential(*self.model)
|
self.model = nn.Sequential(*self.model)
|
||||||
self.mu = nn.Linear(128, np.prod(action_shape))
|
self.mu = nn.Linear(128, np.prod(action_shape))
|
||||||
self.sigma = nn.Linear(128, np.prod(action_shape))
|
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
|
||||||
|
# self.sigma = nn.Linear(128, np.prod(action_shape))
|
||||||
self._max = max_action
|
self._max = max_action
|
||||||
|
|
||||||
def forward(self, s, **kwargs):
|
def forward(self, s, **kwargs):
|
||||||
@ -48,8 +49,13 @@ class ActorProb(nn.Module):
|
|||||||
batch = s.shape[0]
|
batch = s.shape[0]
|
||||||
s = s.view(batch, -1)
|
s = s.view(batch, -1)
|
||||||
logits = self.model(s)
|
logits = self.model(s)
|
||||||
mu = self._max * torch.tanh(self.mu(logits))
|
mu = self.mu(logits)
|
||||||
sigma = torch.exp(self.sigma(logits))
|
shape = [1] * len(mu.shape)
|
||||||
|
shape[1] = -1
|
||||||
|
sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp()
|
||||||
|
# assert sigma.shape == mu.shape
|
||||||
|
# mu = self._max * torch.tanh(self.mu(logits))
|
||||||
|
# sigma = torch.exp(self.sigma(logits))
|
||||||
return (mu, sigma), None
|
return (mu, sigma), None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -36,6 +36,7 @@ def get_args():
|
|||||||
parser.add_argument('--test-num', type=int, default=100)
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
parser.add_argument('--logdir', type=str, default='log')
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
parser.add_argument('--render', type=float, default=0.)
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
|
parser.add_argument('--rew-norm', type=bool, default=True)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--device', type=str,
|
'--device', type=str,
|
||||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
@ -77,7 +78,7 @@ def test_ddpg(args=get_args()):
|
|||||||
actor, actor_optim, critic, critic_optim,
|
actor, actor_optim, critic, critic_optim,
|
||||||
args.tau, args.gamma, args.exploration_noise,
|
args.tau, args.gamma, args.exploration_noise,
|
||||||
[env.action_space.low[0], env.action_space.high[0]],
|
[env.action_space.low[0], env.action_space.high[0]],
|
||||||
reward_normalization=True, ignore_done=True)
|
reward_normalization=args.rew_norm, ignore_done=True)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
|
|||||||
@ -20,7 +20,7 @@ else: # pytest
|
|||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||||
parser.add_argument('--seed', type=int, default=0)
|
parser.add_argument('--seed', type=int, default=1626)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--lr', type=float, default=1e-3)
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
parser.add_argument('--gamma', type=float, default=0.99)
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
@ -28,9 +28,9 @@ def get_args():
|
|||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=1)
|
parser.add_argument('--collect-per-step', type=int, default=1)
|
||||||
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=128)
|
||||||
parser.add_argument('--layer-num', type=int, default=2)
|
parser.add_argument('--layer-num', type=int, default=1)
|
||||||
parser.add_argument('--training-num', type=int, default=8)
|
parser.add_argument('--training-num', type=int, default=16)
|
||||||
parser.add_argument('--test-num', type=int, default=100)
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
parser.add_argument('--logdir', type=str, default='log')
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
parser.add_argument('--render', type=float, default=0.)
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
@ -39,16 +39,18 @@ def get_args():
|
|||||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
# ppo special
|
# ppo special
|
||||||
parser.add_argument('--vf-coef', type=float, default=0.5)
|
parser.add_argument('--vf-coef', type=float, default=0.5)
|
||||||
parser.add_argument('--ent-coef', type=float, default=0.0)
|
parser.add_argument('--ent-coef', type=float, default=0.01)
|
||||||
parser.add_argument('--eps-clip', type=float, default=0.2)
|
parser.add_argument('--eps-clip', type=float, default=0.2)
|
||||||
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
||||||
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
||||||
|
parser.add_argument('--rew-norm', type=bool, default=True)
|
||||||
|
parser.add_argument('--dual-clip', type=float, default=5.)
|
||||||
|
parser.add_argument('--value-clip', type=bool, default=True)
|
||||||
args = parser.parse_known_args()[0]
|
args = parser.parse_known_args()[0]
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def _test_ppo(args=get_args()):
|
def test_ppo(args=get_args()):
|
||||||
# just a demo, I have not made it work :(
|
|
||||||
torch.set_num_threads(1) # we just need only one thread for NN
|
torch.set_num_threads(1) # we just need only one thread for NN
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
if args.task == 'Pendulum-v0':
|
if args.task == 'Pendulum-v0':
|
||||||
@ -85,7 +87,11 @@ def _test_ppo(args=get_args()):
|
|||||||
eps_clip=args.eps_clip,
|
eps_clip=args.eps_clip,
|
||||||
vf_coef=args.vf_coef,
|
vf_coef=args.vf_coef,
|
||||||
ent_coef=args.ent_coef,
|
ent_coef=args.ent_coef,
|
||||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
reward_normalization=args.rew_norm,
|
||||||
|
dual_clip=args.dual_clip,
|
||||||
|
value_clip=args.value_clip,
|
||||||
|
# action_range=[env.action_space.low[0], env.action_space.high[0]],)
|
||||||
|
# if clip the action, ppo would not converge :)
|
||||||
gae_lambda=args.gae_lambda)
|
gae_lambda=args.gae_lambda)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
@ -121,4 +127,4 @@ def _test_ppo(args=get_args()):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
_test_ppo()
|
test_ppo()
|
||||||
|
|||||||
@ -37,6 +37,7 @@ def get_args():
|
|||||||
parser.add_argument('--test-num', type=int, default=100)
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
parser.add_argument('--logdir', type=str, default='log')
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
parser.add_argument('--render', type=float, default=0.)
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
|
parser.add_argument('--rew-norm', type=bool, default=True)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--device', type=str,
|
'--device', type=str,
|
||||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
@ -82,7 +83,7 @@ def test_sac_with_il(args=get_args()):
|
|||||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||||
args.tau, args.gamma, args.alpha,
|
args.tau, args.gamma, args.alpha,
|
||||||
[env.action_space.low[0], env.action_space.high[0]],
|
[env.action_space.low[0], env.action_space.high[0]],
|
||||||
reward_normalization=True, ignore_done=True)
|
reward_normalization=args.rew_norm, ignore_done=True)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
|
|||||||
@ -39,6 +39,7 @@ def get_args():
|
|||||||
parser.add_argument('--test-num', type=int, default=100)
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
parser.add_argument('--logdir', type=str, default='log')
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
parser.add_argument('--render', type=float, default=0.)
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
|
parser.add_argument('--rew-norm', type=bool, default=True)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--device', type=str,
|
'--device', type=str,
|
||||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
@ -85,7 +86,7 @@ def test_td3(args=get_args()):
|
|||||||
args.tau, args.gamma, args.exploration_noise, args.policy_noise,
|
args.tau, args.gamma, args.exploration_noise, args.policy_noise,
|
||||||
args.update_actor_freq, args.noise_clip,
|
args.update_actor_freq, args.noise_clip,
|
||||||
[env.action_space.low[0], env.action_space.high[0]],
|
[env.action_space.low[0], env.action_space.high[0]],
|
||||||
reward_normalization=True, ignore_done=True)
|
reward_normalization=args.rew_norm, ignore_done=True)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
|
|||||||
@ -57,6 +57,18 @@ def test_fn(size=2560):
|
|||||||
batch = fn(batch, buf, 0)
|
batch = fn(batch, buf, 0)
|
||||||
ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
|
ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
|
||||||
assert abs(batch.returns - ans).sum() <= 1e-5
|
assert abs(batch.returns - ans).sum() <= 1e-5
|
||||||
|
batch = Batch(
|
||||||
|
done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]),
|
||||||
|
rew=np.array([
|
||||||
|
101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202])
|
||||||
|
)
|
||||||
|
v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3])
|
||||||
|
ret = policy.compute_episodic_return(batch, v, gamma=0.99, gae_lambda=0.95)
|
||||||
|
returns = np.array([
|
||||||
|
454.8344, 376.1143, 291.298, 200.,
|
||||||
|
464.5610, 383.1085, 295.387, 201.,
|
||||||
|
474.2876, 390.1027, 299.476, 202.])
|
||||||
|
assert abs(ret.returns - returns).sum() <= 1e-3
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
batch = Batch(
|
batch = Batch(
|
||||||
done=np.random.randint(100, size=size) == 0,
|
done=np.random.randint(100, size=size) == 0,
|
||||||
|
|||||||
@ -20,17 +20,17 @@ else: # pytest
|
|||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||||
parser.add_argument('--seed', type=int, default=1626)
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--lr', type=float, default=1e-3)
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
parser.add_argument('--gamma', type=float, default=0.9)
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=10)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=2000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--collect-per-step', type=int, default=20)
|
||||||
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
parser.add_argument('--layer-num', type=int, default=1)
|
parser.add_argument('--layer-num', type=int, default=1)
|
||||||
parser.add_argument('--training-num', type=int, default=32)
|
parser.add_argument('--training-num', type=int, default=20)
|
||||||
parser.add_argument('--test-num', type=int, default=100)
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
parser.add_argument('--logdir', type=str, default='log')
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
parser.add_argument('--render', type=float, default=0.)
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
@ -42,7 +42,10 @@ def get_args():
|
|||||||
parser.add_argument('--ent-coef', type=float, default=0.0)
|
parser.add_argument('--ent-coef', type=float, default=0.0)
|
||||||
parser.add_argument('--eps-clip', type=float, default=0.2)
|
parser.add_argument('--eps-clip', type=float, default=0.2)
|
||||||
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
||||||
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
parser.add_argument('--gae-lambda', type=float, default=1)
|
||||||
|
parser.add_argument('--rew-norm', type=bool, default=True)
|
||||||
|
parser.add_argument('--dual-clip', type=float, default=None)
|
||||||
|
parser.add_argument('--value-clip', type=bool, default=True)
|
||||||
args = parser.parse_known_args()[0]
|
args = parser.parse_known_args()[0]
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@ -78,7 +81,10 @@ def test_ppo(args=get_args()):
|
|||||||
vf_coef=args.vf_coef,
|
vf_coef=args.vf_coef,
|
||||||
ent_coef=args.ent_coef,
|
ent_coef=args.ent_coef,
|
||||||
action_range=None,
|
action_range=None,
|
||||||
gae_lambda=args.gae_lambda)
|
gae_lambda=args.gae_lambda,
|
||||||
|
reward_normalization=args.rew_norm,
|
||||||
|
dual_clip=args.dual_clip,
|
||||||
|
value_clip=args.value_clip)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
|
|||||||
@ -79,7 +79,7 @@ class Batch(object):
|
|||||||
"""Return str(self)."""
|
"""Return str(self)."""
|
||||||
s = self.__class__.__name__ + '(\n'
|
s = self.__class__.__name__ + '(\n'
|
||||||
flag = False
|
flag = False
|
||||||
for k in self.__dict__.keys():
|
for k in sorted(self.__dict__.keys()):
|
||||||
if k[0] != '_' and self.__dict__[k] is not None:
|
if k[0] != '_' and self.__dict__[k] is not None:
|
||||||
rpl = '\n' + ' ' * (6 + len(k))
|
rpl = '\n' + ' ' * (6 + len(k))
|
||||||
obj = str(self.__dict__[k]).replace('\n', rpl)
|
obj = str(self.__dict__[k]).replace('\n', rpl)
|
||||||
|
|||||||
@ -76,7 +76,8 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def compute_episodic_return(self, batch, v_s_=None,
|
@staticmethod
|
||||||
|
def compute_episodic_return(batch, v_s_=None,
|
||||||
gamma=0.99, gae_lambda=0.95):
|
gamma=0.99, gae_lambda=0.95):
|
||||||
"""Compute returns over given full-length episodes, including the
|
"""Compute returns over given full-length episodes, including the
|
||||||
implementation of Generalized Advantage Estimation (arXiv:1506.02438).
|
implementation of Generalized Advantage Estimation (arXiv:1506.02438).
|
||||||
@ -93,11 +94,11 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
"""
|
"""
|
||||||
if v_s_ is None:
|
if v_s_ is None:
|
||||||
v_s_ = np.zeros_like(batch.rew)
|
v_s_ = np.zeros_like(batch.rew)
|
||||||
if not isinstance(v_s_, np.ndarray):
|
|
||||||
v_s_ = np.array(v_s_, np.float)
|
|
||||||
else:
|
else:
|
||||||
v_s_ = v_s_.flatten()
|
if not isinstance(v_s_, np.ndarray):
|
||||||
batch.returns = np.roll(v_s_, 1)
|
v_s_ = np.array(v_s_, np.float)
|
||||||
|
v_s_ = v_s_.reshape(batch.rew.shape)
|
||||||
|
batch.returns = np.roll(v_s_, 1, axis=0)
|
||||||
m = (1. - batch.done) * gamma
|
m = (1. - batch.done) * gamma
|
||||||
delta = batch.rew + v_s_ * m - batch.returns
|
delta = batch.rew + v_s_ * m - batch.returns
|
||||||
m *= gae_lambda
|
m *= gae_lambda
|
||||||
|
|||||||
@ -51,7 +51,7 @@ class A2CPolicy(PGPolicy):
|
|||||||
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
||||||
v_ = []
|
v_ = []
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for b in batch.split(self._batch * 4, permute=False):
|
for b in batch.split(self._batch, permute=False):
|
||||||
v_.append(self.critic(b.obs_next).detach().cpu().numpy())
|
v_.append(self.critic(b.obs_next).detach().cpu().numpy())
|
||||||
v_ = np.concatenate(v_, axis=0)
|
v_ = np.concatenate(v_, axis=0)
|
||||||
return self.compute_episodic_return(
|
return self.compute_episodic_return(
|
||||||
|
|||||||
@ -1,8 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from copy import deepcopy
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from tianshou.data import Batch
|
from tianshou.data import Batch
|
||||||
from tianshou.policy import PGPolicy
|
from tianshou.policy import PGPolicy
|
||||||
@ -28,6 +26,13 @@ class PPOPolicy(PGPolicy):
|
|||||||
:type action_range: [float, float]
|
:type action_range: [float, float]
|
||||||
:param float gae_lambda: in [0, 1], param for Generalized Advantage
|
:param float gae_lambda: in [0, 1], param for Generalized Advantage
|
||||||
Estimation, defaults to 0.95.
|
Estimation, defaults to 0.95.
|
||||||
|
:param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5,
|
||||||
|
where c > 1 is a constant indicating the lower bound,
|
||||||
|
defaults to 5.0 (set ``None`` if you do not want to use it).
|
||||||
|
:param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1,
|
||||||
|
defaults to ``True``.
|
||||||
|
:param bool reward_normalization: normalize the returns to Normal(0, 1),
|
||||||
|
defaults to ``True``.
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
@ -36,13 +41,9 @@ class PPOPolicy(PGPolicy):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, actor, critic, optim, dist_fn,
|
def __init__(self, actor, critic, optim, dist_fn,
|
||||||
discount_factor=0.99,
|
discount_factor=0.99, max_grad_norm=.5, eps_clip=.2,
|
||||||
max_grad_norm=.5,
|
vf_coef=.5, ent_coef=.01, action_range=None, gae_lambda=0.95,
|
||||||
eps_clip=.2,
|
dual_clip=5., value_clip=True, reward_normalization=True,
|
||||||
vf_coef=.5,
|
|
||||||
ent_coef=.0,
|
|
||||||
action_range=None,
|
|
||||||
gae_lambda=0.95,
|
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__(None, None, dist_fn, discount_factor, **kwargs)
|
super().__init__(None, None, dist_fn, discount_factor, **kwargs)
|
||||||
self._max_grad_norm = max_grad_norm
|
self._max_grad_norm = max_grad_norm
|
||||||
@ -50,41 +51,36 @@ class PPOPolicy(PGPolicy):
|
|||||||
self._w_vf = vf_coef
|
self._w_vf = vf_coef
|
||||||
self._w_ent = ent_coef
|
self._w_ent = ent_coef
|
||||||
self._range = action_range
|
self._range = action_range
|
||||||
self.actor, self.actor_old = actor, deepcopy(actor)
|
self.actor = actor
|
||||||
self.actor_old.eval()
|
self.critic = critic
|
||||||
self.critic, self.critic_old = critic, deepcopy(critic)
|
|
||||||
self.critic_old.eval()
|
|
||||||
self.optim = optim
|
self.optim = optim
|
||||||
self._batch = 64
|
self._batch = 64
|
||||||
assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].'
|
assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].'
|
||||||
self._lambda = gae_lambda
|
self._lambda = gae_lambda
|
||||||
|
assert dual_clip is None or dual_clip > 1, \
|
||||||
def train(self):
|
'Dual-clip PPO parameter should greater than 1.'
|
||||||
"""Set the module in training mode, except for the target network."""
|
self._dual_clip = dual_clip
|
||||||
self.training = True
|
self._value_clip = value_clip
|
||||||
self.actor.train()
|
self._rew_norm = reward_normalization
|
||||||
self.critic.train()
|
self.__eps = np.finfo(np.float32).eps.item()
|
||||||
|
|
||||||
def eval(self):
|
|
||||||
"""Set the module in evaluation mode, except for the target network."""
|
|
||||||
self.training = False
|
|
||||||
self.actor.eval()
|
|
||||||
self.critic.eval()
|
|
||||||
|
|
||||||
def process_fn(self, batch, buffer, indice):
|
def process_fn(self, batch, buffer, indice):
|
||||||
|
if self._rew_norm:
|
||||||
|
mean, std = batch.rew.mean(), batch.rew.std()
|
||||||
|
if std > self.__eps:
|
||||||
|
batch.rew = (batch.rew - mean) / std
|
||||||
if self._lambda in [0, 1]:
|
if self._lambda in [0, 1]:
|
||||||
return self.compute_episodic_return(
|
return self.compute_episodic_return(
|
||||||
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
||||||
v_ = []
|
v_ = []
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for b in batch.split(self._batch * 4, permute=False):
|
for b in batch.split(self._batch, permute=False):
|
||||||
v_.append(self.critic(b.obs_next).detach().cpu().numpy())
|
v_.append(self.critic(b.obs_next))
|
||||||
v_ = np.concatenate(v_, axis=0)
|
v_ = torch.cat(v_, dim=0).cpu().numpy()
|
||||||
batch.v_ = v_
|
|
||||||
return self.compute_episodic_return(
|
return self.compute_episodic_return(
|
||||||
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
||||||
|
|
||||||
def forward(self, batch, state=None, model='actor', **kwargs):
|
def forward(self, batch, state=None, **kwargs):
|
||||||
"""Compute action over the given batch data.
|
"""Compute action over the given batch data.
|
||||||
|
|
||||||
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
|
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
|
||||||
@ -99,8 +95,7 @@ class PPOPolicy(PGPolicy):
|
|||||||
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
||||||
more detailed explanation.
|
more detailed explanation.
|
||||||
"""
|
"""
|
||||||
model = getattr(self, model)
|
logits, h = self.actor(batch.obs, state=state, info=batch.info)
|
||||||
logits, h = model(batch.obs, state=state, info=batch.info)
|
|
||||||
if isinstance(logits, tuple):
|
if isinstance(logits, tuple):
|
||||||
dist = self.dist_fn(*logits)
|
dist = self.dist_fn(*logits)
|
||||||
else:
|
else:
|
||||||
@ -110,35 +105,54 @@ class PPOPolicy(PGPolicy):
|
|||||||
act = act.clamp(self._range[0], self._range[1])
|
act = act.clamp(self._range[0], self._range[1])
|
||||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||||
|
|
||||||
def sync_weight(self):
|
|
||||||
"""Synchronize the weight for the target network."""
|
|
||||||
self.actor_old.load_state_dict(self.actor.state_dict())
|
|
||||||
self.critic_old.load_state_dict(self.critic.state_dict())
|
|
||||||
|
|
||||||
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
|
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
|
||||||
self._batch = batch_size
|
self._batch = batch_size
|
||||||
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
||||||
r = batch.returns
|
v = []
|
||||||
batch.returns = (r - r.mean()) / (r.std() + self._eps)
|
old_log_prob = []
|
||||||
batch.act = torch.tensor(batch.act)
|
with torch.no_grad():
|
||||||
batch.returns = torch.tensor(batch.returns)[:, None]
|
for b in batch.split(batch_size, permute=False):
|
||||||
batch.v_ = torch.tensor(batch.v_)
|
v.append(self.critic(b.obs))
|
||||||
|
old_log_prob.append(self(b).dist.log_prob(
|
||||||
|
torch.tensor(b.act, device=v[0].device)))
|
||||||
|
batch.v = torch.cat(v, dim=0) # old value
|
||||||
|
dev = batch.v.device
|
||||||
|
batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev)
|
||||||
|
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||||
|
batch.returns = torch.tensor(
|
||||||
|
batch.returns, dtype=torch.float, device=dev
|
||||||
|
).reshape(batch.v.shape)
|
||||||
|
if self._rew_norm:
|
||||||
|
mean, std = batch.returns.mean(), batch.returns.std()
|
||||||
|
if std > self.__eps:
|
||||||
|
batch.returns = (batch.returns - mean) / std
|
||||||
|
batch.adv = batch.returns - batch.v
|
||||||
|
if self._rew_norm:
|
||||||
|
mean, std = batch.adv.mean(), batch.adv.std()
|
||||||
|
if std > self.__eps:
|
||||||
|
batch.adv = (batch.adv - mean) / std
|
||||||
for _ in range(repeat):
|
for _ in range(repeat):
|
||||||
for b in batch.split(batch_size):
|
for b in batch.split(batch_size):
|
||||||
vs_old = self.critic_old(b.obs)
|
|
||||||
vs__old = b.v_.to(vs_old.device)
|
|
||||||
dist = self(b).dist
|
dist = self(b).dist
|
||||||
dist_old = self(b, model='actor_old').dist
|
value = self.critic(b.obs)
|
||||||
target_v = b.returns.to(vs_old.device) + self._gamma * vs__old
|
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
||||||
adv = (target_v - vs_old).detach()
|
surr1 = ratio * b.adv
|
||||||
a = b.act.to(adv.device)
|
|
||||||
ratio = torch.exp(dist.log_prob(a) - dist_old.log_prob(a))
|
|
||||||
surr1 = ratio * adv
|
|
||||||
surr2 = ratio.clamp(
|
surr2 = ratio.clamp(
|
||||||
1. - self._eps_clip, 1. + self._eps_clip) * adv
|
1. - self._eps_clip, 1. + self._eps_clip) * b.adv
|
||||||
clip_loss = -torch.min(surr1, surr2).mean()
|
if self._dual_clip:
|
||||||
|
clip_loss = -torch.max(torch.min(surr1, surr2),
|
||||||
|
self._dual_clip * b.adv).mean()
|
||||||
|
else:
|
||||||
|
clip_loss = -torch.min(surr1, surr2).mean()
|
||||||
clip_losses.append(clip_loss.item())
|
clip_losses.append(clip_loss.item())
|
||||||
vf_loss = F.smooth_l1_loss(self.critic(b.obs), target_v)
|
if self._value_clip:
|
||||||
|
v_clip = b.v + (value - b.v).clamp(
|
||||||
|
-self._eps_clip, self._eps_clip)
|
||||||
|
vf1 = (b.returns - value).pow(2)
|
||||||
|
vf2 = (b.returns - v_clip).pow(2)
|
||||||
|
vf_loss = .5 * torch.max(vf1, vf2).mean()
|
||||||
|
else:
|
||||||
|
vf_loss = .5 * (b.returns - value).pow(2).mean()
|
||||||
vf_losses.append(vf_loss.item())
|
vf_losses.append(vf_loss.item())
|
||||||
e_loss = dist.entropy().mean()
|
e_loss = dist.entropy().mean()
|
||||||
ent_losses.append(e_loss.item())
|
ent_losses.append(e_loss.item())
|
||||||
@ -150,7 +164,6 @@ class PPOPolicy(PGPolicy):
|
|||||||
self.actor.parameters()) + list(self.critic.parameters()),
|
self.actor.parameters()) + list(self.critic.parameters()),
|
||||||
self._max_grad_norm)
|
self._max_grad_norm)
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
self.sync_weight()
|
|
||||||
return {
|
return {
|
||||||
'loss': losses,
|
'loss': losses,
|
||||||
'loss/clip': clip_losses,
|
'loss/clip': clip_losses,
|
||||||
|
|||||||
@ -26,13 +26,13 @@ class MovAvg(object):
|
|||||||
def add(self, x):
|
def add(self, x):
|
||||||
"""Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with
|
"""Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with
|
||||||
only one element, a python scalar, or a list of python scalar. It will
|
only one element, a python scalar, or a list of python scalar. It will
|
||||||
automatically exclude the infinity.
|
automatically exclude the infinity and NaN.
|
||||||
"""
|
"""
|
||||||
if isinstance(x, torch.Tensor):
|
if isinstance(x, torch.Tensor):
|
||||||
x = x.item()
|
x = x.item()
|
||||||
if isinstance(x, list):
|
if isinstance(x, list):
|
||||||
for _ in x:
|
for _ in x:
|
||||||
if _ != np.inf:
|
if _ not in [np.inf, np.nan, -np.inf]:
|
||||||
self.cache.append(_)
|
self.cache.append(_)
|
||||||
elif x != np.inf:
|
elif x != np.inf:
|
||||||
self.cache.append(x)
|
self.cache.append(x)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user