Fix action scaling bug in SAC (#591)

close #588
This commit is contained in:
ChenDRAG 2022-04-12 00:26:06 +08:00 committed by GitHub
parent f13e415eb0
commit 75d7c9f1d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -5,7 +5,7 @@ import numpy as np
import torch
from torch.distributions import Independent, Normal
from tianshou.data import Batch, ReplayBuffer, to_torch_as
from tianshou.data import Batch, ReplayBuffer
from tianshou.exploration import BaseNoise
from tianshou.policy import DDPGPolicy
@ -121,15 +121,9 @@ class SACPolicy(DDPGPolicy):
# apply correction for Tanh squashing when computing logprob from Gaussian
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
# in appendix C to get some understanding of this equation.
if self.action_scaling and self.action_space is not None:
low, high = self.action_space.low, self.action_space.high # type: ignore
action_scale = to_torch_as((high - low) / 2.0, act)
else:
action_scale = 1.0 # type: ignore
squashed_action = torch.tanh(act)
log_prob = log_prob - torch.log(
action_scale * (1 - squashed_action.pow(2)) + self.__eps
).sum(-1, keepdim=True)
log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) +
self.__eps).sum(-1, keepdim=True)
return Batch(
logits=logits,
act=squashed_action,