* add makefile * bump version * add isort and yapf * update contributing.md * update PR template * spelling check
157 lines
5.5 KiB
Python
157 lines
5.5 KiB
Python
from typing import Any, Dict, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.distributions import Categorical
|
|
|
|
from tianshou.data import Batch, ReplayBuffer, to_torch
|
|
from tianshou.policy import SACPolicy
|
|
|
|
|
|
class DiscreteSACPolicy(SACPolicy):
|
|
"""Implementation of SAC for Discrete Action Settings. arXiv:1910.07207.
|
|
|
|
:param torch.nn.Module actor: the actor network following the rules in
|
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
|
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
|
|
:param torch.nn.Module critic1: the first critic network. (s -> Q(s))
|
|
:param torch.optim.Optimizer critic1_optim: the optimizer for the first
|
|
critic network.
|
|
:param torch.nn.Module critic2: the second critic network. (s -> Q(s))
|
|
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
|
|
critic network.
|
|
:param float tau: param for soft update of the target network. Default to 0.005.
|
|
:param float gamma: discount factor, in [0, 1]. Default to 0.99.
|
|
:param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy
|
|
regularization coefficient. Default to 0.2.
|
|
If a tuple (target_entropy, log_alpha, alpha_optim) is provided, the
|
|
alpha is automatically tuned.
|
|
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
|
Default to False.
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
|
explanation.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
actor: torch.nn.Module,
|
|
actor_optim: torch.optim.Optimizer,
|
|
critic1: torch.nn.Module,
|
|
critic1_optim: torch.optim.Optimizer,
|
|
critic2: torch.nn.Module,
|
|
critic2_optim: torch.optim.Optimizer,
|
|
tau: float = 0.005,
|
|
gamma: float = 0.99,
|
|
alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2,
|
|
reward_normalization: bool = False,
|
|
estimation_step: int = 1,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
super().__init__(
|
|
actor,
|
|
actor_optim,
|
|
critic1,
|
|
critic1_optim,
|
|
critic2,
|
|
critic2_optim,
|
|
tau,
|
|
gamma,
|
|
alpha,
|
|
reward_normalization,
|
|
estimation_step,
|
|
action_scaling=False,
|
|
action_bound_method="",
|
|
**kwargs
|
|
)
|
|
self._alpha: Union[float, torch.Tensor]
|
|
|
|
def forward( # type: ignore
|
|
self,
|
|
batch: Batch,
|
|
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
|
input: str = "obs",
|
|
**kwargs: Any,
|
|
) -> Batch:
|
|
obs = batch[input]
|
|
logits, h = self.actor(obs, state=state, info=batch.info)
|
|
dist = Categorical(logits=logits)
|
|
act = dist.sample()
|
|
return Batch(logits=logits, act=act, state=h, dist=dist)
|
|
|
|
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
|
batch = buffer[indices] # batch.obs: s_{t+n}
|
|
obs_next_result = self(batch, input="obs_next")
|
|
dist = obs_next_result.dist
|
|
target_q = dist.probs * torch.min(
|
|
self.critic1_old(batch.obs_next),
|
|
self.critic2_old(batch.obs_next),
|
|
)
|
|
target_q = target_q.sum(dim=-1) + self._alpha * dist.entropy()
|
|
return target_q
|
|
|
|
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
|
weight = batch.pop("weight", 1.0)
|
|
target_q = batch.returns.flatten()
|
|
act = to_torch(
|
|
batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long
|
|
)
|
|
|
|
# critic 1
|
|
current_q1 = self.critic1(batch.obs).gather(1, act).flatten()
|
|
td1 = current_q1 - target_q
|
|
critic1_loss = (td1.pow(2) * weight).mean()
|
|
|
|
self.critic1_optim.zero_grad()
|
|
critic1_loss.backward()
|
|
self.critic1_optim.step()
|
|
|
|
# critic 2
|
|
current_q2 = self.critic2(batch.obs).gather(1, act).flatten()
|
|
td2 = current_q2 - target_q
|
|
critic2_loss = (td2.pow(2) * weight).mean()
|
|
|
|
self.critic2_optim.zero_grad()
|
|
critic2_loss.backward()
|
|
self.critic2_optim.step()
|
|
batch.weight = (td1 + td2) / 2.0 # prio-buffer
|
|
|
|
# actor
|
|
dist = self(batch).dist
|
|
entropy = dist.entropy()
|
|
with torch.no_grad():
|
|
current_q1a = self.critic1(batch.obs)
|
|
current_q2a = self.critic2(batch.obs)
|
|
q = torch.min(current_q1a, current_q2a)
|
|
actor_loss = -(self._alpha * entropy + (dist.probs * q).sum(dim=-1)).mean()
|
|
self.actor_optim.zero_grad()
|
|
actor_loss.backward()
|
|
self.actor_optim.step()
|
|
|
|
if self._is_auto_alpha:
|
|
log_prob = -entropy.detach() + self._target_entropy
|
|
alpha_loss = -(self._log_alpha * log_prob).mean()
|
|
self._alpha_optim.zero_grad()
|
|
alpha_loss.backward()
|
|
self._alpha_optim.step()
|
|
self._alpha = self._log_alpha.detach().exp()
|
|
|
|
self.sync_weight()
|
|
|
|
result = {
|
|
"loss/actor": actor_loss.item(),
|
|
"loss/critic1": critic1_loss.item(),
|
|
"loss/critic2": critic2_loss.item(),
|
|
}
|
|
if self._is_auto_alpha:
|
|
result["loss/alpha"] = alpha_loss.item()
|
|
result["alpha"] = self._alpha.item() # type: ignore
|
|
|
|
return result
|
|
|
|
def exploration_noise(self, act: Union[np.ndarray, Batch],
|
|
batch: Batch) -> Union[np.ndarray, Batch]:
|
|
return act
|