Michael Panchenko b900fdf6f2
Remove kwargs in policy init (#950)
Closes #947 

This removes all kwargs from all policy constructors. While doing that,
I also improved several names and added a whole lot of TODOs.

## Functional changes:

1. Added possibility to pass None as `critic2` and `critic2_optim`. In
fact, the default behavior then should cover the absolute majority of
cases
2. Added a function called `clone_optimizer` as a temporary measure to
support passing `critic2_optim=None`

## Breaking changes:

1. `action_space` is no longer optional. In fact, it already was
non-optional, as there was a ValueError in BasePolicy.init. So now
several examples were fixed to reflect that
2. `reward_normalization` removed from DDPG and children. It was never
allowed to pass it as `True` there, an error would have been raised in
`compute_n_step_reward`. Now I removed it from the interface
3. renamed `critic1` and similar to `critic`, in order to have uniform
interfaces. Note that the `critic` in DDPG was optional for the sole
reason that child classes used `critic1`. I removed this optionality
(DDPG can't do anything with `critic=None`)
4. Several renamings of fields (mostly private to public, so backwards
compatible)

## Additional changes: 
1. Removed type and default declaration from docstring. This kind of
duplication is really not necessary
2. Policy constructors are now only called using named arguments, not a
fragile mixture of positional and named as before
5. Minor beautifications in typing and code 
6. Generally shortened docstrings and made them uniform across all
policies (hopefully)

## Comment:

With these changes, several problems in tianshou's inheritance hierarchy
become more apparent. I tried highlighting them for future work.

---------

Co-authored-by: Dominik Jain <d.jain@appliedai.de>
2023-10-08 08:57:03 -07:00

108 lines
4.1 KiB
Python

from typing import Any, cast
import numpy as np
import torch
from tianshou.data import ReplayBuffer, SegmentTree, to_numpy
from tianshou.data.types import PrioBatchProtocol, RolloutBatchProtocol
class PrioritizedReplayBuffer(ReplayBuffer):
"""Implementation of Prioritized Experience Replay. arXiv:1511.05952.
:param alpha: the prioritization exponent.
:param beta: the importance sample soft coefficient.
:param weight_norm: whether to normalize returned weights with the maximum
weight value within the batch. Default to True.
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
"""
def __init__(
self,
size: int,
alpha: float,
beta: float,
weight_norm: bool = True,
**kwargs: Any,
) -> None:
# will raise KeyError in PrioritizedVectorReplayBuffer
# super().__init__(size, **kwargs)
ReplayBuffer.__init__(self, size, **kwargs)
assert alpha > 0.0
assert beta >= 0.0
self._alpha, self._beta = alpha, beta
self._max_prio = self._min_prio = 1.0
# save weight directly in this class instead of self._meta
self.weight = SegmentTree(size)
self.__eps = np.finfo(np.float32).eps.item()
self.options.update(alpha=alpha, beta=beta)
self._weight_norm = weight_norm
def init_weight(self, index: int | np.ndarray) -> None:
self.weight[index] = self._max_prio**self._alpha
def update(self, buffer: ReplayBuffer) -> np.ndarray:
indices = super().update(buffer)
self.init_weight(indices)
return indices
def add(
self,
batch: RolloutBatchProtocol,
buffer_ids: np.ndarray | list[int] | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids)
self.init_weight(ptr)
return ptr, ep_rew, ep_len, ep_idx
def sample_indices(self, batch_size: int) -> np.ndarray:
if batch_size > 0 and len(self) > 0:
scalar = np.random.rand(batch_size) * self.weight.reduce()
return self.weight.get_prefix_sum_idx(scalar) # type: ignore
return super().sample_indices(batch_size)
def get_weight(self, index: int | np.ndarray) -> float | np.ndarray:
"""Get the importance sampling weight.
The "weight" in the returned Batch is the weight on loss function to debias
the sampling process (some transition tuples are sampled more often so their
losses are weighted less).
"""
# important sampling weight calculation
# original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta))
# simplified formula: (p_j/p_min)**(-beta)
return (self.weight[index] / self._min_prio) ** (-self._beta)
def update_weight(self, index: np.ndarray, new_weight: np.ndarray | torch.Tensor) -> None:
"""Update priority weight by index in this buffer.
:param np.ndarray index: index you want to update weight.
:param np.ndarray new_weight: new priority weight you want to update.
"""
weight = np.abs(to_numpy(new_weight)) + self.__eps
self.weight[index] = weight**self._alpha
self._max_prio = max(self._max_prio, weight.max())
self._min_prio = min(self._min_prio, weight.min())
def __getitem__(self, index: slice | int | list[int] | np.ndarray) -> PrioBatchProtocol:
if isinstance(index, slice): # change slice to np array
# buffer[:] will get all available data
indices = (
self.sample_indices(0)
if index == slice(None)
else self._indices[: len(self)][index]
)
else:
indices = index # type: ignore
batch = super().__getitem__(indices)
weight = self.get_weight(indices)
# ref: https://github.com/Kaixhin/Rainbow/blob/master/memory.py L154
batch.weight = weight / np.max(weight) if self._weight_norm else weight
return cast(PrioBatchProtocol, batch)
def set_beta(self, beta: float) -> None:
self._beta = beta