parent
f13e415eb0
commit
75d7c9f1d9
@ -5,7 +5,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch.distributions import Independent, Normal
|
from torch.distributions import Independent, Normal
|
||||||
|
|
||||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
from tianshou.data import Batch, ReplayBuffer
|
||||||
from tianshou.exploration import BaseNoise
|
from tianshou.exploration import BaseNoise
|
||||||
from tianshou.policy import DDPGPolicy
|
from tianshou.policy import DDPGPolicy
|
||||||
|
|
||||||
@ -121,15 +121,9 @@ class SACPolicy(DDPGPolicy):
|
|||||||
# apply correction for Tanh squashing when computing logprob from Gaussian
|
# apply correction for Tanh squashing when computing logprob from Gaussian
|
||||||
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
|
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
|
||||||
# in appendix C to get some understanding of this equation.
|
# in appendix C to get some understanding of this equation.
|
||||||
if self.action_scaling and self.action_space is not None:
|
|
||||||
low, high = self.action_space.low, self.action_space.high # type: ignore
|
|
||||||
action_scale = to_torch_as((high - low) / 2.0, act)
|
|
||||||
else:
|
|
||||||
action_scale = 1.0 # type: ignore
|
|
||||||
squashed_action = torch.tanh(act)
|
squashed_action = torch.tanh(act)
|
||||||
log_prob = log_prob - torch.log(
|
log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) +
|
||||||
action_scale * (1 - squashed_action.pow(2)) + self.__eps
|
self.__eps).sum(-1, keepdim=True)
|
||||||
).sum(-1, keepdim=True)
|
|
||||||
return Batch(
|
return Batch(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
act=squashed_action,
|
act=squashed_action,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user