Michael Panchenko b900fdf6f2
Remove kwargs in policy init (#950)
Closes #947 

This removes all kwargs from all policy constructors. While doing that,
I also improved several names and added a whole lot of TODOs.

## Functional changes:

1. Added possibility to pass None as `critic2` and `critic2_optim`. In
fact, the default behavior then should cover the absolute majority of
cases
2. Added a function called `clone_optimizer` as a temporary measure to
support passing `critic2_optim=None`

## Breaking changes:

1. `action_space` is no longer optional. In fact, it already was
non-optional, as there was a ValueError in BasePolicy.init. So now
several examples were fixed to reflect that
2. `reward_normalization` removed from DDPG and children. It was never
allowed to pass it as `True` there, an error would have been raised in
`compute_n_step_reward`. Now I removed it from the interface
3. renamed `critic1` and similar to `critic`, in order to have uniform
interfaces. Note that the `critic` in DDPG was optional for the sole
reason that child classes used `critic1`. I removed this optionality
(DDPG can't do anything with `critic=None`)
4. Several renamings of fields (mostly private to public, so backwards
compatible)

## Additional changes: 
1. Removed type and default declaration from docstring. This kind of
duplication is really not necessary
2. Policy constructors are now only called using named arguments, not a
fragile mixture of positional and named as before
5. Minor beautifications in typing and code 
6. Generally shortened docstrings and made them uniform across all
policies (hopefully)

## Comment:

With these changes, several problems in tianshou's inheritance hierarchy
become more apparent. I tried highlighting them for future work.

---------

Co-authored-by: Dominik Jain <d.jain@appliedai.de>
2023-10-08 08:57:03 -07:00

191 lines
7.9 KiB
Python

from collections.abc import Callable
from typing import Any, Literal
import gymnasium as gym
import numpy as np
import torch
from torch import nn
from tianshou.data import ReplayBuffer, to_torch_as
from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol
from tianshou.policy import A2CPolicy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.pg import TDistParams
from tianshou.utils.net.common import ActorCritic
class PPOPolicy(A2CPolicy):
r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347.
:param actor: the actor network following the rules in BasePolicy. (s -> logits)
:param critic: the critic network. (s -> V(s))
:param optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action.
:param action_space: env's action space
:param eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original
paper.
:param dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5,
where c > 1 is a constant indicating the lower bound. Set to None
to disable dual-clip PPO.
:param value_clip: a parameter mentioned in arXiv:1811.02553v3 Sec. 4.1.
:param advantage_normalization: whether to do per mini-batch advantage
normalization.
:param recompute_advantage: whether to recompute advantage every update
repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5.
:param vf_coef: weight for value loss.
:param ent_coef: weight for entropy loss.
:param max_grad_norm: clipping gradients in back propagation.
:param gae_lambda: in [0, 1], param for Generalized Advantage Estimation.
:param max_batchsize: the maximum size of the batch when computing GAE.
:param discount_factor: in [0, 1].
:param reward_normalization: normalize estimated values to have std close to 1.
:param deterministic_eval: if True, use deterministic evaluation.
:param observation_space: the space of the observation.
:param action_scaling: if True, scale the action from [-1, 1] to the range of
action_space. Only used if the action_space is continuous.
:param action_bound_method: method to bound action to range [-1, 1].
:param lr_scheduler: if not None, will be called in `policy.update()`.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(
self,
*,
actor: torch.nn.Module,
critic: torch.nn.Module,
optim: torch.optim.Optimizer,
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
action_space: gym.Space,
eps_clip: float = 0.2,
dual_clip: float | None = None,
value_clip: bool = False,
advantage_normalization: bool = True,
recompute_advantage: bool = False,
vf_coef: float = 0.5,
ent_coef: float = 0.01,
max_grad_norm: float | None = None,
gae_lambda: float = 0.95,
max_batchsize: int = 256,
discount_factor: float = 0.99,
# TODO: rename to return_normalization?
reward_normalization: bool = False,
deterministic_eval: bool = False,
observation_space: gym.Space | None = None,
action_scaling: bool = True,
action_bound_method: Literal["clip", "tanh"] | None = "clip",
lr_scheduler: TLearningRateScheduler | None = None,
) -> None:
assert (
dual_clip is None or dual_clip > 1.0
), f"Dual-clip PPO parameter should greater than 1.0 but got {dual_clip}"
super().__init__(
actor=actor,
critic=critic,
optim=optim,
dist_fn=dist_fn,
action_space=action_space,
vf_coef=vf_coef,
ent_coef=ent_coef,
max_grad_norm=max_grad_norm,
gae_lambda=gae_lambda,
max_batchsize=max_batchsize,
discount_factor=discount_factor,
reward_normalization=reward_normalization,
deterministic_eval=deterministic_eval,
observation_space=observation_space,
action_scaling=action_scaling,
action_bound_method=action_bound_method,
lr_scheduler=lr_scheduler,
)
self.eps_clip = eps_clip
self.dual_clip = dual_clip
self.value_clip = value_clip
self.norm_adv = advantage_normalization
self.recompute_adv = recompute_advantage
self._actor_critic: ActorCritic
def process_fn(
self,
batch: RolloutBatchProtocol,
buffer: ReplayBuffer,
indices: np.ndarray,
) -> LogpOldProtocol:
if self.recompute_adv:
# buffer input `buffer` and `indices` to be used in `learn()`.
self._buffer, self._indices = buffer, indices
batch = self._compute_returns(batch, buffer, indices)
batch.act = to_torch_as(batch.act, batch.v_s)
with torch.no_grad():
batch.logp_old = self(batch).dist.log_prob(batch.act)
batch: LogpOldProtocol
return batch
# TODO: why does mypy complain?
def learn( # type: ignore
self,
batch: RolloutBatchProtocol,
batch_size: int,
repeat: int,
*args: Any,
**kwargs: Any,
) -> dict[str, list[float]]:
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
for step in range(repeat):
if self.recompute_adv and step > 0:
batch = self._compute_returns(batch, self._buffer, self._indices)
for minibatch in batch.split(batch_size, merge_last=True):
# calculate loss for actor
dist = self(minibatch).dist
if self.norm_adv:
mean, std = minibatch.adv.mean(), minibatch.adv.std()
minibatch.adv = (minibatch.adv - mean) / (std + self._eps) # per-batch norm
ratio = (dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float()
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
surr1 = ratio * minibatch.adv
surr2 = ratio.clamp(1.0 - self.eps_clip, 1.0 + self.eps_clip) * minibatch.adv
if self.dual_clip:
clip1 = torch.min(surr1, surr2)
clip2 = torch.max(clip1, self.dual_clip * minibatch.adv)
clip_loss = -torch.where(minibatch.adv < 0, clip2, clip1).mean()
else:
clip_loss = -torch.min(surr1, surr2).mean()
# calculate loss for critic
value = self.critic(minibatch.obs).flatten()
if self.value_clip:
v_clip = minibatch.v_s + (value - minibatch.v_s).clamp(
-self.eps_clip,
self.eps_clip,
)
vf1 = (minibatch.returns - value).pow(2)
vf2 = (minibatch.returns - v_clip).pow(2)
vf_loss = torch.max(vf1, vf2).mean()
else:
vf_loss = (minibatch.returns - value).pow(2).mean()
# calculate regularization and overall loss
ent_loss = dist.entropy().mean()
loss = clip_loss + self.vf_coef * vf_loss - self.ent_coef * ent_loss
self.optim.zero_grad()
loss.backward()
if self.max_grad_norm: # clip large gradient
nn.utils.clip_grad_norm_(
self._actor_critic.parameters(),
max_norm=self.max_grad_norm,
)
self.optim.step()
clip_losses.append(clip_loss.item())
vf_losses.append(vf_loss.item())
ent_losses.append(ent_loss.item())
losses.append(loss.item())
return {
"loss": losses,
"loss/clip": clip_losses,
"loss/vf": vf_losses,
"loss/ent": ent_losses,
}