n+e fc251ab0b8
bump to v0.4.3 (#432)
* add makefile
* bump version
* add isort and yapf
* update contributing.md
* update PR template
* spelling check
2021-09-03 05:05:04 +08:00

106 lines
4.0 KiB
Python

from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
from tianshou.data import Batch, ReplayBuffer, SegmentTree, to_numpy
class PrioritizedReplayBuffer(ReplayBuffer):
"""Implementation of Prioritized Experience Replay. arXiv:1511.05952.
:param float alpha: the prioritization exponent.
:param float beta: the importance sample soft coefficient.
:param bool weight_norm: whether to normalize returned weights with the maximum
weight value within the batch. Default to True.
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
"""
def __init__(
self,
size: int,
alpha: float,
beta: float,
weight_norm: bool = True,
**kwargs: Any
) -> None:
# will raise KeyError in PrioritizedVectorReplayBuffer
# super().__init__(size, **kwargs)
ReplayBuffer.__init__(self, size, **kwargs)
assert alpha > 0.0 and beta >= 0.0
self._alpha, self._beta = alpha, beta
self._max_prio = self._min_prio = 1.0
# save weight directly in this class instead of self._meta
self.weight = SegmentTree(size)
self.__eps = np.finfo(np.float32).eps.item()
self.options.update(alpha=alpha, beta=beta)
self._weight_norm = weight_norm
def init_weight(self, index: Union[int, np.ndarray]) -> None:
self.weight[index] = self._max_prio**self._alpha
def update(self, buffer: ReplayBuffer) -> np.ndarray:
indices = super().update(buffer)
self.init_weight(indices)
return indices
def add(
self,
batch: Batch,
buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids)
self.init_weight(ptr)
return ptr, ep_rew, ep_len, ep_idx
def sample_indices(self, batch_size: int) -> np.ndarray:
if batch_size > 0 and len(self) > 0:
scalar = np.random.rand(batch_size) * self.weight.reduce()
return self.weight.get_prefix_sum_idx(scalar) # type: ignore
else:
return super().sample_indices(batch_size)
def get_weight(self, index: Union[int, np.ndarray]) -> Union[float, np.ndarray]:
"""Get the importance sampling weight.
The "weight" in the returned Batch is the weight on loss function to debias
the sampling process (some transition tuples are sampled more often so their
losses are weighted less).
"""
# important sampling weight calculation
# original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta))
# simplified formula: (p_j/p_min)**(-beta)
return (self.weight[index] / self._min_prio)**(-self._beta)
def update_weight(
self, index: np.ndarray, new_weight: Union[np.ndarray, torch.Tensor]
) -> None:
"""Update priority weight by index in this buffer.
:param np.ndarray index: index you want to update weight.
:param np.ndarray new_weight: new priority weight you want to update.
"""
weight = np.abs(to_numpy(new_weight)) + self.__eps
self.weight[index] = weight**self._alpha
self._max_prio = max(self._max_prio, weight.max())
self._min_prio = min(self._min_prio, weight.min())
def __getitem__(self, index: Union[slice, int, List[int], np.ndarray]) -> Batch:
if isinstance(index, slice): # change slice to np array
# buffer[:] will get all available data
indices = self.sample_indices(0) if index == slice(None) \
else self._indices[:len(self)][index]
else:
indices = index
batch = super().__getitem__(indices)
weight = self.get_weight(indices)
# ref: https://github.com/Kaixhin/Rainbow/blob/master/memory.py L154
batch.weight = weight / np.max(weight) if self._weight_norm else weight
return batch
def set_beta(self, beta: float) -> None:
self._beta = beta