* add shmem vecenv, some add&fix in test_env * generalize test_env IO * pep8 fix * comment update * style change * pep8 fix * style fix * minor fix * fix a bug * test fix * change env * testenv bug fix& shmem support recurse dict * bugfix * pep8 fix * _NP_TO_CT enhance * doc update * docstring update * pep8 fix * style change * style fix * remove assert * minor Co-authored-by: Trinkle23897 <463003665@qq.com>
121 lines
4.0 KiB
Python
121 lines
4.0 KiB
Python
import numpy as np
|
|
|
|
from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer
|
|
|
|
if __name__ == '__main__':
|
|
from env import MyTestEnv
|
|
else: # pytest
|
|
from test.base.env import MyTestEnv
|
|
|
|
|
|
def test_replaybuffer(size=10, bufsize=20):
|
|
env = MyTestEnv(size)
|
|
buf = ReplayBuffer(bufsize)
|
|
obs = env.reset()
|
|
action_list = [1] * 5 + [0] * 10 + [1] * 10
|
|
for i, a in enumerate(action_list):
|
|
obs_next, rew, done, info = env.step(a)
|
|
buf.add(obs, a, rew, done, obs_next, info)
|
|
obs = obs_next
|
|
assert len(buf) == min(bufsize, i + 1)
|
|
data, indice = buf.sample(bufsize * 2)
|
|
assert (indice < len(buf)).all()
|
|
assert (data.obs < size).all()
|
|
assert (0 <= data.done).all() and (data.done <= 1).all()
|
|
b = ReplayBuffer(size=10)
|
|
b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}})
|
|
assert b.obs[0] == 1
|
|
assert b.done[0] == 'str'
|
|
assert np.all(b.obs[1:] == 0)
|
|
assert np.all(b.done[1:] == np.array(None))
|
|
assert b.info.a[0] == 3 and b.info.a.dtype == np.integer
|
|
assert np.all(b.info.a[1:] == 0)
|
|
assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact
|
|
assert np.all(b.info.b.c[1:] == 0.0)
|
|
|
|
|
|
def test_ignore_obs_next(size=10):
|
|
# Issue 82
|
|
buf = ReplayBuffer(size, ignore_obs_net=True)
|
|
for i in range(size):
|
|
buf.add(obs={'mask1': np.array([i, 1, 1, 0, 0]),
|
|
'mask2': np.array([i + 4, 0, 1, 0, 0])},
|
|
act={'act_id': i,
|
|
'position_id': i + 3},
|
|
rew=i,
|
|
done=i % 3 == 0,
|
|
info={'if': i})
|
|
indice = np.arange(len(buf))
|
|
orig = np.arange(len(buf))
|
|
data = buf[indice]
|
|
data2 = buf[indice]
|
|
assert isinstance(data, Batch)
|
|
assert isinstance(data2, Batch)
|
|
assert np.allclose(indice, orig)
|
|
|
|
|
|
def test_stack(size=5, bufsize=9, stack_num=4):
|
|
env = MyTestEnv(size)
|
|
buf = ReplayBuffer(bufsize, stack_num=stack_num)
|
|
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
|
|
obs = env.reset(1)
|
|
for i in range(15):
|
|
obs_next, rew, done, info = env.step(1)
|
|
buf.add(obs, 1, rew, done, None, info)
|
|
buf2.add(obs, 1, rew, done, None, info)
|
|
obs = obs_next
|
|
if done:
|
|
obs = env.reset(1)
|
|
indice = np.arange(len(buf))
|
|
assert np.allclose(buf.get(indice, 'obs'), np.expand_dims(
|
|
[[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
|
|
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
|
|
[3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]], axis=-1))
|
|
print(buf)
|
|
_, indice = buf2.sample(0)
|
|
assert indice == [2]
|
|
_, indice = buf2.sample(1)
|
|
assert indice.sum() == 2
|
|
|
|
|
|
def test_priortized_replaybuffer(size=32, bufsize=15):
|
|
env = MyTestEnv(size)
|
|
buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
|
|
obs = env.reset()
|
|
action_list = [1] * 5 + [0] * 10 + [1] * 10
|
|
for i, a in enumerate(action_list):
|
|
obs_next, rew, done, info = env.step(a)
|
|
buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5)
|
|
obs = obs_next
|
|
data, indice = buf.sample(len(buf) // 2)
|
|
if len(buf) // 2 == 0:
|
|
assert len(data) == len(buf)
|
|
else:
|
|
assert len(data) == len(buf) // 2
|
|
assert len(buf) == min(bufsize, i + 1)
|
|
data, indice = buf.sample(len(buf) // 2)
|
|
buf.update_weight(indice, -data.weight / 2)
|
|
assert np.allclose(
|
|
buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha)
|
|
|
|
|
|
def test_update():
|
|
buf1 = ReplayBuffer(4, stack_num=2)
|
|
buf2 = ReplayBuffer(4, stack_num=2)
|
|
for i in range(5):
|
|
buf1.add(obs=np.array([i]), act=float(i), rew=i * i,
|
|
done=False, info={'incident': 'found'})
|
|
assert len(buf1) > len(buf2)
|
|
buf2.update(buf1)
|
|
assert len(buf1) == len(buf2)
|
|
assert (buf2[0].obs == buf1[1].obs).all()
|
|
assert (buf2[-1].obs == buf1[0].obs).all()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_replaybuffer()
|
|
test_ignore_obs_next()
|
|
test_stack()
|
|
test_priortized_replaybuffer(233333, 200000)
|
|
test_update()
|