diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index d701aca..3a55fba 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -87,6 +87,7 @@ class SACPolicy(DDPGPolicy): act = y * self._action_scale + self._action_bias log_prob = dist.log_prob(x) - torch.log( self._action_scale * (1 - y.pow(2)) + self.__eps) + log_prob = torch.unsqueeze(torch.sum(log_prob, 1), 1) act = act.clamp(self._range[0], self._range[1]) return Batch( logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)