bugfix for update with empty buffer; remove duplicate variable _weight_sum in PrioritizedReplayBuffer (#120)
* bugfix for update with empty buffer; remove duplicate variable _weight_sum in PrioritizedReplayBuffer * point out that ListReplayBuffer cannot be sampled * remove useless _amortization_counter variable
This commit is contained in:
parent
e767de044b
commit
ff99662fe6
@ -92,20 +92,16 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
|
|||||||
obs_next, rew, done, info = env.step(a)
|
obs_next, rew, done, info = env.step(a)
|
||||||
buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5)
|
buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5)
|
||||||
obs = obs_next
|
obs = obs_next
|
||||||
assert np.isclose(np.sum((buf.weight / buf._weight_sum)[:buf._size]),
|
|
||||||
1, rtol=1e-12)
|
|
||||||
data, indice = buf.sample(len(buf) // 2)
|
data, indice = buf.sample(len(buf) // 2)
|
||||||
if len(buf) // 2 == 0:
|
if len(buf) // 2 == 0:
|
||||||
assert len(data) == len(buf)
|
assert len(data) == len(buf)
|
||||||
else:
|
else:
|
||||||
assert len(data) == len(buf) // 2
|
assert len(data) == len(buf) // 2
|
||||||
assert len(buf) == min(bufsize, i + 1)
|
assert len(buf) == min(bufsize, i + 1)
|
||||||
assert np.isclose(buf._weight_sum, (buf.weight).sum())
|
|
||||||
data, indice = buf.sample(len(buf) // 2)
|
data, indice = buf.sample(len(buf) // 2)
|
||||||
buf.update_weight(indice, -data.weight / 2)
|
buf.update_weight(indice, -data.weight / 2)
|
||||||
assert np.isclose(buf.weight[indice], np.power(
|
assert np.allclose(
|
||||||
np.abs(-data.weight / 2), buf._alpha)).all()
|
buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha)
|
||||||
assert np.isclose(buf._weight_sum, (buf.weight).sum())
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -151,6 +151,8 @@ class ReplayBuffer:
|
|||||||
|
|
||||||
def update(self, buffer: 'ReplayBuffer') -> None:
|
def update(self, buffer: 'ReplayBuffer') -> None:
|
||||||
"""Move the data from the given buffer to self."""
|
"""Move the data from the given buffer to self."""
|
||||||
|
if len(buffer) == 0:
|
||||||
|
return
|
||||||
i = begin = buffer._index % len(buffer)
|
i = begin = buffer._index % len(buffer)
|
||||||
while True:
|
while True:
|
||||||
self.add(**buffer[i])
|
self.add(**buffer[i])
|
||||||
@ -298,7 +300,9 @@ class ReplayBuffer:
|
|||||||
class ListReplayBuffer(ReplayBuffer):
|
class ListReplayBuffer(ReplayBuffer):
|
||||||
"""The function of :class:`~tianshou.data.ListReplayBuffer` is almost the
|
"""The function of :class:`~tianshou.data.ListReplayBuffer` is almost the
|
||||||
same as :class:`~tianshou.data.ReplayBuffer`. The only difference is that
|
same as :class:`~tianshou.data.ReplayBuffer`. The only difference is that
|
||||||
:class:`~tianshou.data.ListReplayBuffer` is based on ``list``.
|
:class:`~tianshou.data.ListReplayBuffer` is based on ``list``. Therefore,
|
||||||
|
it does not support advanced indexing, which means you cannot sample a
|
||||||
|
batch of data out of it. It is typically used for storing data.
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
@ -309,6 +313,9 @@ class ListReplayBuffer(ReplayBuffer):
|
|||||||
def __init__(self, **kwargs) -> None:
|
def __init__(self, **kwargs) -> None:
|
||||||
super().__init__(size=0, ignore_obs_next=False, **kwargs)
|
super().__init__(size=0, ignore_obs_next=False, **kwargs)
|
||||||
|
|
||||||
|
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
|
||||||
|
raise NotImplementedError("ListReplayBuffer cannot be sampled!")
|
||||||
|
|
||||||
def _add_to_buffer(
|
def _add_to_buffer(
|
||||||
self, name: str,
|
self, name: str,
|
||||||
inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None:
|
inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None:
|
||||||
@ -349,7 +356,6 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
self._beta = beta
|
self._beta = beta
|
||||||
self._weight_sum = 0.0
|
self._weight_sum = 0.0
|
||||||
self._amortization_freq = 50
|
self._amortization_freq = 50
|
||||||
self._amortization_counter = 0
|
|
||||||
self._replace = replace
|
self._replace = replace
|
||||||
self._meta.__dict__['weight'] = np.zeros(size, dtype=np.float64)
|
self._meta.__dict__['weight'] = np.zeros(size, dtype=np.float64)
|
||||||
|
|
||||||
@ -369,7 +375,6 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
self._meta.__dict__['weight'][self._index]
|
self._meta.__dict__['weight'][self._index]
|
||||||
self._add_to_buffer('weight', np.abs(weight) ** self._alpha)
|
self._add_to_buffer('weight', np.abs(weight) ** self._alpha)
|
||||||
super().add(obs, act, rew, done, obs_next, info, policy)
|
super().add(obs, act, rew, done, obs_next, info, policy)
|
||||||
self._check_weight_sum()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def replace(self):
|
def replace(self):
|
||||||
@ -379,46 +384,38 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
def replace(self, v: bool):
|
def replace(self, v: bool):
|
||||||
self._replace = v
|
self._replace = v
|
||||||
|
|
||||||
def sample(self, batch_size: int,
|
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
|
||||||
importance_sample: bool = True
|
|
||||||
) -> Tuple[Batch, np.ndarray]:
|
|
||||||
"""Get a random sample from buffer with priority probability. \
|
"""Get a random sample from buffer with priority probability. \
|
||||||
Return all the data in the buffer if batch_size is ``0``.
|
Return all the data in the buffer if batch_size is ``0``.
|
||||||
|
|
||||||
:return: Sample data and its corresponding index inside the buffer.
|
:return: Sample data and its corresponding index inside the buffer.
|
||||||
"""
|
"""
|
||||||
if batch_size > 0 and batch_size <= self._size:
|
assert self._size > 0, 'cannot sample a buffer with size == 0 !'
|
||||||
# Multiple sampling of the same sample
|
p = None
|
||||||
# will cause weight update conflict
|
if batch_size > 0 and (self._replace or batch_size <= self._size):
|
||||||
|
# sampling weight
|
||||||
|
p = (self.weight / self.weight.sum())[:self._size]
|
||||||
indice = np.random.choice(
|
indice = np.random.choice(
|
||||||
self._size, batch_size,
|
self._size, batch_size, p=p,
|
||||||
p=(self.weight / self.weight.sum())[:self._size],
|
|
||||||
replace=self._replace)
|
replace=self._replace)
|
||||||
# self._weight_sum is not work for the accuracy issue
|
p = p[indice] # weight of each sample
|
||||||
# p=(self.weight/self._weight_sum)[:self._size], replace=False)
|
|
||||||
elif batch_size == 0:
|
elif batch_size == 0:
|
||||||
|
p = np.full(shape=self._size, fill_value=1.0/self._size)
|
||||||
indice = np.concatenate([
|
indice = np.concatenate([
|
||||||
np.arange(self._index, self._size),
|
np.arange(self._index, self._size),
|
||||||
np.arange(0, self._index),
|
np.arange(0, self._index),
|
||||||
])
|
])
|
||||||
else:
|
else:
|
||||||
# if batch_size larger than len(self),
|
|
||||||
# it will lead to a bug in update weight
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"batch_size should be less than len(self), \
|
f"batch_size should be less than {len(self)}, \
|
||||||
or set replace=False")
|
or set replace=True")
|
||||||
batch = self[indice]
|
batch = self[indice]
|
||||||
if importance_sample:
|
impt_weight = Batch(
|
||||||
impt_weight = Batch(
|
impt_weight=(self._size * p) ** (-self._beta))
|
||||||
impt_weight=1 / np.power(
|
batch.cat_(impt_weight)
|
||||||
self._size * (batch.weight / self._weight_sum),
|
|
||||||
self._beta))
|
|
||||||
batch.cat_(impt_weight)
|
|
||||||
self._check_weight_sum()
|
|
||||||
return batch, indice
|
return batch, indice
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self._amortization_counter = 0
|
|
||||||
super().reset()
|
super().reset()
|
||||||
|
|
||||||
def update_weight(self, indice: Union[slice, np.ndarray],
|
def update_weight(self, indice: Union[slice, np.ndarray],
|
||||||
@ -436,8 +433,6 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
indice, unique_indice = np.unique(
|
indice, unique_indice = np.unique(
|
||||||
indice, return_index=True)
|
indice, return_index=True)
|
||||||
new_weight = new_weight[unique_indice]
|
new_weight = new_weight[unique_indice]
|
||||||
self._weight_sum += np.power(np.abs(new_weight), self._alpha).sum() \
|
|
||||||
- self.weight[indice].sum()
|
|
||||||
self.weight[indice] = np.power(np.abs(new_weight), self._alpha)
|
self.weight[indice] = np.power(np.abs(new_weight), self._alpha)
|
||||||
|
|
||||||
def __getitem__(self, index: Union[
|
def __getitem__(self, index: Union[
|
||||||
@ -452,10 +447,3 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
weight=self.weight[index],
|
weight=self.weight[index],
|
||||||
policy=self.get(index, 'policy'),
|
policy=self.get(index, 'policy'),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_weight_sum(self) -> None:
|
|
||||||
# keep an accurate _weight_sum
|
|
||||||
self._amortization_counter += 1
|
|
||||||
if self._amortization_counter % self._amortization_freq == 0:
|
|
||||||
self._weight_sum = np.sum(self.weight)
|
|
||||||
self._amortization_counter = 0
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user