From 8f718d9b13c2c6dc3bb6382e371461dfa41c5ccd Mon Sep 17 00:00:00 2001 From: nicoguertler <59915837+nicoguertler@users.noreply.github.com> Date: Tue, 28 Apr 2020 17:44:15 +0200 Subject: [PATCH] Fix log_prob in SAC (#41) --- tianshou/policy/modelfree/sac.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index d701aca..3a55fba 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -87,6 +87,7 @@ class SACPolicy(DDPGPolicy): act = y * self._action_scale + self._action_bias log_prob = dist.log_prob(x) - torch.log( self._action_scale * (1 - y.pow(2)) + self.__eps) + log_prob = torch.unsqueeze(torch.sum(log_prob, 1), 1) act = act.clamp(self._range[0], self._range[1]) return Batch( logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)