From d09b69e594c2bd1d463322d95a448b9df76ef17c Mon Sep 17 00:00:00 2001 From: ChenDRAG <40993476+ChenDRAG@users.noreply.github.com> Date: Mon, 20 Jul 2020 22:12:57 +0800 Subject: [PATCH] buffer update bug fix (#154) * buffer update bug fix * some fix in buffer update * polish Co-authored-by: n+e <463003665@qq.com> --- test/base/test_buffer.py | 23 ++++++++++++++++------- tianshou/data/buffer.py | 11 ++++++++++- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 28ccd88..6178a32 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1,5 +1,6 @@ import numpy as np -from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer + +from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer if __name__ == '__main__': from env import MyTestEnv @@ -10,7 +11,6 @@ else: # pytest def test_replaybuffer(size=10, bufsize=20): env = MyTestEnv(size) buf = ReplayBuffer(bufsize) - buf2 = ReplayBuffer(bufsize) obs = env.reset() action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): @@ -22,11 +22,6 @@ def test_replaybuffer(size=10, bufsize=20): assert (indice < len(buf)).all() assert (data.obs < size).all() assert (0 <= data.done).all() and (data.done <= 1).all() - assert len(buf) > len(buf2) - buf2.update(buf) - assert len(buf) == len(buf2) - assert buf2[0].obs == buf[5].obs - assert buf2[-1].obs == buf[4].obs b = ReplayBuffer(size=10) b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}}) assert b.obs[0] == 1 @@ -104,8 +99,22 @@ def test_priortized_replaybuffer(size=32, bufsize=15): buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha) +def test_update(): + buf1 = ReplayBuffer(4, stack_num=2) + buf2 = ReplayBuffer(4, stack_num=2) + for i in range(5): + buf1.add(obs=np.array([i]), act=float(i), rew=i * i, + done=False, info={'incident': 'found'}) + assert len(buf1) > len(buf2) + buf2.update(buf1) + assert len(buf1) == len(buf2) + assert (buf2[0].obs == buf1[1].obs).all() + assert (buf2[-1].obs == buf1[0].obs).all() + + if __name__ == '__main__': test_replaybuffer() test_ignore_obs_next() test_stack() test_priortized_replaybuffer(233333, 200000) + test_update() diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index a2e658d..b7ddcff 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -157,16 +157,25 @@ class ReplayBuffer: value.__dict__[key] = _create_value(inst[key], self._maxsize) value[self._index] = inst + def _get_stack_num(self): + return self._stack + + def _set_stack_num(self, num): + self._stack = num + def update(self, buffer: 'ReplayBuffer') -> None: """Move the data from the given buffer to self.""" if len(buffer) == 0: return i = begin = buffer._index % len(buffer) + origin = buffer._get_stack_num() + buffer._set_stack_num(0) while True: self.add(**buffer[i]) i = (i + 1) % len(buffer) if i == begin: break + buffer._set_stack_num(origin) def add(self, obs: Union[dict, Batch, np.ndarray], @@ -408,7 +417,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): replace=self._replace) p = p[indice] # weight of each sample elif batch_size == 0: - p = np.full(shape=self._size, fill_value=1.0/self._size) + p = np.full(shape=self._size, fill_value=1.0 / self._size) indice = np.concatenate([ np.arange(self._index, self._size), np.arange(0, self._index),