1. add policy.eval() in all test scripts' "watch performance" 2. remove dict return support for collector preprocess_fn 3. add `__contains__` and `pop` in batch: `key in batch`, `batch.pop(key, deft)` 4. exact n_episode for a list of n_episode limitation and save fake data in cache_buffer when self.buffer is None (#184) 5. fix tensorboard logging: h-axis stands for env step instead of gradient step; add test results into tensorboard 6. add test_returns (both GAE and nstep) 7. change the type-checking order in batch.py and converter.py in order to meet the most often case first 8. fix shape inconsistency for torch.Tensor in replay buffer 9. remove `**kwargs` in ReplayBuffer 10. remove default value in batch.split() and add merge_last argument (#185) 11. improve nstep efficiency 12. add max_batchsize in onpolicy algorithms 13. potential bugfix for subproc.wait 14. fix RecurrentActorProb 15. improve the code-coverage (from 90% to 95%) and remove the dead code 16. fix some incorrect type annotation The above improvement also increases the training FPS: on my computer, the previous version is only ~1800 FPS and after that, it can reach ~2050 (faster than v0.2.4.post1).
287 lines
10 KiB
Python
287 lines
10 KiB
Python
import torch
|
|
import pickle
|
|
import pytest
|
|
import numpy as np
|
|
from timeit import timeit
|
|
|
|
from tianshou.data import Batch, SegmentTree, \
|
|
ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer
|
|
|
|
if __name__ == '__main__':
|
|
from env import MyTestEnv
|
|
else: # pytest
|
|
from test.base.env import MyTestEnv
|
|
|
|
|
|
def test_replaybuffer(size=10, bufsize=20):
|
|
env = MyTestEnv(size)
|
|
buf = ReplayBuffer(bufsize)
|
|
buf.update(buf)
|
|
assert str(buf) == buf.__class__.__name__ + '()'
|
|
obs = env.reset()
|
|
action_list = [1] * 5 + [0] * 10 + [1] * 10
|
|
for i, a in enumerate(action_list):
|
|
obs_next, rew, done, info = env.step(a)
|
|
buf.add(obs, [a], rew, done, obs_next, info)
|
|
obs = obs_next
|
|
assert len(buf) == min(bufsize, i + 1)
|
|
with pytest.raises(ValueError):
|
|
buf._add_to_buffer('rew', np.array([1, 2, 3]))
|
|
assert buf.act.dtype == np.object
|
|
assert isinstance(buf.act[0], list)
|
|
data, indice = buf.sample(bufsize * 2)
|
|
assert (indice < len(buf)).all()
|
|
assert (data.obs < size).all()
|
|
assert (0 <= data.done).all() and (data.done <= 1).all()
|
|
b = ReplayBuffer(size=10)
|
|
b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}})
|
|
assert b.obs[0] == 1
|
|
assert b.done[0] == 'str'
|
|
assert np.all(b.obs[1:] == 0)
|
|
assert np.all(b.done[1:] == np.array(None))
|
|
assert b.info.a[0] == 3 and b.info.a.dtype == np.integer
|
|
assert np.all(b.info.a[1:] == 0)
|
|
assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact
|
|
assert np.all(b.info.b.c[1:] == 0.0)
|
|
with pytest.raises(IndexError):
|
|
b[22]
|
|
b = ListReplayBuffer()
|
|
with pytest.raises(NotImplementedError):
|
|
b.sample(0)
|
|
|
|
|
|
def test_ignore_obs_next(size=10):
|
|
# Issue 82
|
|
buf = ReplayBuffer(size, ignore_obs_next=True)
|
|
for i in range(size):
|
|
buf.add(obs={'mask1': np.array([i, 1, 1, 0, 0]),
|
|
'mask2': np.array([i + 4, 0, 1, 0, 0]),
|
|
'mask': i},
|
|
act={'act_id': i,
|
|
'position_id': i + 3},
|
|
rew=i,
|
|
done=i % 3 == 0,
|
|
info={'if': i})
|
|
indice = np.arange(len(buf))
|
|
orig = np.arange(len(buf))
|
|
data = buf[indice]
|
|
data2 = buf[indice]
|
|
assert isinstance(data, Batch)
|
|
assert isinstance(data2, Batch)
|
|
assert np.allclose(indice, orig)
|
|
assert np.allclose(data.obs_next.mask, data2.obs_next.mask)
|
|
assert np.allclose(data.obs_next.mask, [0, 2, 3, 3, 5, 6, 6, 8, 9, 9])
|
|
buf.stack_num = 4
|
|
data = buf[indice]
|
|
data2 = buf[indice]
|
|
assert np.allclose(data.obs_next.mask, data2.obs_next.mask)
|
|
assert np.allclose(data.obs_next.mask, np.array([
|
|
[0, 0, 0, 0], [1, 1, 1, 2], [1, 1, 2, 3], [1, 1, 2, 3],
|
|
[4, 4, 4, 5], [4, 4, 5, 6], [4, 4, 5, 6],
|
|
[7, 7, 7, 8], [7, 7, 8, 9], [7, 7, 8, 9]]))
|
|
assert np.allclose(data.info['if'], data2.info['if'])
|
|
assert np.allclose(data.info['if'], np.array([
|
|
[0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
|
|
[4, 4, 4, 4], [4, 4, 4, 5], [4, 4, 5, 6],
|
|
[7, 7, 7, 7], [7, 7, 7, 8], [7, 7, 8, 9]]))
|
|
assert data.obs_next
|
|
|
|
|
|
def test_stack(size=5, bufsize=9, stack_num=4):
|
|
env = MyTestEnv(size)
|
|
buf = ReplayBuffer(bufsize, stack_num=stack_num)
|
|
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
|
|
obs = env.reset(1)
|
|
for i in range(16):
|
|
obs_next, rew, done, info = env.step(1)
|
|
buf.add(obs, 1, rew, done, None, info)
|
|
buf2.add(obs, 1, rew, done, None, info)
|
|
obs = obs_next
|
|
if done:
|
|
obs = env.reset(1)
|
|
indice = np.arange(len(buf))
|
|
assert np.allclose(buf.get(indice, 'obs')[..., 0], [
|
|
[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
|
|
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
|
|
[1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]])
|
|
_, indice = buf2.sample(0)
|
|
assert indice.tolist() == [2, 6]
|
|
_, indice = buf2.sample(1)
|
|
assert indice in [2, 6]
|
|
with pytest.raises(IndexError):
|
|
buf[bufsize * 2]
|
|
|
|
|
|
def test_priortized_replaybuffer(size=32, bufsize=15):
|
|
env = MyTestEnv(size)
|
|
buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
|
|
obs = env.reset()
|
|
action_list = [1] * 5 + [0] * 10 + [1] * 10
|
|
for i, a in enumerate(action_list):
|
|
obs_next, rew, done, info = env.step(a)
|
|
buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5)
|
|
obs = obs_next
|
|
data, indice = buf.sample(len(buf) // 2)
|
|
if len(buf) // 2 == 0:
|
|
assert len(data) == len(buf)
|
|
else:
|
|
assert len(data) == len(buf) // 2
|
|
assert len(buf) == min(bufsize, i + 1)
|
|
data, indice = buf.sample(len(buf) // 2)
|
|
buf.update_weight(indice, -data.weight / 2)
|
|
assert np.allclose(
|
|
buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha)
|
|
|
|
|
|
def test_update():
|
|
buf1 = ReplayBuffer(4, stack_num=2)
|
|
buf2 = ReplayBuffer(4, stack_num=2)
|
|
for i in range(5):
|
|
buf1.add(obs=np.array([i]), act=float(i), rew=i * i,
|
|
done=i % 2 == 0, info={'incident': 'found'})
|
|
assert len(buf1) > len(buf2)
|
|
buf2.update(buf1)
|
|
assert len(buf1) == len(buf2)
|
|
assert (buf2[0].obs == buf1[1].obs).all()
|
|
assert (buf2[-1].obs == buf1[0].obs).all()
|
|
|
|
|
|
def test_segtree():
|
|
for op, init in zip(['sum', 'max', 'min'], [0., -np.inf, np.inf]):
|
|
realop = getattr(np, op)
|
|
# small test
|
|
actual_len = 8
|
|
tree = SegmentTree(actual_len, op) # 1-15. 8-15 are leaf nodes
|
|
assert len(tree) == actual_len
|
|
assert np.all([tree[i] == init for i in range(actual_len)])
|
|
with pytest.raises(IndexError):
|
|
tree[actual_len]
|
|
naive = np.full([actual_len], init)
|
|
for _ in range(1000):
|
|
# random choose a place to perform single update
|
|
index = np.random.randint(actual_len)
|
|
value = np.random.rand()
|
|
naive[index] = value
|
|
tree[index] = value
|
|
for i in range(actual_len):
|
|
for j in range(i + 1, actual_len):
|
|
ref = realop(naive[i:j])
|
|
out = tree.reduce(i, j)
|
|
assert np.allclose(ref, out)
|
|
assert np.allclose(tree.reduce(start=1), realop(naive[1:]))
|
|
assert np.allclose(tree.reduce(end=-1), realop(naive[:-1]))
|
|
# batch setitem
|
|
for _ in range(1000):
|
|
index = np.random.choice(actual_len, size=4)
|
|
value = np.random.rand(4)
|
|
naive[index] = value
|
|
tree[index] = value
|
|
assert np.allclose(realop(naive), tree.reduce())
|
|
for i in range(10):
|
|
left = np.random.randint(actual_len)
|
|
right = np.random.randint(left + 1, actual_len + 1)
|
|
assert np.allclose(realop(naive[left:right]),
|
|
tree.reduce(left, right))
|
|
# large test
|
|
actual_len = 16384
|
|
tree = SegmentTree(actual_len, op)
|
|
naive = np.full([actual_len], init)
|
|
for _ in range(1000):
|
|
index = np.random.choice(actual_len, size=64)
|
|
value = np.random.rand(64)
|
|
naive[index] = value
|
|
tree[index] = value
|
|
assert np.allclose(realop(naive), tree.reduce())
|
|
for i in range(10):
|
|
left = np.random.randint(actual_len)
|
|
right = np.random.randint(left + 1, actual_len + 1)
|
|
assert np.allclose(realop(naive[left:right]),
|
|
tree.reduce(left, right))
|
|
|
|
# test prefix-sum-idx
|
|
actual_len = 8
|
|
tree = SegmentTree(actual_len)
|
|
naive = np.random.rand(actual_len)
|
|
tree[np.arange(actual_len)] = naive
|
|
for _ in range(1000):
|
|
scalar = np.random.rand() * naive.sum()
|
|
index = tree.get_prefix_sum_idx(scalar)
|
|
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
|
|
# corner case here
|
|
naive = np.ones(actual_len, np.int)
|
|
tree[np.arange(actual_len)] = naive
|
|
for scalar in range(actual_len):
|
|
index = tree.get_prefix_sum_idx(scalar * 1.)
|
|
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
|
|
tree = SegmentTree(10)
|
|
tree[np.arange(3)] = np.array([0.1, 0, 0.1])
|
|
assert np.allclose(tree.get_prefix_sum_idx(
|
|
np.array([0, .1, .1 + 1e-6, .2 - 1e-6])), [0, 0, 2, 2])
|
|
with pytest.raises(AssertionError):
|
|
tree.get_prefix_sum_idx(.2)
|
|
# test large prefix-sum-idx
|
|
actual_len = 16384
|
|
tree = SegmentTree(actual_len)
|
|
naive = np.random.rand(actual_len)
|
|
tree[np.arange(actual_len)] = naive
|
|
for _ in range(1000):
|
|
scalar = np.random.rand() * naive.sum()
|
|
index = tree.get_prefix_sum_idx(scalar)
|
|
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
|
|
|
|
# profile
|
|
if __name__ == '__main__':
|
|
size = 100000
|
|
bsz = 64
|
|
naive = np.random.rand(size)
|
|
tree = SegmentTree(size)
|
|
tree[np.arange(size)] = naive
|
|
|
|
def sample_npbuf():
|
|
return np.random.choice(size, bsz, p=naive / naive.sum())
|
|
|
|
def sample_tree():
|
|
scalar = np.random.rand(bsz) * tree.reduce()
|
|
return tree.get_prefix_sum_idx(scalar)
|
|
|
|
print('npbuf', timeit(sample_npbuf, setup=sample_npbuf, number=1000))
|
|
print('tree', timeit(sample_tree, setup=sample_tree, number=1000))
|
|
|
|
|
|
def test_pickle():
|
|
size = 100
|
|
vbuf = ReplayBuffer(size, stack_num=2)
|
|
lbuf = ListReplayBuffer()
|
|
pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4)
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
rew = torch.tensor([1.]).to(device)
|
|
for i in range(4):
|
|
vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0)
|
|
for i in range(3):
|
|
lbuf.add(obs=Batch(index=np.array([i])), act=1, rew=rew, done=0)
|
|
for i in range(5):
|
|
pbuf.add(obs=Batch(index=np.array([i])),
|
|
act=2, rew=rew, done=0, weight=np.random.rand())
|
|
# save & load
|
|
_vbuf = pickle.loads(pickle.dumps(vbuf))
|
|
_lbuf = pickle.loads(pickle.dumps(lbuf))
|
|
_pbuf = pickle.loads(pickle.dumps(pbuf))
|
|
assert len(_vbuf) == len(vbuf) and np.allclose(_vbuf.act, vbuf.act)
|
|
assert len(_lbuf) == len(lbuf) and np.allclose(_lbuf.act, lbuf.act)
|
|
assert len(_pbuf) == len(pbuf) and np.allclose(_pbuf.act, pbuf.act)
|
|
# make sure the meta var is identical
|
|
assert _vbuf.stack_num == vbuf.stack_num
|
|
assert np.allclose(_pbuf.weight[np.arange(len(_pbuf))],
|
|
pbuf.weight[np.arange(len(pbuf))])
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_replaybuffer()
|
|
test_ignore_obs_next()
|
|
test_stack()
|
|
test_pickle()
|
|
test_segtree()
|
|
test_priortized_replaybuffer()
|
|
test_priortized_replaybuffer(233333, 200000)
|
|
test_update()
|