Fix log_prob and PPO dual_clip (#49)
* Added DiagGaussian to fix log_probg * Disable PPO dual_clip
This commit is contained in:
parent
70122dc03d
commit
57bca16f94
@ -8,6 +8,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
|
|
||||||
from tianshou.env import VectorEnv
|
from tianshou.env import VectorEnv
|
||||||
from tianshou.policy import PPOPolicy
|
from tianshou.policy import PPOPolicy
|
||||||
|
from tianshou.policy.utils import DiagGaussian
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
|
|
||||||
@ -44,7 +45,7 @@ def get_args():
|
|||||||
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
||||||
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
||||||
parser.add_argument('--rew-norm', type=bool, default=True)
|
parser.add_argument('--rew-norm', type=bool, default=True)
|
||||||
parser.add_argument('--dual-clip', type=float, default=5.)
|
# parser.add_argument('--dual-clip', type=float, default=5.)
|
||||||
parser.add_argument('--value-clip', type=bool, default=True)
|
parser.add_argument('--value-clip', type=bool, default=True)
|
||||||
args = parser.parse_known_args()[0]
|
args = parser.parse_known_args()[0]
|
||||||
return args
|
return args
|
||||||
@ -85,7 +86,7 @@ def test_ppo(args=get_args()):
|
|||||||
torch.nn.init.zeros_(m.bias)
|
torch.nn.init.zeros_(m.bias)
|
||||||
optim = torch.optim.Adam(list(
|
optim = torch.optim.Adam(list(
|
||||||
actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
||||||
dist = torch.distributions.Normal
|
dist = DiagGaussian
|
||||||
policy = PPOPolicy(
|
policy = PPOPolicy(
|
||||||
actor, critic, optim, dist, args.gamma,
|
actor, critic, optim, dist, args.gamma,
|
||||||
max_grad_norm=args.max_grad_norm,
|
max_grad_norm=args.max_grad_norm,
|
||||||
@ -93,7 +94,8 @@ def test_ppo(args=get_args()):
|
|||||||
vf_coef=args.vf_coef,
|
vf_coef=args.vf_coef,
|
||||||
ent_coef=args.ent_coef,
|
ent_coef=args.ent_coef,
|
||||||
reward_normalization=args.rew_norm,
|
reward_normalization=args.rew_norm,
|
||||||
dual_clip=args.dual_clip,
|
# dual_clip=args.dual_clip,
|
||||||
|
# dual clip cause monotonically increasing log_std :)
|
||||||
value_clip=args.value_clip,
|
value_clip=args.value_clip,
|
||||||
# action_range=[env.action_space.low[0], env.action_space.high[0]],)
|
# action_range=[env.action_space.low[0], env.action_space.high[0]],)
|
||||||
# if clip the action, ppo would not converge :)
|
# if clip the action, ppo would not converge :)
|
||||||
|
@ -53,7 +53,7 @@ class PPOPolicy(PGPolicy):
|
|||||||
ent_coef: float = .01,
|
ent_coef: float = .01,
|
||||||
action_range: Optional[Tuple[float, float]] = None,
|
action_range: Optional[Tuple[float, float]] = None,
|
||||||
gae_lambda: float = 0.95,
|
gae_lambda: float = 0.95,
|
||||||
dual_clip: float = 5.,
|
dual_clip: float = None,
|
||||||
value_clip: bool = True,
|
value_clip: bool = True,
|
||||||
reward_normalization: bool = True,
|
reward_normalization: bool = True,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
|
@ -6,6 +6,7 @@ from typing import Dict, Tuple, Union, Optional
|
|||||||
|
|
||||||
from tianshou.data import Batch
|
from tianshou.data import Batch
|
||||||
from tianshou.policy import DDPGPolicy
|
from tianshou.policy import DDPGPolicy
|
||||||
|
from tianshou.policy.utils import DiagGaussian
|
||||||
|
|
||||||
|
|
||||||
class SACPolicy(DDPGPolicy):
|
class SACPolicy(DDPGPolicy):
|
||||||
@ -94,13 +95,12 @@ class SACPolicy(DDPGPolicy):
|
|||||||
obs = getattr(batch, input)
|
obs = getattr(batch, input)
|
||||||
logits, h = self.actor(obs, state=state, info=batch.info)
|
logits, h = self.actor(obs, state=state, info=batch.info)
|
||||||
assert isinstance(logits, tuple)
|
assert isinstance(logits, tuple)
|
||||||
dist = torch.distributions.Normal(*logits)
|
dist = DiagGaussian(*logits)
|
||||||
x = dist.rsample()
|
x = dist.rsample()
|
||||||
y = torch.tanh(x)
|
y = torch.tanh(x)
|
||||||
act = y * self._action_scale + self._action_bias
|
act = y * self._action_scale + self._action_bias
|
||||||
log_prob = dist.log_prob(x) - torch.log(
|
log_prob = dist.log_prob(x) - torch.log(
|
||||||
self._action_scale * (1 - y.pow(2)) + self.__eps)
|
self._action_scale * (1 - y.pow(2)) + self.__eps)
|
||||||
log_prob = torch.unsqueeze(torch.sum(log_prob, 1), 1)
|
|
||||||
act = act.clamp(self._range[0], self._range[1])
|
act = act.clamp(self._range[0], self._range[1])
|
||||||
return Batch(
|
return Batch(
|
||||||
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
|
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
|
||||||
|
13
tianshou/policy/utils.py
Normal file
13
tianshou/policy/utils.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class DiagGaussian(torch.distributions.Normal):
|
||||||
|
"""Diagonal Gaussian Distribution
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def log_prob(self, actions):
|
||||||
|
return super().log_prob(actions).sum(-1, keepdim=True)
|
||||||
|
|
||||||
|
def entropy(self):
|
||||||
|
return super().entropy().sum(-1)
|
Loading…
x
Reference in New Issue
Block a user