fix a bug in re-index replay buffer (fix #82)

This commit is contained in:
Trinkle23897 2020-06-17 16:37:51 +08:00
parent c59ad40aef
commit 81e4a16ef2
3 changed files with 28 additions and 5 deletions

View File

@ -1,5 +1,5 @@
import numpy as np import numpy as np
from tianshou.data import ReplayBuffer, PrioritizedReplayBuffer from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer
if __name__ == '__main__': if __name__ == '__main__':
from env import MyTestEnv from env import MyTestEnv
@ -29,6 +29,26 @@ def test_replaybuffer(size=10, bufsize=20):
assert buf2[-1].obs == buf[4].obs assert buf2[-1].obs == buf[4].obs
def test_ignore_obs_next(size=10):
# Issue 82
buf = ReplayBuffer(size, ignore_obs_net=True)
for i in range(size):
buf.add(obs={'mask1': np.array([i, 1, 1, 0, 0]),
'mask2': np.array([i + 4, 0, 1, 0, 0])},
act={'act_id': i,
'position_id': i + 3},
rew=i,
done=i % 3 == 0,
info={'if': i})
indice = np.arange(len(buf))
orig = np.arange(len(buf))
data = buf[indice]
data2 = buf[indice]
assert isinstance(data, Batch)
assert isinstance(data2, Batch)
assert np.allclose(indice, orig)
def test_stack(size=5, bufsize=9, stack_num=4): def test_stack(size=5, bufsize=9, stack_num=4):
env = MyTestEnv(size) env = MyTestEnv(size)
buf = ReplayBuffer(bufsize, stack_num) buf = ReplayBuffer(bufsize, stack_num)
@ -74,5 +94,6 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
if __name__ == '__main__': if __name__ == '__main__':
test_replaybuffer() test_replaybuffer()
test_ignore_obs_next()
test_stack() test_stack()
test_priortized_replaybuffer(233333, 200000) test_priortized_replaybuffer(233333, 200000)

View File

@ -42,21 +42,21 @@ def test_fn(size=2560):
) )
batch = fn(batch, buf, 0) batch = fn(batch, buf, 0)
ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7]) ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
assert abs(batch.returns - ans).sum() <= 1e-5 assert np.allclose(batch.returns, ans)
batch = Batch( batch = Batch(
done=np.array([0, 1, 0, 1, 0, 1, 0.]), done=np.array([0, 1, 0, 1, 0, 1, 0.]),
rew=np.array([7, 6, 1, 2, 3, 4, 5.]), rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
) )
batch = fn(batch, buf, 0) batch = fn(batch, buf, 0)
ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5]) ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
assert abs(batch.returns - ans).sum() <= 1e-5 assert np.allclose(batch.returns, ans)
batch = Batch( batch = Batch(
done=np.array([0, 1, 0, 1, 0, 0, 1.]), done=np.array([0, 1, 0, 1, 0, 0, 1.]),
rew=np.array([7, 6, 1, 2, 3, 4, 5.]), rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
) )
batch = fn(batch, buf, 0) batch = fn(batch, buf, 0)
ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]) ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
assert abs(batch.returns - ans).sum() <= 1e-5 assert np.allclose(batch.returns, ans)
batch = Batch( batch = Batch(
done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]), done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]),
rew=np.array([ rew=np.array([
@ -68,7 +68,7 @@ def test_fn(size=2560):
454.8344, 376.1143, 291.298, 200., 454.8344, 376.1143, 291.298, 200.,
464.5610, 383.1085, 295.387, 201., 464.5610, 383.1085, 295.387, 201.,
474.2876, 390.1027, 299.476, 202.]) 474.2876, 390.1027, 299.476, 202.])
assert abs(ret.returns - returns).sum() <= 1e-3 assert np.allclose(ret.returns, returns)
if __name__ == '__main__': if __name__ == '__main__':
batch = Batch( batch = Batch(
done=np.random.randint(100, size=size) == 0, done=np.random.randint(100, size=size) == 0,

View File

@ -250,6 +250,8 @@ class ReplayBuffer(object):
else self._size - indice.stop if indice.stop < 0 else self._size - indice.stop if indice.stop < 0
else indice.stop, else indice.stop,
1 if indice.step is None else indice.step) 1 if indice.step is None else indice.step)
else:
indice = np.array(indice)
# set last frame done to True # set last frame done to True
last_index = (self._index - 1 + self._size) % self._size last_index = (self._index - 1 + self._size) % self._size
last_done, self.done[last_index] = self.done[last_index], True last_done, self.done[last_index] = self.done[last_index], True