Tianshou/test/base/test_buffer.py

360 lines
13 KiB
Python
Raw Normal View History

import os
import torch
import pickle
import pytest
import tempfile
import h5py
import numpy as np
from timeit import timeit
from tianshou.data import Batch, SegmentTree, \
ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer
from tianshou.data.utils.converter import to_hdf5
2020-03-11 18:02:19 +08:00
if __name__ == '__main__':
2020-03-21 10:58:01 +08:00
from env import MyTestEnv
2020-03-12 22:20:33 +08:00
else: # pytest
2020-03-21 10:58:01 +08:00
from test.base.env import MyTestEnv
2020-03-11 17:28:51 +08:00
2020-03-11 18:02:19 +08:00
def test_replaybuffer(size=10, bufsize=20):
env = MyTestEnv(size)
2020-03-11 17:28:51 +08:00
buf = ReplayBuffer(bufsize)
buf.update(buf)
assert str(buf) == buf.__class__.__name__ + '()'
2020-03-11 17:28:51 +08:00
obs = env.reset()
2020-03-16 11:11:29 +08:00
action_list = [1] * 5 + [0] * 10 + [1] * 10
2020-03-11 17:28:51 +08:00
for i, a in enumerate(action_list):
obs_next, rew, done, info = env.step(a)
buf.add(obs, [a], rew, done, obs_next, info)
2020-03-16 11:11:29 +08:00
obs = obs_next
2020-06-08 21:53:00 +08:00
assert len(buf) == min(bufsize, i + 1)
with pytest.raises(ValueError):
buf._add_to_buffer('rew', np.array([1, 2, 3]))
assert buf.act.dtype == np.object
assert isinstance(buf.act[0], list)
2020-03-14 21:48:31 +08:00
data, indice = buf.sample(bufsize * 2)
2020-03-11 18:02:19 +08:00
assert (indice < len(buf)).all()
assert (data.obs < size).all()
assert (0 <= data.done).all() and (data.done <= 1).all()
b = ReplayBuffer(size=10)
b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}})
assert b.obs[0] == 1
assert b.done[0] == 'str'
assert np.all(b.obs[1:] == 0)
assert np.all(b.done[1:] == np.array(None))
assert b.info.a[0] == 3 and b.info.a.dtype == np.integer
assert np.all(b.info.a[1:] == 0)
assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact
assert np.all(b.info.b.c[1:] == 0.0)
with pytest.raises(IndexError):
b[22]
b = ListReplayBuffer()
with pytest.raises(NotImplementedError):
b.sample(0)
2020-03-11 17:28:51 +08:00
def test_ignore_obs_next(size=10):
# Issue 82
buf = ReplayBuffer(size, ignore_obs_next=True)
for i in range(size):
buf.add(obs={'mask1': np.array([i, 1, 1, 0, 0]),
'mask2': np.array([i + 4, 0, 1, 0, 0]),
'mask': i},
act={'act_id': i,
'position_id': i + 3},
rew=i,
done=i % 3 == 0,
info={'if': i})
indice = np.arange(len(buf))
orig = np.arange(len(buf))
data = buf[indice]
data2 = buf[indice]
assert isinstance(data, Batch)
assert isinstance(data2, Batch)
assert np.allclose(indice, orig)
assert np.allclose(data.obs_next.mask, data2.obs_next.mask)
assert np.allclose(data.obs_next.mask, [0, 2, 3, 3, 5, 6, 6, 8, 9, 9])
buf.stack_num = 4
data = buf[indice]
data2 = buf[indice]
assert np.allclose(data.obs_next.mask, data2.obs_next.mask)
assert np.allclose(data.obs_next.mask, np.array([
[0, 0, 0, 0], [1, 1, 1, 2], [1, 1, 2, 3], [1, 1, 2, 3],
[4, 4, 4, 5], [4, 4, 5, 6], [4, 4, 5, 6],
[7, 7, 7, 8], [7, 7, 8, 9], [7, 7, 8, 9]]))
assert np.allclose(data.info['if'], data2.info['if'])
assert np.allclose(data.info['if'], np.array([
[0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
[4, 4, 4, 4], [4, 4, 4, 5], [4, 4, 5, 6],
[7, 7, 7, 7], [7, 7, 7, 8], [7, 7, 8, 9]]))
assert data.obs_next
def test_stack(size=5, bufsize=9, stack_num=4):
env = MyTestEnv(size)
buf = ReplayBuffer(bufsize, stack_num=stack_num)
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True)
obs = env.reset(1)
for i in range(16):
obs_next, rew, done, info = env.step(1)
buf.add(obs, 1, rew, done, None, info)
buf2.add(obs, 1, rew, done, None, info)
buf3.add([None, None, obs], 1, rew, done, [None, obs], info)
obs = obs_next
if done:
obs = env.reset(1)
indice = np.arange(len(buf))
assert np.allclose(buf.get(indice, 'obs')[..., 0], [
[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
[1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]])
assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs'))
assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs_next'))
_, indice = buf2.sample(0)
assert indice.tolist() == [2, 6]
_, indice = buf2.sample(1)
assert indice in [2, 6]
with pytest.raises(IndexError):
buf[bufsize * 2]
2020-04-26 15:11:20 +08:00
def test_priortized_replaybuffer(size=32, bufsize=15):
env = MyTestEnv(size)
buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
obs = env.reset()
action_list = [1] * 5 + [0] * 10 + [1] * 10
for i, a in enumerate(action_list):
obs_next, rew, done, info = env.step(a)
buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5)
obs = obs_next
data, indice = buf.sample(len(buf) // 2)
if len(buf) // 2 == 0:
assert len(data) == len(buf)
else:
assert len(data) == len(buf) // 2
2020-06-08 21:53:00 +08:00
assert len(buf) == min(bufsize, i + 1)
2020-04-26 15:11:20 +08:00
data, indice = buf.sample(len(buf) // 2)
buf.update_weight(indice, -data.weight / 2)
assert np.allclose(
buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha)
2020-04-26 15:11:20 +08:00
def test_update():
buf1 = ReplayBuffer(4, stack_num=2)
buf2 = ReplayBuffer(4, stack_num=2)
for i in range(5):
buf1.add(obs=np.array([i]), act=float(i), rew=i * i,
done=i % 2 == 0, info={'incident': 'found'})
assert len(buf1) > len(buf2)
buf2.update(buf1)
assert len(buf1) == len(buf2)
assert (buf2[0].obs == buf1[1].obs).all()
assert (buf2[-1].obs == buf1[0].obs).all()
def test_segtree():
realop = np.sum
# small test
actual_len = 8
tree = SegmentTree(actual_len) # 1-15. 8-15 are leaf nodes
assert len(tree) == actual_len
assert np.all([tree[i] == 0. for i in range(actual_len)])
with pytest.raises(IndexError):
tree[actual_len]
naive = np.zeros([actual_len])
for _ in range(1000):
# random choose a place to perform single update
index = np.random.randint(actual_len)
value = np.random.rand()
naive[index] = value
tree[index] = value
for i in range(actual_len):
for j in range(i + 1, actual_len):
ref = realop(naive[i:j])
out = tree.reduce(i, j)
assert np.allclose(ref, out), (ref, out)
assert np.allclose(tree.reduce(start=1), realop(naive[1:]))
assert np.allclose(tree.reduce(end=-1), realop(naive[:-1]))
# batch setitem
for _ in range(1000):
index = np.random.choice(actual_len, size=4)
value = np.random.rand(4)
naive[index] = value
tree[index] = value
assert np.allclose(realop(naive), tree.reduce())
for i in range(10):
left = np.random.randint(actual_len)
right = np.random.randint(left + 1, actual_len + 1)
assert np.allclose(realop(naive[left:right]),
tree.reduce(left, right))
# large test
actual_len = 16384
tree = SegmentTree(actual_len)
naive = np.zeros([actual_len])
for _ in range(1000):
index = np.random.choice(actual_len, size=64)
value = np.random.rand(64)
naive[index] = value
tree[index] = value
assert np.allclose(realop(naive), tree.reduce())
for i in range(10):
left = np.random.randint(actual_len)
right = np.random.randint(left + 1, actual_len + 1)
assert np.allclose(realop(naive[left:right]),
tree.reduce(left, right))
# test prefix-sum-idx
actual_len = 8
tree = SegmentTree(actual_len)
naive = np.random.rand(actual_len)
tree[np.arange(actual_len)] = naive
for _ in range(1000):
scalar = np.random.rand() * naive.sum()
index = tree.get_prefix_sum_idx(scalar)
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
# corner case here
naive = np.ones(actual_len, np.int)
tree[np.arange(actual_len)] = naive
for scalar in range(actual_len):
index = tree.get_prefix_sum_idx(scalar * 1.)
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
tree = SegmentTree(10)
tree[np.arange(3)] = np.array([0.1, 0, 0.1])
assert np.allclose(tree.get_prefix_sum_idx(
np.array([0, .1, .1 + 1e-6, .2 - 1e-6])), [0, 0, 2, 2])
with pytest.raises(AssertionError):
tree.get_prefix_sum_idx(.2)
# test large prefix-sum-idx
actual_len = 16384
tree = SegmentTree(actual_len)
naive = np.random.rand(actual_len)
tree[np.arange(actual_len)] = naive
for _ in range(1000):
scalar = np.random.rand() * naive.sum()
index = tree.get_prefix_sum_idx(scalar)
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
# profile
if __name__ == '__main__':
size = 100000
bsz = 64
naive = np.random.rand(size)
tree = SegmentTree(size)
tree[np.arange(size)] = naive
def sample_npbuf():
return np.random.choice(size, bsz, p=naive / naive.sum())
def sample_tree():
scalar = np.random.rand(bsz) * tree.reduce()
return tree.get_prefix_sum_idx(scalar)
print('npbuf', timeit(sample_npbuf, setup=sample_npbuf, number=1000))
print('tree', timeit(sample_tree, setup=sample_tree, number=1000))
def test_pickle():
size = 100
vbuf = ReplayBuffer(size, stack_num=2)
lbuf = ListReplayBuffer()
pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
rew = torch.tensor([1.]).to(device)
for i in range(4):
vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0)
for i in range(3):
lbuf.add(obs=Batch(index=np.array([i])), act=1, rew=rew, done=0)
for i in range(5):
pbuf.add(obs=Batch(index=np.array([i])),
act=2, rew=rew, done=0, weight=np.random.rand())
# save & load
_vbuf = pickle.loads(pickle.dumps(vbuf))
_lbuf = pickle.loads(pickle.dumps(lbuf))
_pbuf = pickle.loads(pickle.dumps(pbuf))
assert len(_vbuf) == len(vbuf) and np.allclose(_vbuf.act, vbuf.act)
assert len(_lbuf) == len(lbuf) and np.allclose(_lbuf.act, lbuf.act)
assert len(_pbuf) == len(pbuf) and np.allclose(_pbuf.act, pbuf.act)
# make sure the meta var is identical
assert _vbuf.stack_num == vbuf.stack_num
assert np.allclose(_pbuf.weight[np.arange(len(_pbuf))],
pbuf.weight[np.arange(len(pbuf))])
def test_hdf5():
size = 100
buffers = {
"array": ReplayBuffer(size, stack_num=2),
"list": ListReplayBuffer(),
"prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4)
}
buffer_types = {k: b.__class__ for k, b in buffers.items()}
device = 'cuda' if torch.cuda.is_available() else 'cpu'
rew = torch.tensor([1.]).to(device)
for i in range(4):
kwargs = {
'obs': Batch(index=np.array([i])),
'act': i,
'rew': rew,
'done': 0,
'info': {"number": {"n": i}, 'extra': None},
}
buffers["array"].add(**kwargs)
buffers["list"].add(**kwargs)
buffers["prioritized"].add(weight=np.random.rand(), **kwargs)
# save
paths = {}
for k, buf in buffers.items():
f, path = tempfile.mkstemp(suffix='.hdf5')
os.close(f)
buf.save_hdf5(path)
paths[k] = path
# load replay buffer
_buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()}
# compare
for k in buffers.keys():
assert len(_buffers[k]) == len(buffers[k])
assert np.allclose(_buffers[k].act, buffers[k].act)
assert _buffers[k].stack_num == buffers[k].stack_num
assert _buffers[k]._maxsize == buffers[k]._maxsize
assert _buffers[k]._index == buffers[k]._index
assert np.all(_buffers[k]._indices == buffers[k]._indices)
for k in ["array", "prioritized"]:
assert isinstance(buffers[k].get(0, "info"), Batch)
assert isinstance(_buffers[k].get(0, "info"), Batch)
for k in ["array"]:
assert np.all(
buffers[k][:].info.number.n == _buffers[k][:].info.number.n)
assert np.all(
buffers[k][:].info.extra == _buffers[k][:].info.extra)
for path in paths.values():
os.remove(path)
# raise exception when value cannot be pickled
data = {"not_supported": lambda x: x*x}
grp = h5py.Group
with pytest.raises(NotImplementedError):
to_hdf5(data, grp)
# ndarray with data type not supported by HDF5 that cannot be pickled
data = {"not_supported": np.array(lambda x: x*x)}
grp = h5py.Group
with pytest.raises(RuntimeError):
to_hdf5(data, grp)
2020-03-11 17:28:51 +08:00
if __name__ == '__main__':
test_hdf5()
2020-03-11 18:02:19 +08:00
test_replaybuffer()
test_ignore_obs_next()
test_stack()
test_pickle()
test_segtree()
test_priortized_replaybuffer()
2020-04-26 15:11:20 +08:00
test_priortized_replaybuffer(233333, 200000)
test_update()