Tianshou/test/test_buffer.py
2020-03-11 17:28:51 +08:00

19 lines
554 B
Python

from tianshou.data import ReplayBuffer
from test.test_env import MyTestEnv
def test_replaybuffer(bufsize=20):
env = MyTestEnv(10)
buf = ReplayBuffer(bufsize)
obs = env.reset()
action_list = [1] * 5 + [0] * 10 + [1] * 9
for i, a in enumerate(action_list):
obs_next, rew, done, info = env.step(a)
buf.add(obs, a, rew, done, obs_next, info)
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
indice = buf.sample_indice(4)
data = buf.sample(4)
if __name__ == '__main__':
test_replaybuffer()