Fix log_prob in SAC (#41)
This commit is contained in:
parent
69e4b3d301
commit
8f718d9b13
@ -87,6 +87,7 @@ class SACPolicy(DDPGPolicy):
|
||||
act = y * self._action_scale + self._action_bias
|
||||
log_prob = dist.log_prob(x) - torch.log(
|
||||
self._action_scale * (1 - y.pow(2)) + self.__eps)
|
||||
log_prob = torch.unsqueeze(torch.sum(log_prob, 1), 1)
|
||||
act = act.clamp(self._range[0], self._range[1])
|
||||
return Batch(
|
||||
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
|
||||
|
Loading…
x
Reference in New Issue
Block a user