Fix log_prob and PPO dual_clip (#49)

* Added DiagGaussian to fix log_probg

* Disable PPO dual_clip
This commit is contained in:
Imone 2020-05-18 16:23:35 +08:00 committed by GitHub
parent 70122dc03d
commit 57bca16f94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 21 additions and 6 deletions

View File

@ -8,6 +8,7 @@ from torch.utils.tensorboard import SummaryWriter
from tianshou.env import VectorEnv
from tianshou.policy import PPOPolicy
from tianshou.policy.utils import DiagGaussian
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
@ -44,7 +45,7 @@ def get_args():
parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--gae-lambda', type=float, default=0.95)
parser.add_argument('--rew-norm', type=bool, default=True)
parser.add_argument('--dual-clip', type=float, default=5.)
# parser.add_argument('--dual-clip', type=float, default=5.)
parser.add_argument('--value-clip', type=bool, default=True)
args = parser.parse_known_args()[0]
return args
@ -85,7 +86,7 @@ def test_ppo(args=get_args()):
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(list(
actor.parameters()) + list(critic.parameters()), lr=args.lr)
dist = torch.distributions.Normal
dist = DiagGaussian
policy = PPOPolicy(
actor, critic, optim, dist, args.gamma,
max_grad_norm=args.max_grad_norm,
@ -93,7 +94,8 @@ def test_ppo(args=get_args()):
vf_coef=args.vf_coef,
ent_coef=args.ent_coef,
reward_normalization=args.rew_norm,
dual_clip=args.dual_clip,
# dual_clip=args.dual_clip,
# dual clip cause monotonically increasing log_std :)
value_clip=args.value_clip,
# action_range=[env.action_space.low[0], env.action_space.high[0]],)
# if clip the action, ppo would not converge :)

View File

@ -53,7 +53,7 @@ class PPOPolicy(PGPolicy):
ent_coef: float = .01,
action_range: Optional[Tuple[float, float]] = None,
gae_lambda: float = 0.95,
dual_clip: float = 5.,
dual_clip: float = None,
value_clip: bool = True,
reward_normalization: bool = True,
**kwargs) -> None:

View File

@ -6,6 +6,7 @@ from typing import Dict, Tuple, Union, Optional
from tianshou.data import Batch
from tianshou.policy import DDPGPolicy
from tianshou.policy.utils import DiagGaussian
class SACPolicy(DDPGPolicy):
@ -94,13 +95,12 @@ class SACPolicy(DDPGPolicy):
obs = getattr(batch, input)
logits, h = self.actor(obs, state=state, info=batch.info)
assert isinstance(logits, tuple)
dist = torch.distributions.Normal(*logits)
dist = DiagGaussian(*logits)
x = dist.rsample()
y = torch.tanh(x)
act = y * self._action_scale + self._action_bias
log_prob = dist.log_prob(x) - torch.log(
self._action_scale * (1 - y.pow(2)) + self.__eps)
log_prob = torch.unsqueeze(torch.sum(log_prob, 1), 1)
act = act.clamp(self._range[0], self._range[1])
return Batch(
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)

13
tianshou/policy/utils.py Normal file
View File

@ -0,0 +1,13 @@
import torch
class DiagGaussian(torch.distributions.Normal):
"""Diagonal Gaussian Distribution
"""
def log_prob(self, actions):
return super().log_prob(actions).sum(-1, keepdim=True)
def entropy(self):
return super().entropy().sum(-1)