n+e 454c86c469
fix venv seed, add TOC in docs, and split buffer.py into several files (#303)
Things changed in this PR:

- various docs update, add TOC
- split buffer into several files
- fix venv action_space randomness
2021-03-02 12:28:28 +08:00

83 lines
3.3 KiB
Python

import torch
import numpy as np
from typing import Any, List, Tuple, Union, Optional
from tianshou.data import Batch, SegmentTree, to_numpy, ReplayBuffer
class PrioritizedReplayBuffer(ReplayBuffer):
"""Implementation of Prioritized Experience Replay. arXiv:1511.05952.
:param float alpha: the prioritization exponent.
:param float beta: the importance sample soft coefficient.
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
"""
def __init__(self, size: int, alpha: float, beta: float, **kwargs: Any) -> None:
# will raise KeyError in PrioritizedVectorReplayBuffer
# super().__init__(size, **kwargs)
ReplayBuffer.__init__(self, size, **kwargs)
assert alpha > 0.0 and beta >= 0.0
self._alpha, self._beta = alpha, beta
self._max_prio = self._min_prio = 1.0
# save weight directly in this class instead of self._meta
self.weight = SegmentTree(size)
self.__eps = np.finfo(np.float32).eps.item()
self.options.update(alpha=alpha, beta=beta)
def init_weight(self, index: Union[int, np.ndarray]) -> None:
self.weight[index] = self._max_prio ** self._alpha
def update(self, buffer: ReplayBuffer) -> np.ndarray:
indices = super().update(buffer)
self.init_weight(indices)
def add(
self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids)
self.init_weight(ptr)
return ptr, ep_rew, ep_len, ep_idx
def sample_index(self, batch_size: int) -> np.ndarray:
if batch_size > 0 and len(self) > 0:
scalar = np.random.rand(batch_size) * self.weight.reduce()
return self.weight.get_prefix_sum_idx(scalar)
else:
return super().sample_index(batch_size)
def get_weight(
self, index: Union[slice, int, np.integer, np.ndarray]
) -> np.ndarray:
"""Get the importance sampling weight.
The "weight" in the returned Batch is the weight on loss function to de-bias
the sampling process (some transition tuples are sampled more often so their
losses are weighted less).
"""
# important sampling weight calculation
# original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta))
# simplified formula: (p_j/p_min)**(-beta)
return (self.weight[index] / self._min_prio) ** (-self._beta)
def update_weight(
self, index: np.ndarray, new_weight: Union[np.ndarray, torch.Tensor]
) -> None:
"""Update priority weight by index in this buffer.
:param np.ndarray index: index you want to update weight.
:param np.ndarray new_weight: new priority weight you want to update.
"""
weight = np.abs(to_numpy(new_weight)) + self.__eps
self.weight[index] = weight ** self._alpha
self._max_prio = max(self._max_prio, weight.max())
self._min_prio = min(self._min_prio, weight.min())
def __getitem__(self, index: Union[slice, int, np.integer, np.ndarray]) -> Batch:
batch = super().__getitem__(index)
batch.weight = self.get_weight(index)
return batch