bugfix for update with empty buffer; remove duplicate variable _weight_sum in PrioritizedReplayBuffer (#120)

* bugfix for update with empty buffer; remove duplicate variable _weight_sum in PrioritizedReplayBuffer

* point out that ListReplayBuffer cannot be sampled

* remove useless _amortization_counter variable
This commit is contained in:
youkaichao 2020-07-10 08:24:11 +08:00 committed by GitHub
parent e767de044b
commit ff99662fe6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 40 deletions

View File

@ -92,20 +92,16 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
obs_next, rew, done, info = env.step(a)
buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5)
obs = obs_next
assert np.isclose(np.sum((buf.weight / buf._weight_sum)[:buf._size]),
1, rtol=1e-12)
data, indice = buf.sample(len(buf) // 2)
if len(buf) // 2 == 0:
assert len(data) == len(buf)
else:
assert len(data) == len(buf) // 2
assert len(buf) == min(bufsize, i + 1)
assert np.isclose(buf._weight_sum, (buf.weight).sum())
data, indice = buf.sample(len(buf) // 2)
buf.update_weight(indice, -data.weight / 2)
assert np.isclose(buf.weight[indice], np.power(
np.abs(-data.weight / 2), buf._alpha)).all()
assert np.isclose(buf._weight_sum, (buf.weight).sum())
assert np.allclose(
buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha)
if __name__ == '__main__':

View File

@ -151,6 +151,8 @@ class ReplayBuffer:
def update(self, buffer: 'ReplayBuffer') -> None:
"""Move the data from the given buffer to self."""
if len(buffer) == 0:
return
i = begin = buffer._index % len(buffer)
while True:
self.add(**buffer[i])
@ -298,7 +300,9 @@ class ReplayBuffer:
class ListReplayBuffer(ReplayBuffer):
"""The function of :class:`~tianshou.data.ListReplayBuffer` is almost the
same as :class:`~tianshou.data.ReplayBuffer`. The only difference is that
:class:`~tianshou.data.ListReplayBuffer` is based on ``list``.
:class:`~tianshou.data.ListReplayBuffer` is based on ``list``. Therefore,
it does not support advanced indexing, which means you cannot sample a
batch of data out of it. It is typically used for storing data.
.. seealso::
@ -309,6 +313,9 @@ class ListReplayBuffer(ReplayBuffer):
def __init__(self, **kwargs) -> None:
super().__init__(size=0, ignore_obs_next=False, **kwargs)
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
raise NotImplementedError("ListReplayBuffer cannot be sampled!")
def _add_to_buffer(
self, name: str,
inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None:
@ -349,7 +356,6 @@ class PrioritizedReplayBuffer(ReplayBuffer):
self._beta = beta
self._weight_sum = 0.0
self._amortization_freq = 50
self._amortization_counter = 0
self._replace = replace
self._meta.__dict__['weight'] = np.zeros(size, dtype=np.float64)
@ -369,7 +375,6 @@ class PrioritizedReplayBuffer(ReplayBuffer):
self._meta.__dict__['weight'][self._index]
self._add_to_buffer('weight', np.abs(weight) ** self._alpha)
super().add(obs, act, rew, done, obs_next, info, policy)
self._check_weight_sum()
@property
def replace(self):
@ -379,46 +384,38 @@ class PrioritizedReplayBuffer(ReplayBuffer):
def replace(self, v: bool):
self._replace = v
def sample(self, batch_size: int,
importance_sample: bool = True
) -> Tuple[Batch, np.ndarray]:
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
"""Get a random sample from buffer with priority probability. \
Return all the data in the buffer if batch_size is ``0``.
:return: Sample data and its corresponding index inside the buffer.
"""
if batch_size > 0 and batch_size <= self._size:
# Multiple sampling of the same sample
# will cause weight update conflict
assert self._size > 0, 'cannot sample a buffer with size == 0 !'
p = None
if batch_size > 0 and (self._replace or batch_size <= self._size):
# sampling weight
p = (self.weight / self.weight.sum())[:self._size]
indice = np.random.choice(
self._size, batch_size,
p=(self.weight / self.weight.sum())[:self._size],
self._size, batch_size, p=p,
replace=self._replace)
# self._weight_sum is not work for the accuracy issue
# p=(self.weight/self._weight_sum)[:self._size], replace=False)
p = p[indice] # weight of each sample
elif batch_size == 0:
p = np.full(shape=self._size, fill_value=1.0/self._size)
indice = np.concatenate([
np.arange(self._index, self._size),
np.arange(0, self._index),
])
else:
# if batch_size larger than len(self),
# it will lead to a bug in update weight
raise ValueError(
"batch_size should be less than len(self), \
or set replace=False")
f"batch_size should be less than {len(self)}, \
or set replace=True")
batch = self[indice]
if importance_sample:
impt_weight = Batch(
impt_weight=1 / np.power(
self._size * (batch.weight / self._weight_sum),
self._beta))
batch.cat_(impt_weight)
self._check_weight_sum()
impt_weight = Batch(
impt_weight=(self._size * p) ** (-self._beta))
batch.cat_(impt_weight)
return batch, indice
def reset(self) -> None:
self._amortization_counter = 0
super().reset()
def update_weight(self, indice: Union[slice, np.ndarray],
@ -436,8 +433,6 @@ class PrioritizedReplayBuffer(ReplayBuffer):
indice, unique_indice = np.unique(
indice, return_index=True)
new_weight = new_weight[unique_indice]
self._weight_sum += np.power(np.abs(new_weight), self._alpha).sum() \
- self.weight[indice].sum()
self.weight[indice] = np.power(np.abs(new_weight), self._alpha)
def __getitem__(self, index: Union[
@ -452,10 +447,3 @@ class PrioritizedReplayBuffer(ReplayBuffer):
weight=self.weight[index],
policy=self.get(index, 'policy'),
)
def _check_weight_sum(self) -> None:
# keep an accurate _weight_sum
self._amortization_counter += 1
if self._amortization_counter % self._amortization_freq == 0:
self._weight_sum = np.sum(self.weight)
self._amortization_counter = 0