Michael Panchenko b900fdf6f2
Remove kwargs in policy init (#950)
Closes #947 

This removes all kwargs from all policy constructors. While doing that,
I also improved several names and added a whole lot of TODOs.

## Functional changes:

1. Added possibility to pass None as `critic2` and `critic2_optim`. In
fact, the default behavior then should cover the absolute majority of
cases
2. Added a function called `clone_optimizer` as a temporary measure to
support passing `critic2_optim=None`

## Breaking changes:

1. `action_space` is no longer optional. In fact, it already was
non-optional, as there was a ValueError in BasePolicy.init. So now
several examples were fixed to reflect that
2. `reward_normalization` removed from DDPG and children. It was never
allowed to pass it as `True` there, an error would have been raised in
`compute_n_step_reward`. Now I removed it from the interface
3. renamed `critic1` and similar to `critic`, in order to have uniform
interfaces. Note that the `critic` in DDPG was optional for the sole
reason that child classes used `critic1`. I removed this optionality
(DDPG can't do anything with `critic=None`)
4. Several renamings of fields (mostly private to public, so backwards
compatible)

## Additional changes: 
1. Removed type and default declaration from docstring. This kind of
duplication is really not necessary
2. Policy constructors are now only called using named arguments, not a
fragile mixture of positional and named as before
5. Minor beautifications in typing and code 
6. Generally shortened docstrings and made them uniform across all
policies (hopefully)

## Comment:

With these changes, several problems in tianshou's inheritance hierarchy
become more apparent. I tried highlighting them for future work.

---------

Co-authored-by: Dominik Jain <d.jain@appliedai.de>
2023-10-08 08:57:03 -07:00

235 lines
8.3 KiB
Python

from copy import deepcopy
from typing import Any, Self, cast
import gymnasium as gym
import numpy as np
import torch
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import (
BatchWithReturnsProtocol,
ModelOutputBatchProtocol,
RolloutBatchProtocol,
)
from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler
class DQNPolicy(BasePolicy):
"""Implementation of Deep Q Network. arXiv:1312.5602.
Implementation of Double Q-Learning. arXiv:1509.06461.
Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is
implemented in the network side, not here).
:param model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param optim: a torch.optim for optimizing the model.
:param discount_factor: in [0, 1].
:param estimation_step: the number of steps to look ahead.
:param target_update_freq: the target network update frequency (0 if
you do not use the target network).
:param reward_normalization: normalize the **returns** to Normal(0, 1).
TODO: rename to return_normalization?
:param is_double: use double dqn.
:param clip_loss_grad: clip the gradient of the loss in accordance
with nature14236; this amounts to using the Huber loss instead of
the MSE loss.
:param observation_space: Env's observation space.
:param lr_scheduler: if not None, will be called in `policy.update()`.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(
self,
*,
model: torch.nn.Module,
optim: torch.optim.Optimizer,
# TODO: type violates Liskov substitution principle
action_space: gym.spaces.Discrete,
discount_factor: float = 0.99,
estimation_step: int = 1,
target_update_freq: int = 0,
reward_normalization: bool = False,
is_double: bool = True,
clip_loss_grad: bool = False,
observation_space: gym.Space | None = None,
lr_scheduler: TLearningRateScheduler | None = None,
) -> None:
super().__init__(
action_space=action_space,
observation_space=observation_space,
action_scaling=False,
action_bound_method=None,
lr_scheduler=lr_scheduler,
)
self.model = model
self.optim = optim
self.eps = 0.0
assert (
0.0 <= discount_factor <= 1.0
), f"discount factor should be in [0, 1] but got: {discount_factor}"
self.gamma = discount_factor
assert (
estimation_step > 0
), f"estimation_step should be greater than 0 but got: {estimation_step}"
self.n_step = estimation_step
self._target = target_update_freq > 0
self.freq = target_update_freq
self._iter = 0
if self._target:
self.model_old = deepcopy(self.model)
self.model_old.eval()
self.rew_norm = reward_normalization
self.is_double = is_double
self.clip_loss_grad = clip_loss_grad
# TODO: set in forward, fix this!
self.max_action_num: int
def set_eps(self, eps: float) -> None:
"""Set the eps for epsilon-greedy exploration."""
self.eps = eps
def train(self, mode: bool = True) -> Self:
"""Set the module in training mode, except for the target network."""
self.training = mode
self.model.train(mode)
return self
def sync_weight(self) -> None:
"""Synchronize the weight for the target network."""
self.model_old.load_state_dict(self.model.state_dict())
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs_next: s_{t+n}
result = self(batch, input="obs_next")
if self._target:
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
target_q = self(batch, model="model_old", input="obs_next").logits
else:
target_q = result.logits
if self.is_double:
return target_q[np.arange(len(result.act)), result.act]
# Nature DQN, over estimate
return target_q.max(dim=1)[0]
def process_fn(
self,
batch: RolloutBatchProtocol,
buffer: ReplayBuffer,
indices: np.ndarray,
) -> BatchWithReturnsProtocol:
"""Compute the n-step return for Q-learning targets.
More details can be found at
:meth:`~tianshou.policy.BasePolicy.compute_nstep_return`.
"""
return self.compute_nstep_return(
batch=batch,
buffer=buffer,
indices=indices,
target_q_fn=self._target_q,
gamma=self.gamma,
n_step=self.n_step,
rew_norm=self.rew_norm,
)
def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor:
"""Compute the q value based on the network's raw output and action mask."""
if mask is not None:
# the masked q value should be smaller than logits.min()
min_value = logits.min() - logits.max() - 1.0
logits = logits + to_torch_as(1 - mask, logits) * min_value
return logits
def forward(
self,
batch: RolloutBatchProtocol,
state: dict | BatchProtocol | np.ndarray | None = None,
model: str = "model",
input: str = "obs",
**kwargs: Any,
) -> ModelOutputBatchProtocol:
"""Compute action over the given batch data.
If you need to mask the action, please add a "mask" into batch.obs, for
example, if we have an environment that has "0/1/2" three actions:
::
batch == Batch(
obs=Batch(
obs="original obs, with batch_size=1 for demonstration",
mask=np.array([[False, True, False]]),
# action 1 is available
# action 0 and 2 are unavailable
),
...
)
:return: A :class:`~tianshou.data.Batch` which has 3 keys:
* ``act`` the action.
* ``logits`` the network's raw output.
* ``state`` the hidden state.
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
model = getattr(self, model)
obs = batch[input]
obs_next = obs.obs if hasattr(obs, "obs") else obs
logits, hidden = model(obs_next, state=state, info=batch.info)
q = self.compute_q_value(logits, getattr(obs, "mask", None))
if not hasattr(self, "max_action_num"):
self.max_action_num = q.shape[1]
act = to_numpy(q.max(dim=1)[1])
result = Batch(logits=logits, act=act, state=hidden)
return cast(ModelOutputBatchProtocol, result)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]:
if self._target and self._iter % self.freq == 0:
self.sync_weight()
self.optim.zero_grad()
weight = batch.pop("weight", 1.0)
q = self(batch).logits
q = q[np.arange(len(q)), batch.act]
returns = to_torch_as(batch.returns.flatten(), q)
td_error = returns - q
if self.clip_loss_grad:
y = q.reshape(-1, 1)
t = returns.reshape(-1, 1)
loss = torch.nn.functional.huber_loss(y, t, reduction="mean")
else:
loss = (td_error.pow(2) * weight).mean()
batch.weight = td_error # prio-buffer
loss.backward()
self.optim.step()
self._iter += 1
return {"loss": loss.item()}
def exploration_noise(
self,
act: np.ndarray | BatchProtocol,
batch: RolloutBatchProtocol,
) -> np.ndarray | BatchProtocol:
if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0):
bsz = len(act)
rand_mask = np.random.rand(bsz) < self.eps
q = np.random.rand(bsz, self.max_action_num) # [0, 1]
if hasattr(batch.obs, "mask"):
q += batch.obs.mask
rand_act = q.argmax(axis=1)
act[rand_mask] = rand_act[rand_mask]
return act