2020-04-09 19:53:45 +08:00
|
|
|
import numpy as np
|
2020-03-11 17:28:51 +08:00
|
|
|
from tianshou.data import ReplayBuffer
|
2020-03-26 09:01:20 +08:00
|
|
|
|
2020-03-11 18:02:19 +08:00
|
|
|
if __name__ == '__main__':
|
2020-03-21 10:58:01 +08:00
|
|
|
from env import MyTestEnv
|
2020-03-12 22:20:33 +08:00
|
|
|
else: # pytest
|
2020-03-21 10:58:01 +08:00
|
|
|
from test.base.env import MyTestEnv
|
2020-03-11 17:28:51 +08:00
|
|
|
|
|
|
|
|
2020-03-11 18:02:19 +08:00
|
|
|
def test_replaybuffer(size=10, bufsize=20):
|
|
|
|
env = MyTestEnv(size)
|
2020-03-11 17:28:51 +08:00
|
|
|
buf = ReplayBuffer(bufsize)
|
2020-03-16 11:11:29 +08:00
|
|
|
buf2 = ReplayBuffer(bufsize)
|
2020-03-11 17:28:51 +08:00
|
|
|
obs = env.reset()
|
2020-03-16 11:11:29 +08:00
|
|
|
action_list = [1] * 5 + [0] * 10 + [1] * 10
|
2020-03-11 17:28:51 +08:00
|
|
|
for i, a in enumerate(action_list):
|
|
|
|
obs_next, rew, done, info = env.step(a)
|
|
|
|
buf.add(obs, a, rew, done, obs_next, info)
|
2020-03-16 11:11:29 +08:00
|
|
|
obs = obs_next
|
2020-03-11 17:28:51 +08:00
|
|
|
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
|
2020-03-14 21:48:31 +08:00
|
|
|
data, indice = buf.sample(bufsize * 2)
|
2020-03-11 18:02:19 +08:00
|
|
|
assert (indice < len(buf)).all()
|
|
|
|
assert (data.obs < size).all()
|
|
|
|
assert (0 <= data.done).all() and (data.done <= 1).all()
|
2020-03-16 11:11:29 +08:00
|
|
|
assert len(buf) > len(buf2)
|
|
|
|
buf2.update(buf)
|
|
|
|
assert len(buf) == len(buf2)
|
|
|
|
assert buf2[0].obs == buf[5].obs
|
|
|
|
assert buf2[-1].obs == buf[4].obs
|
2020-03-11 17:28:51 +08:00
|
|
|
|
|
|
|
|
2020-04-09 19:53:45 +08:00
|
|
|
def test_stack(size=5, bufsize=9, stack_num=4):
|
|
|
|
env = MyTestEnv(size)
|
|
|
|
buf = ReplayBuffer(bufsize, stack_num)
|
|
|
|
obs = env.reset(1)
|
|
|
|
for i in range(15):
|
|
|
|
obs_next, rew, done, info = env.step(1)
|
|
|
|
buf.add(obs, 1, rew, done, None, info)
|
|
|
|
obs = obs_next
|
|
|
|
if done:
|
|
|
|
obs = env.reset(1)
|
|
|
|
indice = np.arange(len(buf))
|
2020-04-10 09:01:17 +08:00
|
|
|
assert abs(buf.get(indice, 'obs') - np.array([
|
2020-04-09 19:53:45 +08:00
|
|
|
[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
|
|
|
|
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
|
|
|
|
[3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]])).sum() < 1e-6
|
|
|
|
print(buf)
|
|
|
|
|
|
|
|
|
2020-03-11 17:28:51 +08:00
|
|
|
if __name__ == '__main__':
|
2020-03-11 18:02:19 +08:00
|
|
|
test_replaybuffer()
|
2020-04-09 19:53:45 +08:00
|
|
|
test_stack()
|