From ff99662fe6f99b31e62e99656a05cda211e08e90 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 10 Jul 2020 08:24:11 +0800 Subject: [PATCH] bugfix for update with empty buffer; remove duplicate variable _weight_sum in PrioritizedReplayBuffer (#120) * bugfix for update with empty buffer; remove duplicate variable _weight_sum in PrioritizedReplayBuffer * point out that ListReplayBuffer cannot be sampled * remove useless _amortization_counter variable --- test/base/test_buffer.py | 8 ++---- tianshou/data/buffer.py | 56 ++++++++++++++++------------------------ 2 files changed, 24 insertions(+), 40 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index b4eac93..28ccd88 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -92,20 +92,16 @@ def test_priortized_replaybuffer(size=32, bufsize=15): obs_next, rew, done, info = env.step(a) buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5) obs = obs_next - assert np.isclose(np.sum((buf.weight / buf._weight_sum)[:buf._size]), - 1, rtol=1e-12) data, indice = buf.sample(len(buf) // 2) if len(buf) // 2 == 0: assert len(data) == len(buf) else: assert len(data) == len(buf) // 2 assert len(buf) == min(bufsize, i + 1) - assert np.isclose(buf._weight_sum, (buf.weight).sum()) data, indice = buf.sample(len(buf) // 2) buf.update_weight(indice, -data.weight / 2) - assert np.isclose(buf.weight[indice], np.power( - np.abs(-data.weight / 2), buf._alpha)).all() - assert np.isclose(buf._weight_sum, (buf.weight).sum()) + assert np.allclose( + buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha) if __name__ == '__main__': diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 5c0bcad..33d3178 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -151,6 +151,8 @@ class ReplayBuffer: def update(self, buffer: 'ReplayBuffer') -> None: """Move the data from the given buffer to self.""" + if len(buffer) == 0: + return i = begin = buffer._index % len(buffer) while True: self.add(**buffer[i]) @@ -298,7 +300,9 @@ class ReplayBuffer: class ListReplayBuffer(ReplayBuffer): """The function of :class:`~tianshou.data.ListReplayBuffer` is almost the same as :class:`~tianshou.data.ReplayBuffer`. The only difference is that - :class:`~tianshou.data.ListReplayBuffer` is based on ``list``. + :class:`~tianshou.data.ListReplayBuffer` is based on ``list``. Therefore, + it does not support advanced indexing, which means you cannot sample a + batch of data out of it. It is typically used for storing data. .. seealso:: @@ -309,6 +313,9 @@ class ListReplayBuffer(ReplayBuffer): def __init__(self, **kwargs) -> None: super().__init__(size=0, ignore_obs_next=False, **kwargs) + def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: + raise NotImplementedError("ListReplayBuffer cannot be sampled!") + def _add_to_buffer( self, name: str, inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None: @@ -349,7 +356,6 @@ class PrioritizedReplayBuffer(ReplayBuffer): self._beta = beta self._weight_sum = 0.0 self._amortization_freq = 50 - self._amortization_counter = 0 self._replace = replace self._meta.__dict__['weight'] = np.zeros(size, dtype=np.float64) @@ -369,7 +375,6 @@ class PrioritizedReplayBuffer(ReplayBuffer): self._meta.__dict__['weight'][self._index] self._add_to_buffer('weight', np.abs(weight) ** self._alpha) super().add(obs, act, rew, done, obs_next, info, policy) - self._check_weight_sum() @property def replace(self): @@ -379,46 +384,38 @@ class PrioritizedReplayBuffer(ReplayBuffer): def replace(self, v: bool): self._replace = v - def sample(self, batch_size: int, - importance_sample: bool = True - ) -> Tuple[Batch, np.ndarray]: + def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: """Get a random sample from buffer with priority probability. \ Return all the data in the buffer if batch_size is ``0``. :return: Sample data and its corresponding index inside the buffer. """ - if batch_size > 0 and batch_size <= self._size: - # Multiple sampling of the same sample - # will cause weight update conflict + assert self._size > 0, 'cannot sample a buffer with size == 0 !' + p = None + if batch_size > 0 and (self._replace or batch_size <= self._size): + # sampling weight + p = (self.weight / self.weight.sum())[:self._size] indice = np.random.choice( - self._size, batch_size, - p=(self.weight / self.weight.sum())[:self._size], + self._size, batch_size, p=p, replace=self._replace) - # self._weight_sum is not work for the accuracy issue - # p=(self.weight/self._weight_sum)[:self._size], replace=False) + p = p[indice] # weight of each sample elif batch_size == 0: + p = np.full(shape=self._size, fill_value=1.0/self._size) indice = np.concatenate([ np.arange(self._index, self._size), np.arange(0, self._index), ]) else: - # if batch_size larger than len(self), - # it will lead to a bug in update weight raise ValueError( - "batch_size should be less than len(self), \ - or set replace=False") + f"batch_size should be less than {len(self)}, \ + or set replace=True") batch = self[indice] - if importance_sample: - impt_weight = Batch( - impt_weight=1 / np.power( - self._size * (batch.weight / self._weight_sum), - self._beta)) - batch.cat_(impt_weight) - self._check_weight_sum() + impt_weight = Batch( + impt_weight=(self._size * p) ** (-self._beta)) + batch.cat_(impt_weight) return batch, indice def reset(self) -> None: - self._amortization_counter = 0 super().reset() def update_weight(self, indice: Union[slice, np.ndarray], @@ -436,8 +433,6 @@ class PrioritizedReplayBuffer(ReplayBuffer): indice, unique_indice = np.unique( indice, return_index=True) new_weight = new_weight[unique_indice] - self._weight_sum += np.power(np.abs(new_weight), self._alpha).sum() \ - - self.weight[indice].sum() self.weight[indice] = np.power(np.abs(new_weight), self._alpha) def __getitem__(self, index: Union[ @@ -452,10 +447,3 @@ class PrioritizedReplayBuffer(ReplayBuffer): weight=self.weight[index], policy=self.get(index, 'policy'), ) - - def _check_weight_sum(self) -> None: - # keep an accurate _weight_sum - self._amortization_counter += 1 - if self._amortization_counter % self._amortization_freq == 0: - self._weight_sum = np.sum(self.weight) - self._amortization_counter = 0