Fix log_prob in SAC (#41)

This commit is contained in:
nicoguertler 2020-04-28 17:44:15 +02:00 committed by GitHub
parent 69e4b3d301
commit 8f718d9b13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -87,6 +87,7 @@ class SACPolicy(DDPGPolicy):
act = y * self._action_scale + self._action_bias
log_prob = dist.log_prob(x) - torch.log(
self._action_scale * (1 - y.pow(2)) + self.__eps)
log_prob = torch.unsqueeze(torch.sum(log_prob, 1), 1)
act = act.clamp(self._range[0], self._range[1])
return Batch(
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)