Tianshou/test/test_buffer.py

32 lines
983 B
Python
Raw Normal View History

2020-03-11 17:28:51 +08:00
from tianshou.data import ReplayBuffer
2020-03-11 18:02:19 +08:00
if __name__ == '__main__':
from test_env import MyTestEnv
2020-03-12 22:20:33 +08:00
else: # pytest
2020-03-11 18:02:19 +08:00
from test.test_env import MyTestEnv
2020-03-11 17:28:51 +08:00
2020-03-11 18:02:19 +08:00
def test_replaybuffer(size=10, bufsize=20):
env = MyTestEnv(size)
2020-03-11 17:28:51 +08:00
buf = ReplayBuffer(bufsize)
2020-03-16 11:11:29 +08:00
buf2 = ReplayBuffer(bufsize)
2020-03-11 17:28:51 +08:00
obs = env.reset()
2020-03-16 11:11:29 +08:00
action_list = [1] * 5 + [0] * 10 + [1] * 10
2020-03-11 17:28:51 +08:00
for i, a in enumerate(action_list):
obs_next, rew, done, info = env.step(a)
buf.add(obs, a, rew, done, obs_next, info)
2020-03-16 11:11:29 +08:00
obs = obs_next
2020-03-11 17:28:51 +08:00
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
2020-03-14 21:48:31 +08:00
data, indice = buf.sample(bufsize * 2)
2020-03-11 18:02:19 +08:00
assert (indice < len(buf)).all()
assert (data.obs < size).all()
assert (0 <= data.done).all() and (data.done <= 1).all()
2020-03-16 11:11:29 +08:00
assert len(buf) > len(buf2)
buf2.update(buf)
assert len(buf) == len(buf2)
assert buf2[0].obs == buf[5].obs
assert buf2[-1].obs == buf[4].obs
2020-03-11 17:28:51 +08:00
if __name__ == '__main__':
2020-03-11 18:02:19 +08:00
test_replaybuffer()