Closes #947 This removes all kwargs from all policy constructors. While doing that, I also improved several names and added a whole lot of TODOs. ## Functional changes: 1. Added possibility to pass None as `critic2` and `critic2_optim`. In fact, the default behavior then should cover the absolute majority of cases 2. Added a function called `clone_optimizer` as a temporary measure to support passing `critic2_optim=None` ## Breaking changes: 1. `action_space` is no longer optional. In fact, it already was non-optional, as there was a ValueError in BasePolicy.init. So now several examples were fixed to reflect that 2. `reward_normalization` removed from DDPG and children. It was never allowed to pass it as `True` there, an error would have been raised in `compute_n_step_reward`. Now I removed it from the interface 3. renamed `critic1` and similar to `critic`, in order to have uniform interfaces. Note that the `critic` in DDPG was optional for the sole reason that child classes used `critic1`. I removed this optionality (DDPG can't do anything with `critic=None`) 4. Several renamings of fields (mostly private to public, so backwards compatible) ## Additional changes: 1. Removed type and default declaration from docstring. This kind of duplication is really not necessary 2. Policy constructors are now only called using named arguments, not a fragile mixture of positional and named as before 5. Minor beautifications in typing and code 6. Generally shortened docstrings and made them uniform across all policies (hopefully) ## Comment: With these changes, several problems in tianshou's inheritance hierarchy become more apparent. I tried highlighting them for future work. --------- Co-authored-by: Dominik Jain <d.jain@appliedai.de>
140 lines
5.6 KiB
Python
140 lines
5.6 KiB
Python
from copy import deepcopy
|
|
from typing import Any, Literal
|
|
|
|
import gymnasium as gym
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.distributions import Categorical
|
|
|
|
from tianshou.data import to_torch, to_torch_as
|
|
from tianshou.data.types import RolloutBatchProtocol
|
|
from tianshou.policy.base import TLearningRateScheduler
|
|
from tianshou.policy.modelfree.pg import PGPolicy
|
|
|
|
|
|
class DiscreteCRRPolicy(PGPolicy):
|
|
r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134.
|
|
|
|
:param actor: the actor network following the rules in
|
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
|
:param critic: the action-value critic (i.e., Q function)
|
|
network. (s -> Q(s, \*))
|
|
:param optim: a torch.optim for optimizing the model.
|
|
:param discount_factor: in [0, 1].
|
|
:param str policy_improvement_mode: type of the weight function f. Possible
|
|
values: "binary"/"exp"/"all".
|
|
:param ratio_upper_bound: when policy_improvement_mode is "exp", the value
|
|
of the exp function is upper-bounded by this parameter.
|
|
:param beta: when policy_improvement_mode is "exp", this is the denominator
|
|
of the exp function.
|
|
:param min_q_weight: weight for CQL loss/regularizer. Default to 10.
|
|
:param target_update_freq: the target network update frequency (0 if
|
|
you do not use the target network).
|
|
:param reward_normalization: if True, will normalize the *returns*
|
|
by subtracting the running mean and dividing by the running standard deviation.
|
|
Can be detrimental to performance! See TODO in process_fn.
|
|
:param observation_space: Env's observation space.
|
|
:param lr_scheduler: if not None, will be called in `policy.update()`.
|
|
|
|
.. seealso::
|
|
Please refer to :class:`~tianshou.policy.PGPolicy` for more detailed
|
|
explanation.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
actor: torch.nn.Module,
|
|
critic: torch.nn.Module,
|
|
optim: torch.optim.Optimizer,
|
|
action_space: gym.spaces.Discrete,
|
|
discount_factor: float = 0.99,
|
|
policy_improvement_mode: Literal["exp", "binary", "all"] = "exp",
|
|
ratio_upper_bound: float = 20.0,
|
|
beta: float = 1.0,
|
|
min_q_weight: float = 10.0,
|
|
target_update_freq: int = 0,
|
|
reward_normalization: bool = False,
|
|
observation_space: gym.Space | None = None,
|
|
lr_scheduler: TLearningRateScheduler | None = None,
|
|
) -> None:
|
|
super().__init__(
|
|
actor=actor,
|
|
optim=optim,
|
|
action_space=action_space,
|
|
dist_fn=lambda x: Categorical(logits=x),
|
|
discount_factor=discount_factor,
|
|
reward_normalization=reward_normalization,
|
|
observation_space=observation_space,
|
|
action_scaling=False,
|
|
action_bound_method=None,
|
|
lr_scheduler=lr_scheduler,
|
|
)
|
|
self.critic = critic
|
|
self._target = target_update_freq > 0
|
|
self._freq = target_update_freq
|
|
self._iter = 0
|
|
if self._target:
|
|
self.actor_old = deepcopy(self.actor)
|
|
self.actor_old.eval()
|
|
self.critic_old = deepcopy(self.critic)
|
|
self.critic_old.eval()
|
|
else:
|
|
self.actor_old = self.actor
|
|
self.critic_old = self.critic
|
|
self._policy_improvement_mode = policy_improvement_mode
|
|
self._ratio_upper_bound = ratio_upper_bound
|
|
self._beta = beta
|
|
self._min_q_weight = min_q_weight
|
|
|
|
def sync_weight(self) -> None:
|
|
self.actor_old.load_state_dict(self.actor.state_dict())
|
|
self.critic_old.load_state_dict(self.critic.state_dict())
|
|
|
|
def learn( # type: ignore
|
|
self,
|
|
batch: RolloutBatchProtocol,
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
) -> dict[str, float]:
|
|
if self._target and self._iter % self._freq == 0:
|
|
self.sync_weight()
|
|
self.optim.zero_grad()
|
|
q_t = self.critic(batch.obs)
|
|
act = to_torch(batch.act, dtype=torch.long, device=q_t.device)
|
|
qa_t = q_t.gather(1, act.unsqueeze(1))
|
|
# Critic loss
|
|
with torch.no_grad():
|
|
target_a_t, _ = self.actor_old(batch.obs_next)
|
|
target_m = Categorical(logits=target_a_t)
|
|
q_t_target = self.critic_old(batch.obs_next)
|
|
rew = to_torch_as(batch.rew, q_t_target)
|
|
expected_target_q = (q_t_target * target_m.probs).sum(-1, keepdim=True)
|
|
expected_target_q[batch.done > 0] = 0.0
|
|
target = rew.unsqueeze(1) + self.gamma * expected_target_q
|
|
critic_loss = 0.5 * F.mse_loss(qa_t, target)
|
|
# Actor loss
|
|
act_target, _ = self.actor(batch.obs)
|
|
dist = Categorical(logits=act_target)
|
|
expected_policy_q = (q_t * dist.probs).sum(-1, keepdim=True)
|
|
advantage = qa_t - expected_policy_q
|
|
if self._policy_improvement_mode == "binary":
|
|
actor_loss_coef = (advantage > 0).float()
|
|
elif self._policy_improvement_mode == "exp":
|
|
actor_loss_coef = (advantage / self._beta).exp().clamp(0, self._ratio_upper_bound)
|
|
else:
|
|
actor_loss_coef = 1.0 # effectively behavior cloning
|
|
actor_loss = (-dist.log_prob(act) * actor_loss_coef).mean()
|
|
# CQL loss/regularizer
|
|
min_q_loss = (q_t.logsumexp(1) - qa_t).mean()
|
|
loss = actor_loss + critic_loss + self._min_q_weight * min_q_loss
|
|
loss.backward()
|
|
self.optim.step()
|
|
self._iter += 1
|
|
return {
|
|
"loss": loss.item(),
|
|
"loss/actor": actor_loss.item(),
|
|
"loss/critic": critic_loss.item(),
|
|
"loss/cql": min_q_loss.item(),
|
|
}
|