Replay buffer allows stack_num = 1 (#165)
* stack_num starts at 1 (for no stacking) instead of 0. * Use getter/stepper for stack_num. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
parent
38a95c19da
commit
30368c29a6
@ -104,7 +104,7 @@ class ReplayBuffer:
|
||||
|
||||
:param int size: the size of replay buffer.
|
||||
:param int stack_num: the frame-stack sampling argument, should be greater
|
||||
than 1, defaults to 0 (no stacking).
|
||||
than or equal to 1, defaults to 1 (no stacking).
|
||||
:param bool ignore_obs_next: whether to store obs_next, defaults to
|
||||
``False``.
|
||||
:param bool sample_avail: the parameter indicating sampling only available
|
||||
@ -112,13 +112,13 @@ class ReplayBuffer:
|
||||
This feature is not supported in Prioritized Replay Buffer currently.
|
||||
"""
|
||||
|
||||
def __init__(self, size: int, stack_num: Optional[int] = 0,
|
||||
def __init__(self, size: int, stack_num: int = 1,
|
||||
ignore_obs_next: bool = False,
|
||||
sample_avail: bool = False, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self._maxsize = size
|
||||
self._stack = stack_num
|
||||
assert stack_num != 1, 'stack_num should greater than 1'
|
||||
self._stack = None
|
||||
self.stack_num = stack_num
|
||||
self._avail = sample_avail and stack_num > 1
|
||||
self._avail_index = []
|
||||
self._save_s_ = not ignore_obs_next
|
||||
@ -157,10 +157,13 @@ class ReplayBuffer:
|
||||
value.__dict__[key] = _create_value(inst[key], self._maxsize)
|
||||
value[self._index] = inst
|
||||
|
||||
def _get_stack_num(self):
|
||||
@property
|
||||
def stack_num(self):
|
||||
return self._stack
|
||||
|
||||
def _set_stack_num(self, num):
|
||||
@stack_num.setter
|
||||
def stack_num(self, num):
|
||||
assert num > 0, 'stack_num should greater than 0'
|
||||
self._stack = num
|
||||
|
||||
def update(self, buffer: 'ReplayBuffer') -> None:
|
||||
@ -168,14 +171,14 @@ class ReplayBuffer:
|
||||
if len(buffer) == 0:
|
||||
return
|
||||
i = begin = buffer._index % len(buffer)
|
||||
origin = buffer._get_stack_num()
|
||||
buffer._set_stack_num(0)
|
||||
stack_num_orig = buffer.stack_num
|
||||
buffer.stack_num = 1
|
||||
while True:
|
||||
self.add(**buffer[i])
|
||||
i = (i + 1) % len(buffer)
|
||||
if i == begin:
|
||||
break
|
||||
buffer._set_stack_num(origin)
|
||||
buffer.stack_num = stack_num_orig
|
||||
|
||||
def add(self,
|
||||
obs: Union[dict, Batch, np.ndarray],
|
||||
@ -204,15 +207,15 @@ class ReplayBuffer:
|
||||
if self._avail:
|
||||
# update current frame
|
||||
avail = sum(self.done[i] for i in range(
|
||||
self._index - self._stack + 1, self._index)) == 0
|
||||
if self._size < self._stack - 1:
|
||||
self._index - self.stack_num + 1, self._index)) == 0
|
||||
if self._size < self.stack_num - 1:
|
||||
avail = False
|
||||
if avail and self._index not in self._avail_index:
|
||||
self._avail_index.append(self._index)
|
||||
elif not avail and self._index in self._avail_index:
|
||||
self._avail_index.remove(self._index)
|
||||
# remove the later available frame because of broken storage
|
||||
t = (self._index + self._stack - 1) % self._maxsize
|
||||
t = (self._index + self.stack_num - 1) % self._maxsize
|
||||
if t in self._avail_index:
|
||||
self._avail_index.remove(t)
|
||||
|
||||
@ -255,7 +258,7 @@ class ReplayBuffer:
|
||||
given from buffer initialization procedure.
|
||||
"""
|
||||
if stack_num is None:
|
||||
stack_num = self._stack
|
||||
stack_num = self.stack_num
|
||||
if isinstance(indice, slice):
|
||||
indice = np.arange(
|
||||
0 if indice.start is None
|
||||
@ -276,7 +279,7 @@ class ReplayBuffer:
|
||||
key = 'obs'
|
||||
val = self._meta.__dict__[key]
|
||||
try:
|
||||
if stack_num > 0:
|
||||
if stack_num > 1:
|
||||
stack = []
|
||||
for _ in range(stack_num):
|
||||
stack = [val[indice]] + stack
|
||||
@ -300,7 +303,7 @@ class ReplayBuffer:
|
||||
|
||||
def __getitem__(self, index: Union[
|
||||
slice, int, np.integer, np.ndarray]) -> Batch:
|
||||
"""Return a data batch: self[index]. If stack_num is set to be > 0,
|
||||
"""Return a data batch: self[index]. If stack_num is larger than 1,
|
||||
return the stacked obs and obs_next with shape [batch, len, ...].
|
||||
"""
|
||||
return Batch(
|
||||
|
Loading…
x
Reference in New Issue
Block a user