From 81e4a16ef2f705044d06f32e8e902c7d98a24083 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 17 Jun 2020 16:37:51 +0800 Subject: [PATCH] fix a bug in re-index replay buffer (fix #82) --- test/base/test_buffer.py | 23 ++++++++++++++++++++++- test/discrete/test_pg.py | 8 ++++---- tianshou/data/buffer.py | 2 ++ 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 6474d44..e45151e 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1,5 +1,5 @@ import numpy as np -from tianshou.data import ReplayBuffer, PrioritizedReplayBuffer +from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer if __name__ == '__main__': from env import MyTestEnv @@ -29,6 +29,26 @@ def test_replaybuffer(size=10, bufsize=20): assert buf2[-1].obs == buf[4].obs +def test_ignore_obs_next(size=10): + # Issue 82 + buf = ReplayBuffer(size, ignore_obs_net=True) + for i in range(size): + buf.add(obs={'mask1': np.array([i, 1, 1, 0, 0]), + 'mask2': np.array([i + 4, 0, 1, 0, 0])}, + act={'act_id': i, + 'position_id': i + 3}, + rew=i, + done=i % 3 == 0, + info={'if': i}) + indice = np.arange(len(buf)) + orig = np.arange(len(buf)) + data = buf[indice] + data2 = buf[indice] + assert isinstance(data, Batch) + assert isinstance(data2, Batch) + assert np.allclose(indice, orig) + + def test_stack(size=5, bufsize=9, stack_num=4): env = MyTestEnv(size) buf = ReplayBuffer(bufsize, stack_num) @@ -74,5 +94,6 @@ def test_priortized_replaybuffer(size=32, bufsize=15): if __name__ == '__main__': test_replaybuffer() + test_ignore_obs_next() test_stack() test_priortized_replaybuffer(233333, 200000) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 7607151..d2817a7 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -42,21 +42,21 @@ def test_fn(size=2560): ) batch = fn(batch, buf, 0) ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7]) - assert abs(batch.returns - ans).sum() <= 1e-5 + assert np.allclose(batch.returns, ans) batch = Batch( done=np.array([0, 1, 0, 1, 0, 1, 0.]), rew=np.array([7, 6, 1, 2, 3, 4, 5.]), ) batch = fn(batch, buf, 0) ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5]) - assert abs(batch.returns - ans).sum() <= 1e-5 + assert np.allclose(batch.returns, ans) batch = Batch( done=np.array([0, 1, 0, 1, 0, 0, 1.]), rew=np.array([7, 6, 1, 2, 3, 4, 5.]), ) batch = fn(batch, buf, 0) ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]) - assert abs(batch.returns - ans).sum() <= 1e-5 + assert np.allclose(batch.returns, ans) batch = Batch( done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]), rew=np.array([ @@ -68,7 +68,7 @@ def test_fn(size=2560): 454.8344, 376.1143, 291.298, 200., 464.5610, 383.1085, 295.387, 201., 474.2876, 390.1027, 299.476, 202.]) - assert abs(ret.returns - returns).sum() <= 1e-3 + assert np.allclose(ret.returns, returns) if __name__ == '__main__': batch = Batch( done=np.random.randint(100, size=size) == 0, diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index a87f53b..592be5f 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -250,6 +250,8 @@ class ReplayBuffer(object): else self._size - indice.stop if indice.stop < 0 else indice.stop, 1 if indice.step is None else indice.step) + else: + indice = np.array(indice) # set last frame done to True last_index = (self._index - 1 + self._size) % self._size last_done, self.done[last_index] = self.done[last_index], True