Michael Panchenko b900fdf6f2
Remove kwargs in policy init (#950)
Closes #947 

This removes all kwargs from all policy constructors. While doing that,
I also improved several names and added a whole lot of TODOs.

## Functional changes:

1. Added possibility to pass None as `critic2` and `critic2_optim`. In
fact, the default behavior then should cover the absolute majority of
cases
2. Added a function called `clone_optimizer` as a temporary measure to
support passing `critic2_optim=None`

## Breaking changes:

1. `action_space` is no longer optional. In fact, it already was
non-optional, as there was a ValueError in BasePolicy.init. So now
several examples were fixed to reflect that
2. `reward_normalization` removed from DDPG and children. It was never
allowed to pass it as `True` there, an error would have been raised in
`compute_n_step_reward`. Now I removed it from the interface
3. renamed `critic1` and similar to `critic`, in order to have uniform
interfaces. Note that the `critic` in DDPG was optional for the sole
reason that child classes used `critic1`. I removed this optionality
(DDPG can't do anything with `critic=None`)
4. Several renamings of fields (mostly private to public, so backwards
compatible)

## Additional changes: 
1. Removed type and default declaration from docstring. This kind of
duplication is really not necessary
2. Policy constructors are now only called using named arguments, not a
fragile mixture of positional and named as before
5. Minor beautifications in typing and code 
6. Generally shortened docstrings and made them uniform across all
policies (hopefully)

## Comment:

With these changes, several problems in tianshou's inheritance hierarchy
become more apparent. I tried highlighting them for future work.

---------

Co-authored-by: Dominik Jain <d.jain@appliedai.de>
2023-10-08 08:57:03 -07:00

148 lines
6.0 KiB
Python

from copy import deepcopy
from typing import Any, Literal, Self
import gymnasium as gym
import numpy as np
import torch
from tianshou.data import ReplayBuffer
from tianshou.data.types import RolloutBatchProtocol
from tianshou.exploration import BaseNoise
from tianshou.policy import DDPGPolicy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.utils.optim import clone_optimizer
class TD3Policy(DDPGPolicy):
"""Implementation of TD3, arXiv:1802.09477.
:param actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param actor_optim: the optimizer for actor network.
:param critic: the first critic network. (s, a -> Q(s, a))
:param critic_optim: the optimizer for the first critic network.
:param action_space: Env's action space. Should be gym.spaces.Box.
:param critic2: the second critic network. (s, a -> Q(s, a)).
If None, use the same network as critic (via deepcopy).
:param critic2_optim: the optimizer for the second critic network.
If None, clone critic_optim to use for critic2.parameters().
:param tau: param for soft update of the target network.
:param gamma: discount factor, in [0, 1].
:param exploration_noise: add noise to action for exploration.
This is useful when solving "hard exploration" problems.
"default" is equivalent to GaussianNoise(sigma=0.1).
:param policy_noise: the noise used in updating policy network.
:param update_actor_freq: the update frequency of actor network.
:param noise_clip: the clipping range used in updating policy network.
:param observation_space: Env's observation space.
:param action_scaling: if True, scale the action from [-1, 1] to the range
of action_space. Only used if the action_space is continuous.
:param action_bound_method: method to bound action to range [-1, 1].
Only used if the action_space is continuous.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate
in optimizer in each policy.update()
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(
self,
*,
actor: torch.nn.Module,
actor_optim: torch.optim.Optimizer,
critic: torch.nn.Module,
critic_optim: torch.optim.Optimizer,
action_space: gym.Space,
critic2: torch.nn.Module | None = None,
critic2_optim: torch.optim.Optimizer | None = None,
tau: float = 0.005,
gamma: float = 0.99,
exploration_noise: BaseNoise | Literal["default"] | None = "default",
policy_noise: float = 0.2,
update_actor_freq: int = 2,
noise_clip: float = 0.5,
estimation_step: int = 1,
observation_space: gym.Space | None = None,
action_scaling: bool = True,
action_bound_method: Literal["clip"] | None = "clip",
lr_scheduler: TLearningRateScheduler | None = None,
) -> None:
# TODO: reduce duplication with SAC.
# Some intermediate class, like TwoCriticPolicy?
super().__init__(
actor=actor,
actor_optim=actor_optim,
critic=critic,
critic_optim=critic_optim,
action_space=action_space,
tau=tau,
gamma=gamma,
exploration_noise=exploration_noise,
estimation_step=estimation_step,
action_scaling=action_scaling,
action_bound_method=action_bound_method,
observation_space=observation_space,
lr_scheduler=lr_scheduler,
)
if critic2 and not critic2_optim:
raise ValueError("critic2_optim must be provided if critic2 is provided")
critic2 = critic2 or deepcopy(critic)
critic2_optim = critic2_optim or clone_optimizer(critic_optim, critic2.parameters())
self.critic2, self.critic2_old = critic2, deepcopy(critic2)
self.critic2_old.eval()
self.critic2_optim = critic2_optim
self.policy_noise = policy_noise
self.update_actor_freq = update_actor_freq
self.noise_clip = noise_clip
self._cnt = 0
self._last = 0
def train(self, mode: bool = True) -> Self:
self.training = mode
self.actor.train(mode)
self.critic.train(mode)
self.critic2.train(mode)
return self
def sync_weight(self) -> None:
self.soft_update(self.critic_old, self.critic, self.tau)
self.soft_update(self.critic2_old, self.critic2, self.tau)
self.soft_update(self.actor_old, self.actor, self.tau)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs: s_{t+n}
act_ = self(batch, model="actor_old", input="obs_next").act
noise = torch.randn(size=act_.shape, device=act_.device) * self.policy_noise
if self.noise_clip > 0.0:
noise = noise.clamp(-self.noise_clip, self.noise_clip)
act_ += noise
return torch.min(
self.critic_old(batch.obs_next, act_),
self.critic2_old(batch.obs_next, act_),
)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]:
# critic 1&2
td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim)
td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim)
batch.weight = (td1 + td2) / 2.0 # prio-buffer
# actor
if self._cnt % self.update_actor_freq == 0:
actor_loss = -self.critic(batch.obs, self(batch, eps=0.0).act).mean()
self.actor_optim.zero_grad()
actor_loss.backward()
self._last = actor_loss.item()
self.actor_optim.step()
self.sync_weight()
self._cnt += 1
return {
"loss/actor": self._last,
"loss/critic1": critic1_loss.item(),
"loss/critic2": critic2_loss.item(),
}