Change the behavior of to_numpy and to_torch: from now on, dict is automatically converted to Batch and list is automatically converted to np.ndarray (if an error occurs, raise the exception instead of converting each element in the list).
88 lines
3.6 KiB
Python
88 lines
3.6 KiB
Python
import torch
|
|
import numpy as np
|
|
from typing import Any, List, Tuple, Union, Optional
|
|
|
|
from tianshou.data import Batch, SegmentTree, to_numpy, ReplayBuffer
|
|
|
|
|
|
class PrioritizedReplayBuffer(ReplayBuffer):
|
|
"""Implementation of Prioritized Experience Replay. arXiv:1511.05952.
|
|
|
|
:param float alpha: the prioritization exponent.
|
|
:param float beta: the importance sample soft coefficient.
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
|
|
"""
|
|
|
|
def __init__(self, size: int, alpha: float, beta: float, **kwargs: Any) -> None:
|
|
# will raise KeyError in PrioritizedVectorReplayBuffer
|
|
# super().__init__(size, **kwargs)
|
|
ReplayBuffer.__init__(self, size, **kwargs)
|
|
assert alpha > 0.0 and beta >= 0.0
|
|
self._alpha, self._beta = alpha, beta
|
|
self._max_prio = self._min_prio = 1.0
|
|
# save weight directly in this class instead of self._meta
|
|
self.weight = SegmentTree(size)
|
|
self.__eps = np.finfo(np.float32).eps.item()
|
|
self.options.update(alpha=alpha, beta=beta)
|
|
|
|
def init_weight(self, index: Union[int, np.ndarray]) -> None:
|
|
self.weight[index] = self._max_prio ** self._alpha
|
|
|
|
def update(self, buffer: ReplayBuffer) -> np.ndarray:
|
|
indices = super().update(buffer)
|
|
self.init_weight(indices)
|
|
return indices
|
|
|
|
def add(
|
|
self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
|
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
|
ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids)
|
|
self.init_weight(ptr)
|
|
return ptr, ep_rew, ep_len, ep_idx
|
|
|
|
def sample_index(self, batch_size: int) -> np.ndarray:
|
|
if batch_size > 0 and len(self) > 0:
|
|
scalar = np.random.rand(batch_size) * self.weight.reduce()
|
|
return self.weight.get_prefix_sum_idx(scalar) # type: ignore
|
|
else:
|
|
return super().sample_index(batch_size)
|
|
|
|
def get_weight(self, index: Union[int, np.ndarray]) -> Union[float, np.ndarray]:
|
|
"""Get the importance sampling weight.
|
|
|
|
The "weight" in the returned Batch is the weight on loss function to de-bias
|
|
the sampling process (some transition tuples are sampled more often so their
|
|
losses are weighted less).
|
|
"""
|
|
# important sampling weight calculation
|
|
# original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta))
|
|
# simplified formula: (p_j/p_min)**(-beta)
|
|
return (self.weight[index] / self._min_prio) ** (-self._beta)
|
|
|
|
def update_weight(
|
|
self, index: np.ndarray, new_weight: Union[np.ndarray, torch.Tensor]
|
|
) -> None:
|
|
"""Update priority weight by index in this buffer.
|
|
|
|
:param np.ndarray index: index you want to update weight.
|
|
:param np.ndarray new_weight: new priority weight you want to update.
|
|
"""
|
|
weight = np.abs(to_numpy(new_weight)) + self.__eps
|
|
self.weight[index] = weight ** self._alpha
|
|
self._max_prio = max(self._max_prio, weight.max())
|
|
self._min_prio = min(self._min_prio, weight.min())
|
|
|
|
def __getitem__(self, index: Union[slice, int, List[int], np.ndarray]) -> Batch:
|
|
if isinstance(index, slice): # change slice to np array
|
|
# buffer[:] will get all available data
|
|
indice = self.sample_index(0) if index == slice(None) \
|
|
else self._indices[:len(self)][index]
|
|
else:
|
|
indice = index
|
|
batch = super().__getitem__(indice)
|
|
batch.weight = self.get_weight(indice)
|
|
return batch
|