diff --git a/test/discrete/test_pdqn.py b/test/discrete/test_pdqn.py index 8450c48..70bfdbb 100644 --- a/test/discrete/test_pdqn.py +++ b/test/discrete/test_pdqn.py @@ -74,7 +74,8 @@ def test_pdqn(args=get_args()): # collector if args.prioritized_replay > 0: buf = PrioritizedReplayBuffer( - args.buffer_size, alpha=args.alpha, beta=args.alpha) + args.buffer_size, alpha=args.alpha, + beta=args.alpha, repeat_sample=True) else: buf = ReplayBuffer(args.buffer_size) train_collector = Collector( diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index cd05393..962dc7e 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -340,6 +340,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): :param float alpha: the prioritization exponent. :param float beta: the importance sample soft coefficient. :param str mode: defaults to ``weight``. + :param bool replace: whether to sample with replacement .. seealso:: @@ -348,7 +349,8 @@ class PrioritizedReplayBuffer(ReplayBuffer): """ def __init__(self, size: int, alpha: float, beta: float, - mode: str = 'weight', **kwargs) -> None: + mode: str = 'weight', + replace: bool = False, **kwargs) -> None: if mode != 'weight': raise NotImplementedError super().__init__(size, **kwargs) @@ -358,6 +360,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): self.weight = np.zeros(size, dtype=np.float64) self._amortization_freq = 50 self._amortization_counter = 0 + self._replace = replace def add(self, obs: Union[dict, np.ndarray], @@ -377,6 +380,14 @@ class PrioritizedReplayBuffer(ReplayBuffer): super().add(obs, act, rew, done, obs_next, info, policy) self._check_weight_sum() + @property + def replace(self): + return self._replace + + @replace.setter + def replace(self, v: bool): + self._replace = v + def sample(self, batch_size: int, importance_sample: bool = True ) -> Tuple[Batch, np.ndarray]: @@ -391,7 +402,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): indice = np.random.choice( self._size, batch_size, p=(self.weight / self.weight.sum())[:self._size], - replace=False) + replace=self._replace) # self._weight_sum is not work for the accuracy issue # p=(self.weight/self._weight_sum)[:self._size], replace=False) elif batch_size == 0: @@ -402,7 +413,9 @@ class PrioritizedReplayBuffer(ReplayBuffer): else: # if batch_size larger than len(self), # it will lead to a bug in update weight - raise ValueError("batch_size should be less than len(self)") + raise ValueError( + "batch_size should be less than len(self), \ + or set replace=False") batch = self[indice] if importance_sample: impt_weight = Batch( @@ -424,6 +437,14 @@ class PrioritizedReplayBuffer(ReplayBuffer): :param np.ndarray indice: indice you want to update weight :param np.ndarray new_weight: new priority weight you wangt to update """ + if self._replace: + if isinstance(indice, slice): + # convert slice to ndarray + indice = np.arange(indice.stop)[indice] + # remove the same values in indice + indice, unique_indice = np.unique( + indice, return_index=True) + new_weight = new_weight[unique_indice] self._weight_sum += np.power(np.abs(new_weight), self._alpha).sum() \ - self.weight[indice].sum() self.weight[indice] = np.power(np.abs(new_weight), self._alpha)