Closes #947 This removes all kwargs from all policy constructors. While doing that, I also improved several names and added a whole lot of TODOs. ## Functional changes: 1. Added possibility to pass None as `critic2` and `critic2_optim`. In fact, the default behavior then should cover the absolute majority of cases 2. Added a function called `clone_optimizer` as a temporary measure to support passing `critic2_optim=None` ## Breaking changes: 1. `action_space` is no longer optional. In fact, it already was non-optional, as there was a ValueError in BasePolicy.init. So now several examples were fixed to reflect that 2. `reward_normalization` removed from DDPG and children. It was never allowed to pass it as `True` there, an error would have been raised in `compute_n_step_reward`. Now I removed it from the interface 3. renamed `critic1` and similar to `critic`, in order to have uniform interfaces. Note that the `critic` in DDPG was optional for the sole reason that child classes used `critic1`. I removed this optionality (DDPG can't do anything with `critic=None`) 4. Several renamings of fields (mostly private to public, so backwards compatible) ## Additional changes: 1. Removed type and default declaration from docstring. This kind of duplication is really not necessary 2. Policy constructors are now only called using named arguments, not a fragile mixture of positional and named as before 5. Minor beautifications in typing and code 6. Generally shortened docstrings and made them uniform across all policies (hopefully) ## Comment: With these changes, several problems in tianshou's inheritance hierarchy become more apparent. I tried highlighting them for future work. --------- Co-authored-by: Dominik Jain <d.jain@appliedai.de>
183 lines
6.7 KiB
Python
183 lines
6.7 KiB
Python
from typing import Any, cast
|
|
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
import torch
|
|
from overrides import override
|
|
from torch.distributions import Categorical
|
|
|
|
from tianshou.data import Batch, ReplayBuffer, to_torch
|
|
from tianshou.data.batch import BatchProtocol
|
|
from tianshou.data.types import RolloutBatchProtocol
|
|
from tianshou.policy import SACPolicy
|
|
from tianshou.policy.base import TLearningRateScheduler
|
|
|
|
|
|
class DiscreteSACPolicy(SACPolicy):
|
|
"""Implementation of SAC for Discrete Action Settings. arXiv:1910.07207.
|
|
|
|
:param actor: the actor network following the rules in
|
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
|
:param actor_optim: the optimizer for actor network.
|
|
:param critic: the first critic network. (s, a -> Q(s, a))
|
|
:param critic_optim: the optimizer for the first critic network.
|
|
:param action_space: Env's action space. Should be gym.spaces.Box.
|
|
:param critic2: the second critic network. (s, a -> Q(s, a)).
|
|
If None, use the same network as critic (via deepcopy).
|
|
:param critic2_optim: the optimizer for the second critic network.
|
|
If None, clone critic_optim to use for critic2.parameters().
|
|
:param tau: param for soft update of the target network.
|
|
:param gamma: discount factor, in [0, 1].
|
|
:param alpha: entropy regularization coefficient.
|
|
If a tuple (target_entropy, log_alpha, alpha_optim) is provided,
|
|
then alpha is automatically tuned.
|
|
:param estimation_step: the number of steps to look ahead for calculating
|
|
:param observation_space: Env's observation space.
|
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate
|
|
in optimizer in each policy.update()
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
|
explanation.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
actor: torch.nn.Module,
|
|
actor_optim: torch.optim.Optimizer,
|
|
critic: torch.nn.Module,
|
|
critic_optim: torch.optim.Optimizer,
|
|
action_space: gym.spaces.Discrete,
|
|
critic2: torch.nn.Module | None = None,
|
|
critic2_optim: torch.optim.Optimizer | None = None,
|
|
tau: float = 0.005,
|
|
gamma: float = 0.99,
|
|
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2,
|
|
estimation_step: int = 1,
|
|
observation_space: gym.Space | None = None,
|
|
lr_scheduler: TLearningRateScheduler | None = None,
|
|
) -> None:
|
|
super().__init__(
|
|
actor=actor,
|
|
actor_optim=actor_optim,
|
|
critic=critic,
|
|
critic_optim=critic_optim,
|
|
action_space=action_space,
|
|
critic2=critic2,
|
|
critic2_optim=critic2_optim,
|
|
tau=tau,
|
|
gamma=gamma,
|
|
alpha=alpha,
|
|
estimation_step=estimation_step,
|
|
# Note: inheriting from continuous sac reduces code duplication,
|
|
# but continuous stuff has to be disabled
|
|
exploration_noise=None,
|
|
action_scaling=False,
|
|
action_bound_method=None,
|
|
observation_space=observation_space,
|
|
lr_scheduler=lr_scheduler,
|
|
)
|
|
|
|
# TODO: violates Liskov substitution principle, incompatible action space with SAC
|
|
# Not too urgent, but still..
|
|
@override
|
|
def _check_field_validity(self) -> None:
|
|
if not isinstance(self.action_space, gym.spaces.Discrete):
|
|
raise ValueError(
|
|
f"DiscreteSACPolicy only supports gym.spaces.Discrete, but got {self.action_space=}."
|
|
f"Please use SACPolicy for continuous action spaces.",
|
|
)
|
|
|
|
def forward( # type: ignore
|
|
self,
|
|
batch: Batch,
|
|
state: dict | Batch | np.ndarray | None = None,
|
|
input: str = "obs",
|
|
**kwargs: Any,
|
|
) -> Batch:
|
|
obs = batch[input]
|
|
logits, hidden = self.actor(obs, state=state, info=batch.info)
|
|
dist = Categorical(logits=logits)
|
|
if self.deterministic_eval and not self.training:
|
|
act = logits.argmax(axis=-1)
|
|
else:
|
|
act = dist.sample()
|
|
return Batch(logits=logits, act=act, state=hidden, dist=dist)
|
|
|
|
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
|
batch = buffer[indices] # batch.obs: s_{t+n}
|
|
obs_next_result = self(batch, input="obs_next")
|
|
dist = obs_next_result.dist
|
|
target_q = dist.probs * torch.min(
|
|
self.critic_old(batch.obs_next),
|
|
self.critic2_old(batch.obs_next),
|
|
)
|
|
return target_q.sum(dim=-1) + self.alpha * dist.entropy()
|
|
|
|
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]:
|
|
weight = batch.pop("weight", 1.0)
|
|
target_q = batch.returns.flatten()
|
|
act = to_torch(batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long)
|
|
|
|
# critic 1
|
|
current_q1 = self.critic(batch.obs).gather(1, act).flatten()
|
|
td1 = current_q1 - target_q
|
|
critic1_loss = (td1.pow(2) * weight).mean()
|
|
|
|
self.critic_optim.zero_grad()
|
|
critic1_loss.backward()
|
|
self.critic_optim.step()
|
|
|
|
# critic 2
|
|
current_q2 = self.critic2(batch.obs).gather(1, act).flatten()
|
|
td2 = current_q2 - target_q
|
|
critic2_loss = (td2.pow(2) * weight).mean()
|
|
|
|
self.critic2_optim.zero_grad()
|
|
critic2_loss.backward()
|
|
self.critic2_optim.step()
|
|
batch.weight = (td1 + td2) / 2.0 # prio-buffer
|
|
|
|
# actor
|
|
dist = self(batch).dist
|
|
entropy = dist.entropy()
|
|
with torch.no_grad():
|
|
current_q1a = self.critic(batch.obs)
|
|
current_q2a = self.critic2(batch.obs)
|
|
q = torch.min(current_q1a, current_q2a)
|
|
actor_loss = -(self.alpha * entropy + (dist.probs * q).sum(dim=-1)).mean()
|
|
self.actor_optim.zero_grad()
|
|
actor_loss.backward()
|
|
self.actor_optim.step()
|
|
|
|
if self.is_auto_alpha:
|
|
log_prob = -entropy.detach() + self.target_entropy
|
|
alpha_loss = -(self.log_alpha * log_prob).mean()
|
|
self.alpha_optim.zero_grad()
|
|
alpha_loss.backward()
|
|
self.alpha_optim.step()
|
|
self.alpha = self.log_alpha.detach().exp()
|
|
|
|
self.sync_weight()
|
|
|
|
result = {
|
|
"loss/actor": actor_loss.item(),
|
|
"loss/critic1": critic1_loss.item(),
|
|
"loss/critic2": critic2_loss.item(),
|
|
}
|
|
if self.is_auto_alpha:
|
|
self.alpha = cast(torch.Tensor, self.alpha)
|
|
result["loss/alpha"] = alpha_loss.item()
|
|
result["alpha"] = self.alpha.item()
|
|
|
|
return result
|
|
|
|
def exploration_noise(
|
|
self,
|
|
act: np.ndarray | BatchProtocol,
|
|
batch: RolloutBatchProtocol,
|
|
) -> np.ndarray | BatchProtocol:
|
|
return act
|