Michael Panchenko b900fdf6f2
Remove kwargs in policy init (#950)
Closes #947 

This removes all kwargs from all policy constructors. While doing that,
I also improved several names and added a whole lot of TODOs.

## Functional changes:

1. Added possibility to pass None as `critic2` and `critic2_optim`. In
fact, the default behavior then should cover the absolute majority of
cases
2. Added a function called `clone_optimizer` as a temporary measure to
support passing `critic2_optim=None`

## Breaking changes:

1. `action_space` is no longer optional. In fact, it already was
non-optional, as there was a ValueError in BasePolicy.init. So now
several examples were fixed to reflect that
2. `reward_normalization` removed from DDPG and children. It was never
allowed to pass it as `True` there, an error would have been raised in
`compute_n_step_reward`. Now I removed it from the interface
3. renamed `critic1` and similar to `critic`, in order to have uniform
interfaces. Note that the `critic` in DDPG was optional for the sole
reason that child classes used `critic1`. I removed this optionality
(DDPG can't do anything with `critic=None`)
4. Several renamings of fields (mostly private to public, so backwards
compatible)

## Additional changes: 
1. Removed type and default declaration from docstring. This kind of
duplication is really not necessary
2. Policy constructors are now only called using named arguments, not a
fragile mixture of positional and named as before
5. Minor beautifications in typing and code 
6. Generally shortened docstrings and made them uniform across all
policies (hopefully)

## Comment:

With these changes, several problems in tianshou's inheritance hierarchy
become more apparent. I tried highlighting them for future work.

---------

Co-authored-by: Dominik Jain <d.jain@appliedai.de>
2023-10-08 08:57:03 -07:00

205 lines
8.0 KiB
Python

import warnings
from copy import deepcopy
from typing import Any, Literal, Self
import gymnasium as gym
import numpy as np
import torch
from tianshou.data import Batch, ReplayBuffer
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol
from tianshou.exploration import BaseNoise, GaussianNoise
from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler
class DDPGPolicy(BasePolicy):
"""Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971.
:param actor: The actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> model_output)
:param actor_optim: The optimizer for actor network.
:param critic: The critic network. (s, a -> Q(s, a))
:param critic_optim: The optimizer for critic network.
:param action_space: Env's action space.
:param tau: Param for soft update of the target network.
:param gamma: Discount factor, in [0, 1].
:param exploration_noise: The exploration noise, added to the action. Defaults
to ``GaussianNoise(sigma=0.1)``.
:param estimation_step: The number of steps to look ahead.
:param observation_space: Env's observation space.
:param action_scaling: if True, scale the action from [-1, 1] to the range
of action_space. Only used if the action_space is continuous.
:param action_bound_method: method to bound action to range [-1, 1].
Only used if the action_space is continuous.
:param lr_scheduler: if not None, will be called in `policy.update()`.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(
self,
*,
actor: torch.nn.Module,
actor_optim: torch.optim.Optimizer,
critic: torch.nn.Module,
critic_optim: torch.optim.Optimizer,
action_space: gym.Space,
tau: float = 0.005,
gamma: float = 0.99,
exploration_noise: BaseNoise | Literal["default"] | None = "default",
estimation_step: int = 1,
observation_space: gym.Space | None = None,
action_scaling: bool = True,
# tanh not supported, see assert below
action_bound_method: Literal["clip"] | None = "clip",
lr_scheduler: TLearningRateScheduler | None = None,
) -> None:
assert 0.0 <= tau <= 1.0, f"tau should be in [0, 1] but got: {tau}"
assert 0.0 <= gamma <= 1.0, f"gamma should be in [0, 1] but got: {gamma}"
assert action_bound_method != "tanh", ( # type: ignore[comparison-overlap]
"tanh mapping is not supported"
"in policies where action is used as input of critic , because"
"raw action in range (-inf, inf) will cause instability in training"
)
super().__init__(
action_space=action_space,
observation_space=observation_space,
action_scaling=action_scaling,
action_bound_method=action_bound_method,
lr_scheduler=lr_scheduler,
)
if action_scaling and not np.isclose(actor.max_action, 1.0): # type: ignore
warnings.warn(
"action_scaling and action_bound_method are only intended to deal"
"with unbounded model action space, but find actor model bound"
f"action space with max_action={actor.max_action}."
"Consider using unbounded=True option of the actor model,"
"or set action_scaling to False and action_bound_method to None.",
)
self.actor = actor
self.actor_old = deepcopy(actor)
self.actor_old.eval()
self.actor_optim = actor_optim
self.critic = critic
self.critic_old = deepcopy(critic)
self.critic_old.eval()
self.critic_optim = critic_optim
self.tau = tau
self.gamma = gamma
if exploration_noise == "default":
exploration_noise = GaussianNoise(sigma=0.1)
# TODO: IMPORTANT - can't call this "exploration_noise" because confusingly,
# there is already a method called exploration_noise() in the base class
# Now this method doesn't apply any noise and is also not overridden. See TODO there
self._exploration_noise = exploration_noise
# it is only a little difference to use GaussianNoise
# self.noise = OUNoise()
self.estimation_step = estimation_step
def set_exp_noise(self, noise: BaseNoise | None) -> None:
"""Set the exploration noise."""
self._exploration_noise = noise
def train(self, mode: bool = True) -> Self:
"""Set the module in training mode, except for the target network."""
self.training = mode
self.actor.train(mode)
self.critic.train(mode)
return self
def sync_weight(self) -> None:
"""Soft-update the weight for the target network."""
self.soft_update(self.actor_old, self.actor, self.tau)
self.soft_update(self.critic_old, self.critic, self.tau)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs_next: s_{t+n}
return self.critic_old(batch.obs_next, self(batch, model="actor_old", input="obs_next").act)
def process_fn(
self,
batch: RolloutBatchProtocol,
buffer: ReplayBuffer,
indices: np.ndarray,
) -> RolloutBatchProtocol | BatchWithReturnsProtocol:
return self.compute_nstep_return(
batch=batch,
buffer=buffer,
indices=indices,
target_q_fn=self._target_q,
gamma=self.gamma,
n_step=self.estimation_step,
)
def forward(
self,
batch: RolloutBatchProtocol,
state: dict | BatchProtocol | np.ndarray | None = None,
model: Literal["actor", "actor_old"] = "actor",
input: str = "obs",
**kwargs: Any,
) -> BatchProtocol:
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 2 keys:
* ``act`` the action.
* ``state`` the hidden state.
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
model = getattr(self, model)
obs = batch[input]
actions, hidden = model(obs, state=state, info=batch.info)
return Batch(act=actions, state=hidden)
@staticmethod
def _mse_optimizer(
batch: RolloutBatchProtocol,
critic: torch.nn.Module,
optimizer: torch.optim.Optimizer,
) -> tuple[torch.Tensor, torch.Tensor]:
"""A simple wrapper script for updating critic network."""
weight = getattr(batch, "weight", 1.0)
current_q = critic(batch.obs, batch.act).flatten()
target_q = batch.returns.flatten()
td = current_q - target_q
# critic_loss = F.mse_loss(current_q1, target_q)
critic_loss = (td.pow(2) * weight).mean()
optimizer.zero_grad()
critic_loss.backward()
optimizer.step()
return td, critic_loss
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]:
# critic
td, critic_loss = self._mse_optimizer(batch, self.critic, self.critic_optim)
batch.weight = td # prio-buffer
# actor
actor_loss = -self.critic(batch.obs, self(batch).act).mean()
self.actor_optim.zero_grad()
actor_loss.backward()
self.actor_optim.step()
self.sync_weight()
return {"loss/actor": actor_loss.item(), "loss/critic": critic_loss.item()}
def exploration_noise(
self,
act: np.ndarray | BatchProtocol,
batch: RolloutBatchProtocol,
) -> np.ndarray | BatchProtocol:
if self._exploration_noise is None:
return act
if isinstance(act, np.ndarray):
return act + self._exploration_noise(act.shape)
warnings.warn("Cannot add exploration noise to non-numpy_array action.")
return act