This commit is contained in:
Trinkle23897 2020-06-13 17:06:08 +08:00
parent 3774258cc7
commit 5f2f05a570

View File

@ -101,7 +101,8 @@ class SACPolicy(DDPGPolicy):
y = torch.tanh(x)
act = y * self._action_scale + self._action_bias
log_prob = dist.log_prob(x) - torch.log(
self._action_scale * (1 - y.pow(2)) + self.__eps)
self._action_scale * (1 - y.pow(2)) + self.__eps
).sum(-1, keepdim=True)
act = act.clamp(self._range[0], self._range[1])
return Batch(
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)