185 lines
7.8 KiB
Python
185 lines
7.8 KiB
Python
import torch
|
|
import numpy as np
|
|
from copy import deepcopy
|
|
from torch.distributions import Independent, Normal
|
|
from typing import Any, Dict, Tuple, Union, Optional
|
|
|
|
from tianshou.policy import DDPGPolicy
|
|
from tianshou.exploration import BaseNoise
|
|
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
|
|
|
|
|
class SACPolicy(DDPGPolicy):
|
|
"""Implementation of Soft Actor-Critic. arXiv:1812.05905.
|
|
|
|
:param torch.nn.Module actor: the actor network following the rules in
|
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
|
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
|
|
:param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a))
|
|
:param torch.optim.Optimizer critic1_optim: the optimizer for the first
|
|
critic network.
|
|
:param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a))
|
|
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
|
|
critic network.
|
|
:param float tau: param for soft update of the target network. Default to 0.005.
|
|
:param float gamma: discount factor, in [0, 1]. Default to 0.99.
|
|
:param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy
|
|
regularization coefficient. Default to 0.2.
|
|
If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then
|
|
alpha is automatatically tuned.
|
|
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
|
Default to False.
|
|
:param BaseNoise exploration_noise: add a noise to action for exploration.
|
|
Default to None. This is useful when solving hard-exploration problem.
|
|
:param bool deterministic_eval: whether to use deterministic action (mean
|
|
of Gaussian policy) instead of stochastic action sampled by the policy.
|
|
Default to True.
|
|
:param bool action_scaling: whether to map actions from range [-1, 1] to range
|
|
[action_spaces.low, action_spaces.high]. Default to True.
|
|
:param str action_bound_method: method to bound action to range [-1, 1], can be
|
|
either "clip" (for simply clipping the action) or empty string for no bounding.
|
|
Default to "clip".
|
|
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
|
to use option "action_scaling" or "action_bound_method". Default to None.
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
|
explanation.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
actor: torch.nn.Module,
|
|
actor_optim: torch.optim.Optimizer,
|
|
critic1: torch.nn.Module,
|
|
critic1_optim: torch.optim.Optimizer,
|
|
critic2: torch.nn.Module,
|
|
critic2_optim: torch.optim.Optimizer,
|
|
tau: float = 0.005,
|
|
gamma: float = 0.99,
|
|
alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2,
|
|
reward_normalization: bool = False,
|
|
estimation_step: int = 1,
|
|
exploration_noise: Optional[BaseNoise] = None,
|
|
deterministic_eval: bool = True,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
super().__init__(
|
|
None, None, None, None, tau, gamma, exploration_noise,
|
|
reward_normalization, estimation_step, **kwargs)
|
|
self.actor, self.actor_optim = actor, actor_optim
|
|
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
|
self.critic1_old.eval()
|
|
self.critic1_optim = critic1_optim
|
|
self.critic2, self.critic2_old = critic2, deepcopy(critic2)
|
|
self.critic2_old.eval()
|
|
self.critic2_optim = critic2_optim
|
|
|
|
self._is_auto_alpha = False
|
|
self._alpha: Union[float, torch.Tensor]
|
|
if isinstance(alpha, tuple):
|
|
self._is_auto_alpha = True
|
|
self._target_entropy, self._log_alpha, self._alpha_optim = alpha
|
|
assert alpha[1].shape == torch.Size([1]) and alpha[1].requires_grad
|
|
self._alpha = self._log_alpha.detach().exp()
|
|
else:
|
|
self._alpha = alpha
|
|
|
|
self._deterministic_eval = deterministic_eval
|
|
self.__eps = np.finfo(np.float32).eps.item()
|
|
|
|
def train(self, mode: bool = True) -> "SACPolicy":
|
|
self.training = mode
|
|
self.actor.train(mode)
|
|
self.critic1.train(mode)
|
|
self.critic2.train(mode)
|
|
return self
|
|
|
|
def sync_weight(self) -> None:
|
|
for o, n in zip(self.critic1_old.parameters(), self.critic1.parameters()):
|
|
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
|
|
for o, n in zip(self.critic2_old.parameters(), self.critic2.parameters()):
|
|
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
|
|
|
|
def forward( # type: ignore
|
|
self,
|
|
batch: Batch,
|
|
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
|
input: str = "obs",
|
|
**kwargs: Any,
|
|
) -> Batch:
|
|
obs = batch[input]
|
|
logits, h = self.actor(obs, state=state, info=batch.info)
|
|
assert isinstance(logits, tuple)
|
|
dist = Independent(Normal(*logits), 1)
|
|
if self._deterministic_eval and not self.training:
|
|
act = logits[0]
|
|
else:
|
|
act = dist.rsample()
|
|
log_prob = dist.log_prob(act).unsqueeze(-1)
|
|
# apply correction for Tanh squashing when computing logprob from Gaussian
|
|
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
|
|
# in appendix C to get some understanding of this equation.
|
|
if self.action_scaling and self.action_space is not None:
|
|
action_scale = to_torch_as(
|
|
(self.action_space.high - self.action_space.low) / 2.0, act)
|
|
else:
|
|
action_scale = 1.0 # type: ignore
|
|
squashed_action = torch.tanh(act)
|
|
log_prob = log_prob - torch.log(
|
|
action_scale * (1 - squashed_action.pow(2)) + self.__eps
|
|
).sum(-1, keepdim=True)
|
|
return Batch(logits=logits, act=squashed_action,
|
|
state=h, dist=dist, log_prob=log_prob)
|
|
|
|
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
|
batch = buffer[indices] # batch.obs: s_{t+n}
|
|
obs_next_result = self(batch, input='obs_next')
|
|
a_ = obs_next_result.act
|
|
target_q = torch.min(
|
|
self.critic1_old(batch.obs_next, a_),
|
|
self.critic2_old(batch.obs_next, a_),
|
|
) - self._alpha * obs_next_result.log_prob
|
|
return target_q
|
|
|
|
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
|
# critic 1&2
|
|
td1, critic1_loss = self._mse_optimizer(
|
|
batch, self.critic1, self.critic1_optim)
|
|
td2, critic2_loss = self._mse_optimizer(
|
|
batch, self.critic2, self.critic2_optim)
|
|
batch.weight = (td1 + td2) / 2.0 # prio-buffer
|
|
|
|
# actor
|
|
obs_result = self(batch)
|
|
a = obs_result.act
|
|
current_q1a = self.critic1(batch.obs, a).flatten()
|
|
current_q2a = self.critic2(batch.obs, a).flatten()
|
|
actor_loss = (self._alpha * obs_result.log_prob.flatten()
|
|
- torch.min(current_q1a, current_q2a)).mean()
|
|
self.actor_optim.zero_grad()
|
|
actor_loss.backward()
|
|
self.actor_optim.step()
|
|
|
|
if self._is_auto_alpha:
|
|
log_prob = obs_result.log_prob.detach() + self._target_entropy
|
|
alpha_loss = -(self._log_alpha * log_prob).mean()
|
|
self._alpha_optim.zero_grad()
|
|
alpha_loss.backward()
|
|
self._alpha_optim.step()
|
|
self._alpha = self._log_alpha.detach().exp()
|
|
|
|
self.sync_weight()
|
|
|
|
result = {
|
|
"loss/actor": actor_loss.item(),
|
|
"loss/critic1": critic1_loss.item(),
|
|
"loss/critic2": critic2_loss.item(),
|
|
}
|
|
if self._is_auto_alpha:
|
|
result["loss/alpha"] = alpha_loss.item()
|
|
result["alpha"] = self._alpha.item() # type: ignore
|
|
|
|
return result
|