parent
f13e415eb0
commit
75d7c9f1d9
@ -5,7 +5,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch.distributions import Independent, Normal
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.exploration import BaseNoise
|
||||
from tianshou.policy import DDPGPolicy
|
||||
|
||||
@ -121,15 +121,9 @@ class SACPolicy(DDPGPolicy):
|
||||
# apply correction for Tanh squashing when computing logprob from Gaussian
|
||||
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
|
||||
# in appendix C to get some understanding of this equation.
|
||||
if self.action_scaling and self.action_space is not None:
|
||||
low, high = self.action_space.low, self.action_space.high # type: ignore
|
||||
action_scale = to_torch_as((high - low) / 2.0, act)
|
||||
else:
|
||||
action_scale = 1.0 # type: ignore
|
||||
squashed_action = torch.tanh(act)
|
||||
log_prob = log_prob - torch.log(
|
||||
action_scale * (1 - squashed_action.pow(2)) + self.__eps
|
||||
).sum(-1, keepdim=True)
|
||||
log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) +
|
||||
self.__eps).sum(-1, keepdim=True)
|
||||
return Batch(
|
||||
logits=logits,
|
||||
act=squashed_action,
|
||||
|
Loading…
x
Reference in New Issue
Block a user