Michael Panchenko b900fdf6f2
Remove kwargs in policy init (#950)
Closes #947 

This removes all kwargs from all policy constructors. While doing that,
I also improved several names and added a whole lot of TODOs.

## Functional changes:

1. Added possibility to pass None as `critic2` and `critic2_optim`. In
fact, the default behavior then should cover the absolute majority of
cases
2. Added a function called `clone_optimizer` as a temporary measure to
support passing `critic2_optim=None`

## Breaking changes:

1. `action_space` is no longer optional. In fact, it already was
non-optional, as there was a ValueError in BasePolicy.init. So now
several examples were fixed to reflect that
2. `reward_normalization` removed from DDPG and children. It was never
allowed to pass it as `True` there, an error would have been raised in
`compute_n_step_reward`. Now I removed it from the interface
3. renamed `critic1` and similar to `critic`, in order to have uniform
interfaces. Note that the `critic` in DDPG was optional for the sole
reason that child classes used `critic1`. I removed this optionality
(DDPG can't do anything with `critic=None`)
4. Several renamings of fields (mostly private to public, so backwards
compatible)

## Additional changes: 
1. Removed type and default declaration from docstring. This kind of
duplication is really not necessary
2. Policy constructors are now only called using named arguments, not a
fragile mixture of positional and named as before
5. Minor beautifications in typing and code 
6. Generally shortened docstrings and made them uniform across all
policies (hopefully)

## Comment:

With these changes, several problems in tianshou's inheritance hierarchy
become more apparent. I tried highlighting them for future work.

---------

Co-authored-by: Dominik Jain <d.jain@appliedai.de>
2023-10-08 08:57:03 -07:00

151 lines
5.2 KiB
Python

from typing import Any, Literal, Self
import gymnasium as gym
import numpy as np
import torch
import torch.nn.functional as F
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import RolloutBatchProtocol
from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
class ICMPolicy(BasePolicy):
"""Implementation of Intrinsic Curiosity Module. arXiv:1705.05363.
:param policy: a base policy to add ICM to.
:param model: the ICM model.
:param optim: a torch.optim for optimizing the model.
:param lr_scale: the scaling factor for ICM learning.
:param forward_loss_weight: the weight for forward model loss.
:param observation_space: Env's observation space.
:param action_scaling: if True, scale the action from [-1, 1] to the range
of action_space. Only used if the action_space is continuous.
:param action_bound_method: method to bound action to range [-1, 1].
Only used if the action_space is continuous.
:param lr_scheduler: if not None, will be called in `policy.update()`.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(
self,
*,
policy: BasePolicy,
model: IntrinsicCuriosityModule,
optim: torch.optim.Optimizer,
lr_scale: float,
reward_scale: float,
forward_loss_weight: float,
action_space: gym.Space,
observation_space: gym.Space | None = None,
action_scaling: bool = False,
action_bound_method: Literal["clip", "tanh"] | None = "clip",
lr_scheduler: TLearningRateScheduler | None = None,
) -> None:
super().__init__(
action_space=action_space,
observation_space=observation_space,
action_scaling=action_scaling,
action_bound_method=action_bound_method,
lr_scheduler=lr_scheduler,
)
self.policy = policy
self.model = model
self.optim = optim
self.lr_scale = lr_scale
self.reward_scale = reward_scale
self.forward_loss_weight = forward_loss_weight
def train(self, mode: bool = True) -> Self:
"""Set the module in training mode."""
self.policy.train(mode)
self.training = mode
self.model.train(mode)
return self
def forward(
self,
batch: RolloutBatchProtocol,
state: dict | BatchProtocol | np.ndarray | None = None,
**kwargs: Any,
) -> BatchProtocol:
"""Compute action over the given batch data by inner policy.
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
return self.policy.forward(batch, state, **kwargs)
def exploration_noise(
self,
act: np.ndarray | BatchProtocol,
batch: RolloutBatchProtocol,
) -> np.ndarray | BatchProtocol:
return self.policy.exploration_noise(act, batch)
def set_eps(self, eps: float) -> None:
"""Set the eps for epsilon-greedy exploration."""
if hasattr(self.policy, "set_eps"):
self.policy.set_eps(eps) # type: ignore
else:
raise NotImplementedError
def process_fn(
self,
batch: RolloutBatchProtocol,
buffer: ReplayBuffer,
indices: np.ndarray,
) -> RolloutBatchProtocol:
"""Pre-process the data from the provided replay buffer.
Used in :meth:`update`. Check out :ref:`process_fn` for more information.
"""
mse_loss, act_hat = self.model(batch.obs, batch.act, batch.obs_next)
batch.policy = Batch(orig_rew=batch.rew, act_hat=act_hat, mse_loss=mse_loss)
batch.rew += to_numpy(mse_loss * self.reward_scale)
return self.policy.process_fn(batch, buffer, indices)
def post_process_fn(
self,
batch: BatchProtocol,
buffer: ReplayBuffer,
indices: np.ndarray,
) -> None:
"""Post-process the data from the provided replay buffer.
Typical usage is to update the sampling weight in prioritized
experience replay. Used in :meth:`update`.
"""
self.policy.post_process_fn(batch, buffer, indices)
batch.rew = batch.policy.orig_rew # restore original reward
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]:
res = self.policy.learn(batch, **kwargs)
self.optim.zero_grad()
act_hat = batch.policy.act_hat
act = to_torch(batch.act, dtype=torch.long, device=act_hat.device)
inverse_loss = F.cross_entropy(act_hat, act).mean()
forward_loss = batch.policy.mse_loss.mean()
loss = (
(1 - self.forward_loss_weight) * inverse_loss + self.forward_loss_weight * forward_loss
) * self.lr_scale
loss.backward()
self.optim.step()
res.update(
{
"loss/icm": loss.item(),
"loss/icm/forward": forward_loss.item(),
"loss/icm/inverse": inverse_loss.item(),
},
)
return res