Fix log_prob and PPO dual_clip (#49)
* Added DiagGaussian to fix log_probg * Disable PPO dual_clip
This commit is contained in:
parent
70122dc03d
commit
57bca16f94
@ -8,6 +8,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.policy import PPOPolicy
|
||||
from tianshou.policy.utils import DiagGaussian
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
|
||||
@ -44,7 +45,7 @@ def get_args():
|
||||
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
||||
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
||||
parser.add_argument('--rew-norm', type=bool, default=True)
|
||||
parser.add_argument('--dual-clip', type=float, default=5.)
|
||||
# parser.add_argument('--dual-clip', type=float, default=5.)
|
||||
parser.add_argument('--value-clip', type=bool, default=True)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
@ -85,7 +86,7 @@ def test_ppo(args=get_args()):
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
optim = torch.optim.Adam(list(
|
||||
actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
||||
dist = torch.distributions.Normal
|
||||
dist = DiagGaussian
|
||||
policy = PPOPolicy(
|
||||
actor, critic, optim, dist, args.gamma,
|
||||
max_grad_norm=args.max_grad_norm,
|
||||
@ -93,7 +94,8 @@ def test_ppo(args=get_args()):
|
||||
vf_coef=args.vf_coef,
|
||||
ent_coef=args.ent_coef,
|
||||
reward_normalization=args.rew_norm,
|
||||
dual_clip=args.dual_clip,
|
||||
# dual_clip=args.dual_clip,
|
||||
# dual clip cause monotonically increasing log_std :)
|
||||
value_clip=args.value_clip,
|
||||
# action_range=[env.action_space.low[0], env.action_space.high[0]],)
|
||||
# if clip the action, ppo would not converge :)
|
||||
|
@ -53,7 +53,7 @@ class PPOPolicy(PGPolicy):
|
||||
ent_coef: float = .01,
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
gae_lambda: float = 0.95,
|
||||
dual_clip: float = 5.,
|
||||
dual_clip: float = None,
|
||||
value_clip: bool = True,
|
||||
reward_normalization: bool = True,
|
||||
**kwargs) -> None:
|
||||
|
@ -6,6 +6,7 @@ from typing import Dict, Tuple, Union, Optional
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.policy.utils import DiagGaussian
|
||||
|
||||
|
||||
class SACPolicy(DDPGPolicy):
|
||||
@ -94,13 +95,12 @@ class SACPolicy(DDPGPolicy):
|
||||
obs = getattr(batch, input)
|
||||
logits, h = self.actor(obs, state=state, info=batch.info)
|
||||
assert isinstance(logits, tuple)
|
||||
dist = torch.distributions.Normal(*logits)
|
||||
dist = DiagGaussian(*logits)
|
||||
x = dist.rsample()
|
||||
y = torch.tanh(x)
|
||||
act = y * self._action_scale + self._action_bias
|
||||
log_prob = dist.log_prob(x) - torch.log(
|
||||
self._action_scale * (1 - y.pow(2)) + self.__eps)
|
||||
log_prob = torch.unsqueeze(torch.sum(log_prob, 1), 1)
|
||||
act = act.clamp(self._range[0], self._range[1])
|
||||
return Batch(
|
||||
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
|
||||
|
13
tianshou/policy/utils.py
Normal file
13
tianshou/policy/utils.py
Normal file
@ -0,0 +1,13 @@
|
||||
import torch
|
||||
|
||||
|
||||
class DiagGaussian(torch.distributions.Normal):
|
||||
"""Diagonal Gaussian Distribution
|
||||
|
||||
"""
|
||||
|
||||
def log_prob(self, actions):
|
||||
return super().log_prob(actions).sum(-1, keepdim=True)
|
||||
|
||||
def entropy(self):
|
||||
return super().entropy().sum(-1)
|
Loading…
x
Reference in New Issue
Block a user