fix a bug in re-index replay buffer (fix #82)
This commit is contained in:
parent
c59ad40aef
commit
81e4a16ef2
@ -1,5 +1,5 @@
|
||||
import numpy as np
|
||||
from tianshou.data import ReplayBuffer, PrioritizedReplayBuffer
|
||||
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer
|
||||
|
||||
if __name__ == '__main__':
|
||||
from env import MyTestEnv
|
||||
@ -29,6 +29,26 @@ def test_replaybuffer(size=10, bufsize=20):
|
||||
assert buf2[-1].obs == buf[4].obs
|
||||
|
||||
|
||||
def test_ignore_obs_next(size=10):
|
||||
# Issue 82
|
||||
buf = ReplayBuffer(size, ignore_obs_net=True)
|
||||
for i in range(size):
|
||||
buf.add(obs={'mask1': np.array([i, 1, 1, 0, 0]),
|
||||
'mask2': np.array([i + 4, 0, 1, 0, 0])},
|
||||
act={'act_id': i,
|
||||
'position_id': i + 3},
|
||||
rew=i,
|
||||
done=i % 3 == 0,
|
||||
info={'if': i})
|
||||
indice = np.arange(len(buf))
|
||||
orig = np.arange(len(buf))
|
||||
data = buf[indice]
|
||||
data2 = buf[indice]
|
||||
assert isinstance(data, Batch)
|
||||
assert isinstance(data2, Batch)
|
||||
assert np.allclose(indice, orig)
|
||||
|
||||
|
||||
def test_stack(size=5, bufsize=9, stack_num=4):
|
||||
env = MyTestEnv(size)
|
||||
buf = ReplayBuffer(bufsize, stack_num)
|
||||
@ -74,5 +94,6 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_replaybuffer()
|
||||
test_ignore_obs_next()
|
||||
test_stack()
|
||||
test_priortized_replaybuffer(233333, 200000)
|
||||
|
@ -42,21 +42,21 @@ def test_fn(size=2560):
|
||||
)
|
||||
batch = fn(batch, buf, 0)
|
||||
ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
|
||||
assert abs(batch.returns - ans).sum() <= 1e-5
|
||||
assert np.allclose(batch.returns, ans)
|
||||
batch = Batch(
|
||||
done=np.array([0, 1, 0, 1, 0, 1, 0.]),
|
||||
rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
|
||||
)
|
||||
batch = fn(batch, buf, 0)
|
||||
ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
|
||||
assert abs(batch.returns - ans).sum() <= 1e-5
|
||||
assert np.allclose(batch.returns, ans)
|
||||
batch = Batch(
|
||||
done=np.array([0, 1, 0, 1, 0, 0, 1.]),
|
||||
rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
|
||||
)
|
||||
batch = fn(batch, buf, 0)
|
||||
ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
|
||||
assert abs(batch.returns - ans).sum() <= 1e-5
|
||||
assert np.allclose(batch.returns, ans)
|
||||
batch = Batch(
|
||||
done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]),
|
||||
rew=np.array([
|
||||
@ -68,7 +68,7 @@ def test_fn(size=2560):
|
||||
454.8344, 376.1143, 291.298, 200.,
|
||||
464.5610, 383.1085, 295.387, 201.,
|
||||
474.2876, 390.1027, 299.476, 202.])
|
||||
assert abs(ret.returns - returns).sum() <= 1e-3
|
||||
assert np.allclose(ret.returns, returns)
|
||||
if __name__ == '__main__':
|
||||
batch = Batch(
|
||||
done=np.random.randint(100, size=size) == 0,
|
||||
|
@ -250,6 +250,8 @@ class ReplayBuffer(object):
|
||||
else self._size - indice.stop if indice.stop < 0
|
||||
else indice.stop,
|
||||
1 if indice.step is None else indice.step)
|
||||
else:
|
||||
indice = np.array(indice)
|
||||
# set last frame done to True
|
||||
last_index = (self._index - 1 + self._size) % self._size
|
||||
last_done, self.done[last_index] = self.done[last_index], True
|
||||
|
Loading…
x
Reference in New Issue
Block a user