diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index ac5b3e0..abe707d 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -5,7 +5,7 @@ import numpy as np import torch from torch.distributions import Independent, Normal -from tianshou.data import Batch, ReplayBuffer, to_torch_as +from tianshou.data import Batch, ReplayBuffer from tianshou.exploration import BaseNoise from tianshou.policy import DDPGPolicy @@ -121,15 +121,9 @@ class SACPolicy(DDPGPolicy): # apply correction for Tanh squashing when computing logprob from Gaussian # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # in appendix C to get some understanding of this equation. - if self.action_scaling and self.action_space is not None: - low, high = self.action_space.low, self.action_space.high # type: ignore - action_scale = to_torch_as((high - low) / 2.0, act) - else: - action_scale = 1.0 # type: ignore squashed_action = torch.tanh(act) - log_prob = log_prob - torch.log( - action_scale * (1 - squashed_action.pow(2)) + self.__eps - ).sum(-1, keepdim=True) + log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + + self.__eps).sum(-1, keepdim=True) return Batch( logits=logits, act=squashed_action,