Replay buffer allows stack_num = 1 (#165)

* stack_num starts at 1 (for no stacking) instead of 0.

* Use getter/stepper for stack_num.

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
Alexis DUBURCQ 2020-07-25 13:33:44 +02:00 committed by GitHub
parent 38a95c19da
commit 30368c29a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -104,7 +104,7 @@ class ReplayBuffer:
:param int size: the size of replay buffer.
:param int stack_num: the frame-stack sampling argument, should be greater
than 1, defaults to 0 (no stacking).
than or equal to 1, defaults to 1 (no stacking).
:param bool ignore_obs_next: whether to store obs_next, defaults to
``False``.
:param bool sample_avail: the parameter indicating sampling only available
@ -112,13 +112,13 @@ class ReplayBuffer:
This feature is not supported in Prioritized Replay Buffer currently.
"""
def __init__(self, size: int, stack_num: Optional[int] = 0,
def __init__(self, size: int, stack_num: int = 1,
ignore_obs_next: bool = False,
sample_avail: bool = False, **kwargs) -> None:
super().__init__()
self._maxsize = size
self._stack = stack_num
assert stack_num != 1, 'stack_num should greater than 1'
self._stack = None
self.stack_num = stack_num
self._avail = sample_avail and stack_num > 1
self._avail_index = []
self._save_s_ = not ignore_obs_next
@ -157,10 +157,13 @@ class ReplayBuffer:
value.__dict__[key] = _create_value(inst[key], self._maxsize)
value[self._index] = inst
def _get_stack_num(self):
@property
def stack_num(self):
return self._stack
def _set_stack_num(self, num):
@stack_num.setter
def stack_num(self, num):
assert num > 0, 'stack_num should greater than 0'
self._stack = num
def update(self, buffer: 'ReplayBuffer') -> None:
@ -168,14 +171,14 @@ class ReplayBuffer:
if len(buffer) == 0:
return
i = begin = buffer._index % len(buffer)
origin = buffer._get_stack_num()
buffer._set_stack_num(0)
stack_num_orig = buffer.stack_num
buffer.stack_num = 1
while True:
self.add(**buffer[i])
i = (i + 1) % len(buffer)
if i == begin:
break
buffer._set_stack_num(origin)
buffer.stack_num = stack_num_orig
def add(self,
obs: Union[dict, Batch, np.ndarray],
@ -204,15 +207,15 @@ class ReplayBuffer:
if self._avail:
# update current frame
avail = sum(self.done[i] for i in range(
self._index - self._stack + 1, self._index)) == 0
if self._size < self._stack - 1:
self._index - self.stack_num + 1, self._index)) == 0
if self._size < self.stack_num - 1:
avail = False
if avail and self._index not in self._avail_index:
self._avail_index.append(self._index)
elif not avail and self._index in self._avail_index:
self._avail_index.remove(self._index)
# remove the later available frame because of broken storage
t = (self._index + self._stack - 1) % self._maxsize
t = (self._index + self.stack_num - 1) % self._maxsize
if t in self._avail_index:
self._avail_index.remove(t)
@ -255,7 +258,7 @@ class ReplayBuffer:
given from buffer initialization procedure.
"""
if stack_num is None:
stack_num = self._stack
stack_num = self.stack_num
if isinstance(indice, slice):
indice = np.arange(
0 if indice.start is None
@ -276,7 +279,7 @@ class ReplayBuffer:
key = 'obs'
val = self._meta.__dict__[key]
try:
if stack_num > 0:
if stack_num > 1:
stack = []
for _ in range(stack_num):
stack = [val[indice]] + stack
@ -300,7 +303,7 @@ class ReplayBuffer:
def __getitem__(self, index: Union[
slice, int, np.integer, np.ndarray]) -> Batch:
"""Return a data batch: self[index]. If stack_num is set to be > 0,
"""Return a data batch: self[index]. If stack_num is larger than 1,
return the stacked obs and obs_next with shape [batch, len, ...].
"""
return Batch(