bugfix for update with empty buffer; remove duplicate variable _weight_sum in PrioritizedReplayBuffer (#120)
* bugfix for update with empty buffer; remove duplicate variable _weight_sum in PrioritizedReplayBuffer * point out that ListReplayBuffer cannot be sampled * remove useless _amortization_counter variable
This commit is contained in:
parent
e767de044b
commit
ff99662fe6
@ -92,20 +92,16 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
|
||||
obs_next, rew, done, info = env.step(a)
|
||||
buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5)
|
||||
obs = obs_next
|
||||
assert np.isclose(np.sum((buf.weight / buf._weight_sum)[:buf._size]),
|
||||
1, rtol=1e-12)
|
||||
data, indice = buf.sample(len(buf) // 2)
|
||||
if len(buf) // 2 == 0:
|
||||
assert len(data) == len(buf)
|
||||
else:
|
||||
assert len(data) == len(buf) // 2
|
||||
assert len(buf) == min(bufsize, i + 1)
|
||||
assert np.isclose(buf._weight_sum, (buf.weight).sum())
|
||||
data, indice = buf.sample(len(buf) // 2)
|
||||
buf.update_weight(indice, -data.weight / 2)
|
||||
assert np.isclose(buf.weight[indice], np.power(
|
||||
np.abs(-data.weight / 2), buf._alpha)).all()
|
||||
assert np.isclose(buf._weight_sum, (buf.weight).sum())
|
||||
assert np.allclose(
|
||||
buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -151,6 +151,8 @@ class ReplayBuffer:
|
||||
|
||||
def update(self, buffer: 'ReplayBuffer') -> None:
|
||||
"""Move the data from the given buffer to self."""
|
||||
if len(buffer) == 0:
|
||||
return
|
||||
i = begin = buffer._index % len(buffer)
|
||||
while True:
|
||||
self.add(**buffer[i])
|
||||
@ -298,7 +300,9 @@ class ReplayBuffer:
|
||||
class ListReplayBuffer(ReplayBuffer):
|
||||
"""The function of :class:`~tianshou.data.ListReplayBuffer` is almost the
|
||||
same as :class:`~tianshou.data.ReplayBuffer`. The only difference is that
|
||||
:class:`~tianshou.data.ListReplayBuffer` is based on ``list``.
|
||||
:class:`~tianshou.data.ListReplayBuffer` is based on ``list``. Therefore,
|
||||
it does not support advanced indexing, which means you cannot sample a
|
||||
batch of data out of it. It is typically used for storing data.
|
||||
|
||||
.. seealso::
|
||||
|
||||
@ -309,6 +313,9 @@ class ListReplayBuffer(ReplayBuffer):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(size=0, ignore_obs_next=False, **kwargs)
|
||||
|
||||
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
|
||||
raise NotImplementedError("ListReplayBuffer cannot be sampled!")
|
||||
|
||||
def _add_to_buffer(
|
||||
self, name: str,
|
||||
inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None:
|
||||
@ -349,7 +356,6 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
self._beta = beta
|
||||
self._weight_sum = 0.0
|
||||
self._amortization_freq = 50
|
||||
self._amortization_counter = 0
|
||||
self._replace = replace
|
||||
self._meta.__dict__['weight'] = np.zeros(size, dtype=np.float64)
|
||||
|
||||
@ -369,7 +375,6 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
self._meta.__dict__['weight'][self._index]
|
||||
self._add_to_buffer('weight', np.abs(weight) ** self._alpha)
|
||||
super().add(obs, act, rew, done, obs_next, info, policy)
|
||||
self._check_weight_sum()
|
||||
|
||||
@property
|
||||
def replace(self):
|
||||
@ -379,46 +384,38 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
def replace(self, v: bool):
|
||||
self._replace = v
|
||||
|
||||
def sample(self, batch_size: int,
|
||||
importance_sample: bool = True
|
||||
) -> Tuple[Batch, np.ndarray]:
|
||||
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
|
||||
"""Get a random sample from buffer with priority probability. \
|
||||
Return all the data in the buffer if batch_size is ``0``.
|
||||
|
||||
:return: Sample data and its corresponding index inside the buffer.
|
||||
"""
|
||||
if batch_size > 0 and batch_size <= self._size:
|
||||
# Multiple sampling of the same sample
|
||||
# will cause weight update conflict
|
||||
assert self._size > 0, 'cannot sample a buffer with size == 0 !'
|
||||
p = None
|
||||
if batch_size > 0 and (self._replace or batch_size <= self._size):
|
||||
# sampling weight
|
||||
p = (self.weight / self.weight.sum())[:self._size]
|
||||
indice = np.random.choice(
|
||||
self._size, batch_size,
|
||||
p=(self.weight / self.weight.sum())[:self._size],
|
||||
self._size, batch_size, p=p,
|
||||
replace=self._replace)
|
||||
# self._weight_sum is not work for the accuracy issue
|
||||
# p=(self.weight/self._weight_sum)[:self._size], replace=False)
|
||||
p = p[indice] # weight of each sample
|
||||
elif batch_size == 0:
|
||||
p = np.full(shape=self._size, fill_value=1.0/self._size)
|
||||
indice = np.concatenate([
|
||||
np.arange(self._index, self._size),
|
||||
np.arange(0, self._index),
|
||||
])
|
||||
else:
|
||||
# if batch_size larger than len(self),
|
||||
# it will lead to a bug in update weight
|
||||
raise ValueError(
|
||||
"batch_size should be less than len(self), \
|
||||
or set replace=False")
|
||||
f"batch_size should be less than {len(self)}, \
|
||||
or set replace=True")
|
||||
batch = self[indice]
|
||||
if importance_sample:
|
||||
impt_weight = Batch(
|
||||
impt_weight=1 / np.power(
|
||||
self._size * (batch.weight / self._weight_sum),
|
||||
self._beta))
|
||||
batch.cat_(impt_weight)
|
||||
self._check_weight_sum()
|
||||
impt_weight = Batch(
|
||||
impt_weight=(self._size * p) ** (-self._beta))
|
||||
batch.cat_(impt_weight)
|
||||
return batch, indice
|
||||
|
||||
def reset(self) -> None:
|
||||
self._amortization_counter = 0
|
||||
super().reset()
|
||||
|
||||
def update_weight(self, indice: Union[slice, np.ndarray],
|
||||
@ -436,8 +433,6 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
indice, unique_indice = np.unique(
|
||||
indice, return_index=True)
|
||||
new_weight = new_weight[unique_indice]
|
||||
self._weight_sum += np.power(np.abs(new_weight), self._alpha).sum() \
|
||||
- self.weight[indice].sum()
|
||||
self.weight[indice] = np.power(np.abs(new_weight), self._alpha)
|
||||
|
||||
def __getitem__(self, index: Union[
|
||||
@ -452,10 +447,3 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
weight=self.weight[index],
|
||||
policy=self.get(index, 'policy'),
|
||||
)
|
||||
|
||||
def _check_weight_sum(self) -> None:
|
||||
# keep an accurate _weight_sum
|
||||
self._amortization_counter += 1
|
||||
if self._amortization_counter % self._amortization_freq == 0:
|
||||
self._weight_sum = np.sum(self.weight)
|
||||
self._amortization_counter = 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user