Fix action scaling bug in SAC (#591)

close #588
This commit is contained in:
ChenDRAG 2022-04-12 00:26:06 +08:00 committed by GitHub
parent f13e415eb0
commit 75d7c9f1d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -5,7 +5,7 @@ import numpy as np
import torch import torch
from torch.distributions import Independent, Normal from torch.distributions import Independent, Normal
from tianshou.data import Batch, ReplayBuffer, to_torch_as from tianshou.data import Batch, ReplayBuffer
from tianshou.exploration import BaseNoise from tianshou.exploration import BaseNoise
from tianshou.policy import DDPGPolicy from tianshou.policy import DDPGPolicy
@ -121,15 +121,9 @@ class SACPolicy(DDPGPolicy):
# apply correction for Tanh squashing when computing logprob from Gaussian # apply correction for Tanh squashing when computing logprob from Gaussian
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
# in appendix C to get some understanding of this equation. # in appendix C to get some understanding of this equation.
if self.action_scaling and self.action_space is not None:
low, high = self.action_space.low, self.action_space.high # type: ignore
action_scale = to_torch_as((high - low) / 2.0, act)
else:
action_scale = 1.0 # type: ignore
squashed_action = torch.tanh(act) squashed_action = torch.tanh(act)
log_prob = log_prob - torch.log( log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) +
action_scale * (1 - squashed_action.pow(2)) + self.__eps self.__eps).sum(-1, keepdim=True)
).sum(-1, keepdim=True)
return Batch( return Batch(
logits=logits, logits=logits,
act=squashed_action, act=squashed_action,