fix a bug in re-index replay buffer (fix #82)
This commit is contained in:
parent
c59ad40aef
commit
81e4a16ef2
@ -1,5 +1,5 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from tianshou.data import ReplayBuffer, PrioritizedReplayBuffer
|
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from env import MyTestEnv
|
from env import MyTestEnv
|
||||||
@ -29,6 +29,26 @@ def test_replaybuffer(size=10, bufsize=20):
|
|||||||
assert buf2[-1].obs == buf[4].obs
|
assert buf2[-1].obs == buf[4].obs
|
||||||
|
|
||||||
|
|
||||||
|
def test_ignore_obs_next(size=10):
|
||||||
|
# Issue 82
|
||||||
|
buf = ReplayBuffer(size, ignore_obs_net=True)
|
||||||
|
for i in range(size):
|
||||||
|
buf.add(obs={'mask1': np.array([i, 1, 1, 0, 0]),
|
||||||
|
'mask2': np.array([i + 4, 0, 1, 0, 0])},
|
||||||
|
act={'act_id': i,
|
||||||
|
'position_id': i + 3},
|
||||||
|
rew=i,
|
||||||
|
done=i % 3 == 0,
|
||||||
|
info={'if': i})
|
||||||
|
indice = np.arange(len(buf))
|
||||||
|
orig = np.arange(len(buf))
|
||||||
|
data = buf[indice]
|
||||||
|
data2 = buf[indice]
|
||||||
|
assert isinstance(data, Batch)
|
||||||
|
assert isinstance(data2, Batch)
|
||||||
|
assert np.allclose(indice, orig)
|
||||||
|
|
||||||
|
|
||||||
def test_stack(size=5, bufsize=9, stack_num=4):
|
def test_stack(size=5, bufsize=9, stack_num=4):
|
||||||
env = MyTestEnv(size)
|
env = MyTestEnv(size)
|
||||||
buf = ReplayBuffer(bufsize, stack_num)
|
buf = ReplayBuffer(bufsize, stack_num)
|
||||||
@ -74,5 +94,6 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_replaybuffer()
|
test_replaybuffer()
|
||||||
|
test_ignore_obs_next()
|
||||||
test_stack()
|
test_stack()
|
||||||
test_priortized_replaybuffer(233333, 200000)
|
test_priortized_replaybuffer(233333, 200000)
|
||||||
|
@ -42,21 +42,21 @@ def test_fn(size=2560):
|
|||||||
)
|
)
|
||||||
batch = fn(batch, buf, 0)
|
batch = fn(batch, buf, 0)
|
||||||
ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
|
ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
|
||||||
assert abs(batch.returns - ans).sum() <= 1e-5
|
assert np.allclose(batch.returns, ans)
|
||||||
batch = Batch(
|
batch = Batch(
|
||||||
done=np.array([0, 1, 0, 1, 0, 1, 0.]),
|
done=np.array([0, 1, 0, 1, 0, 1, 0.]),
|
||||||
rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
|
rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
|
||||||
)
|
)
|
||||||
batch = fn(batch, buf, 0)
|
batch = fn(batch, buf, 0)
|
||||||
ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
|
ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
|
||||||
assert abs(batch.returns - ans).sum() <= 1e-5
|
assert np.allclose(batch.returns, ans)
|
||||||
batch = Batch(
|
batch = Batch(
|
||||||
done=np.array([0, 1, 0, 1, 0, 0, 1.]),
|
done=np.array([0, 1, 0, 1, 0, 0, 1.]),
|
||||||
rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
|
rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
|
||||||
)
|
)
|
||||||
batch = fn(batch, buf, 0)
|
batch = fn(batch, buf, 0)
|
||||||
ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
|
ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
|
||||||
assert abs(batch.returns - ans).sum() <= 1e-5
|
assert np.allclose(batch.returns, ans)
|
||||||
batch = Batch(
|
batch = Batch(
|
||||||
done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]),
|
done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]),
|
||||||
rew=np.array([
|
rew=np.array([
|
||||||
@ -68,7 +68,7 @@ def test_fn(size=2560):
|
|||||||
454.8344, 376.1143, 291.298, 200.,
|
454.8344, 376.1143, 291.298, 200.,
|
||||||
464.5610, 383.1085, 295.387, 201.,
|
464.5610, 383.1085, 295.387, 201.,
|
||||||
474.2876, 390.1027, 299.476, 202.])
|
474.2876, 390.1027, 299.476, 202.])
|
||||||
assert abs(ret.returns - returns).sum() <= 1e-3
|
assert np.allclose(ret.returns, returns)
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
batch = Batch(
|
batch = Batch(
|
||||||
done=np.random.randint(100, size=size) == 0,
|
done=np.random.randint(100, size=size) == 0,
|
||||||
|
@ -250,6 +250,8 @@ class ReplayBuffer(object):
|
|||||||
else self._size - indice.stop if indice.stop < 0
|
else self._size - indice.stop if indice.stop < 0
|
||||||
else indice.stop,
|
else indice.stop,
|
||||||
1 if indice.step is None else indice.step)
|
1 if indice.step is None else indice.step)
|
||||||
|
else:
|
||||||
|
indice = np.array(indice)
|
||||||
# set last frame done to True
|
# set last frame done to True
|
||||||
last_index = (self._index - 1 + self._size) % self._size
|
last_index = (self._index - 1 + self._size) % self._size
|
||||||
last_done, self.done[last_index] = self.done[last_index], True
|
last_done, self.done[last_index] = self.done[last_index], True
|
||||||
|
Loading…
x
Reference in New Issue
Block a user