From 30368c29a6855417df2713f75746e715c38897bf Mon Sep 17 00:00:00 2001 From: Alexis DUBURCQ Date: Sat, 25 Jul 2020 13:33:44 +0200 Subject: [PATCH] Replay buffer allows stack_num = 1 (#165) * stack_num starts at 1 (for no stacking) instead of 0. * Use getter/stepper for stack_num. Co-authored-by: Alexis Duburcq --- tianshou/data/buffer.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index b7ddcff..4491bee 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -104,7 +104,7 @@ class ReplayBuffer: :param int size: the size of replay buffer. :param int stack_num: the frame-stack sampling argument, should be greater - than 1, defaults to 0 (no stacking). + than or equal to 1, defaults to 1 (no stacking). :param bool ignore_obs_next: whether to store obs_next, defaults to ``False``. :param bool sample_avail: the parameter indicating sampling only available @@ -112,13 +112,13 @@ class ReplayBuffer: This feature is not supported in Prioritized Replay Buffer currently. """ - def __init__(self, size: int, stack_num: Optional[int] = 0, + def __init__(self, size: int, stack_num: int = 1, ignore_obs_next: bool = False, sample_avail: bool = False, **kwargs) -> None: super().__init__() self._maxsize = size - self._stack = stack_num - assert stack_num != 1, 'stack_num should greater than 1' + self._stack = None + self.stack_num = stack_num self._avail = sample_avail and stack_num > 1 self._avail_index = [] self._save_s_ = not ignore_obs_next @@ -157,10 +157,13 @@ class ReplayBuffer: value.__dict__[key] = _create_value(inst[key], self._maxsize) value[self._index] = inst - def _get_stack_num(self): + @property + def stack_num(self): return self._stack - def _set_stack_num(self, num): + @stack_num.setter + def stack_num(self, num): + assert num > 0, 'stack_num should greater than 0' self._stack = num def update(self, buffer: 'ReplayBuffer') -> None: @@ -168,14 +171,14 @@ class ReplayBuffer: if len(buffer) == 0: return i = begin = buffer._index % len(buffer) - origin = buffer._get_stack_num() - buffer._set_stack_num(0) + stack_num_orig = buffer.stack_num + buffer.stack_num = 1 while True: self.add(**buffer[i]) i = (i + 1) % len(buffer) if i == begin: break - buffer._set_stack_num(origin) + buffer.stack_num = stack_num_orig def add(self, obs: Union[dict, Batch, np.ndarray], @@ -204,15 +207,15 @@ class ReplayBuffer: if self._avail: # update current frame avail = sum(self.done[i] for i in range( - self._index - self._stack + 1, self._index)) == 0 - if self._size < self._stack - 1: + self._index - self.stack_num + 1, self._index)) == 0 + if self._size < self.stack_num - 1: avail = False if avail and self._index not in self._avail_index: self._avail_index.append(self._index) elif not avail and self._index in self._avail_index: self._avail_index.remove(self._index) # remove the later available frame because of broken storage - t = (self._index + self._stack - 1) % self._maxsize + t = (self._index + self.stack_num - 1) % self._maxsize if t in self._avail_index: self._avail_index.remove(t) @@ -255,7 +258,7 @@ class ReplayBuffer: given from buffer initialization procedure. """ if stack_num is None: - stack_num = self._stack + stack_num = self.stack_num if isinstance(indice, slice): indice = np.arange( 0 if indice.start is None @@ -276,7 +279,7 @@ class ReplayBuffer: key = 'obs' val = self._meta.__dict__[key] try: - if stack_num > 0: + if stack_num > 1: stack = [] for _ in range(stack_num): stack = [val[indice]] + stack @@ -300,7 +303,7 @@ class ReplayBuffer: def __getitem__(self, index: Union[ slice, int, np.integer, np.ndarray]) -> Batch: - """Return a data batch: self[index]. If stack_num is set to be > 0, + """Return a data batch: self[index]. If stack_num is larger than 1, return the stacked obs and obs_next with shape [batch, len, ...]. """ return Batch(