Tianshou/test/base/test_buffer.py
2023-08-09 10:27:18 -07:00

1354 lines
42 KiB
Python

import os
import pickle
import tempfile
from timeit import timeit
import h5py
import numpy as np
import pytest
import torch
from tianshou.data import (
Batch,
CachedReplayBuffer,
HERReplayBuffer,
HERVectorReplayBuffer,
PrioritizedReplayBuffer,
PrioritizedVectorReplayBuffer,
ReplayBuffer,
SegmentTree,
VectorReplayBuffer,
)
from tianshou.data.utils.converter import to_hdf5
if __name__ == '__main__':
from env import MyGoalEnv, MyTestEnv
else: # pytest
from test.base.env import MyGoalEnv, MyTestEnv
def test_replaybuffer(size=10, bufsize=20):
env = MyTestEnv(size)
buf = ReplayBuffer(bufsize)
buf.update(buf)
assert str(buf) == buf.__class__.__name__ + '()'
obs, _ = env.reset()
action_list = [1] * 5 + [0] * 10 + [1] * 10
for i, act in enumerate(action_list):
obs_next, rew, terminated, truncated, info = env.step(act)
buf.add(
Batch(
obs=obs,
act=[act],
rew=rew,
terminated=terminated,
truncated=truncated,
obs_next=obs_next,
info=info
)
)
obs = obs_next
assert len(buf) == min(bufsize, i + 1)
assert buf.act.dtype == int
assert buf.act.shape == (bufsize, 1)
data, indices = buf.sample(bufsize * 2)
assert (indices < len(buf)).all()
assert (data.obs < size).all()
assert (0 <= data.done).all() and (data.done <= 1).all()
assert (0 <= data.terminated).all() and (data.terminated <= 1).all()
assert (0 <= data.truncated).all() and (data.truncated <= 1).all()
b = ReplayBuffer(size=10)
# neg bsz should return empty index
assert b.sample_indices(-1).tolist() == []
ptr, ep_rew, ep_len, ep_idx = b.add(
Batch(
obs=1,
act=1,
rew=1,
terminated=1,
truncated=0,
obs_next='str',
info={
'a': 3,
'b': {
'c': 5.0
}
}
)
)
assert b.obs[0] == 1
assert b.done[0]
assert b.terminated[0]
assert not b.truncated[0]
assert b.obs_next[0] == 'str'
assert np.all(b.obs[1:] == 0)
assert np.all(b.obs_next[1:] == np.array(None))
assert b.info.a[0] == 3 and b.info.a.dtype == int
assert np.all(b.info.a[1:] == 0)
assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == float
assert np.all(b.info.b.c[1:] == 0.0)
assert ptr.shape == (1, ) and ptr[0] == 0
assert ep_rew.shape == (1, ) and ep_rew[0] == 1
assert ep_len.shape == (1, ) and ep_len[0] == 1
assert ep_idx.shape == (1, ) and ep_idx[0] == 0
# test extra keys pop up, the buffer should handle it dynamically
batch = Batch(
obs=2,
act=2,
rew=2,
terminated=0,
truncated=0,
obs_next="str2",
info={
"a": 4,
"d": {
"e": -np.inf
}
}
)
b.add(batch)
info_keys = ["a", "b", "d"]
assert set(b.info.keys()) == set(info_keys)
assert b.info.a[1] == 4 and b.info.b.c[1] == 0
assert b.info.d.e[1] == -np.inf
# test batch-style adding method, where len(batch) == 1
batch.done = [1]
batch.terminated = [0]
batch.truncated = [1]
batch.info.e = np.zeros([1, 4])
batch = Batch.stack([batch])
ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0])
assert ptr.shape == (1, ) and ptr[0] == 2
assert ep_rew.shape == (1, ) and ep_rew[0] == 4
assert ep_len.shape == (1, ) and ep_len[0] == 2
assert ep_idx.shape == (1, ) and ep_idx[0] == 1
assert set(b.info.keys()) == set(info_keys + ["e"])
assert b.info.e.shape == (b.maxsize, 1, 4)
with pytest.raises(IndexError):
b[22]
# test prev / next
assert np.all(b.prev(np.array([0, 1, 2])) == [0, 1, 1])
assert np.all(b.next(np.array([0, 1, 2])) == [0, 2, 2])
batch.done = [0]
b.add(batch, buffer_ids=[0])
assert np.all(b.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3])
assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3])
def test_ignore_obs_next(size=10):
# Issue 82
buf = ReplayBuffer(size, ignore_obs_next=True)
for i in range(size):
buf.add(
Batch(
obs={
'mask1': np.array([i, 1, 1, 0, 0]),
'mask2': np.array([i + 4, 0, 1, 0, 0]),
'mask': i
},
act={
'act_id': i,
'position_id': i + 3
},
rew=i,
terminated=i % 3 == 0,
truncated=False,
info={'if': i}
)
)
indices = np.arange(len(buf))
orig = np.arange(len(buf))
data = buf[indices]
data2 = buf[indices]
assert isinstance(data, Batch)
assert isinstance(data2, Batch)
assert np.allclose(indices, orig)
assert np.allclose(data.obs_next.mask, data2.obs_next.mask)
assert np.allclose(data.obs_next.mask, [0, 2, 3, 3, 5, 6, 6, 8, 9, 9])
buf.stack_num = 4
data = buf[indices]
data2 = buf[indices]
assert np.allclose(data.obs_next.mask, data2.obs_next.mask)
assert np.allclose(
data.obs_next.mask,
np.array(
[
[0, 0, 0, 0], [1, 1, 1, 2], [1, 1, 2, 3], [1, 1, 2, 3], [4, 4, 4, 5],
[4, 4, 5, 6], [4, 4, 5, 6], [7, 7, 7, 8], [7, 7, 8, 9], [7, 7, 8, 9]
]
)
)
assert np.allclose(data.info['if'], data2.info['if'])
assert np.allclose(
data.info['if'],
np.array(
[
[0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], [4, 4, 4, 4],
[4, 4, 4, 5], [4, 4, 5, 6], [7, 7, 7, 7], [7, 7, 7, 8], [7, 7, 8, 9]
]
)
)
assert data.obs_next
def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3):
env = MyTestEnv(size)
buf = ReplayBuffer(bufsize, stack_num=stack_num)
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True)
obs, info = env.reset(options={"state": 1})
for _ in range(16):
obs_next, rew, terminated, truncated, info = env.step(1)
done = terminated or truncated
buf.add(
Batch(
obs=obs,
act=1,
rew=rew,
terminated=terminated,
truncated=truncated,
info=info
)
)
buf2.add(
Batch(
obs=obs,
act=1,
rew=rew,
terminated=terminated,
truncated=truncated,
info=info
)
)
buf3.add(
Batch(
obs=[obs, obs, obs],
act=1,
rew=rew,
terminated=terminated,
truncated=truncated,
obs_next=[obs, obs],
info=info
)
)
obs = obs_next
if done:
obs, info = env.reset(options={"state": 1})
indices = np.arange(len(buf))
assert np.allclose(
buf.get(indices, 'obs')[..., 0], [
[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2],
[1, 1, 2, 3], [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]
]
)
assert np.allclose(buf.get(indices, 'obs'), buf3.get(indices, 'obs'))
assert np.allclose(buf.get(indices, 'obs'), buf3.get(indices, 'obs_next'))
_, indices = buf2.sample(0)
assert indices.tolist() == [2, 6]
_, indices = buf2.sample(1)
assert indices[0] in [2, 6]
batch, indices = buf2.sample(-1) # neg bsz -> no data
assert indices.tolist() == [] and len(batch) == 0
with pytest.raises(IndexError):
buf[bufsize * 2]
def test_priortized_replaybuffer(size=32, bufsize=15):
env = MyTestEnv(size)
buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5)
obs, info = env.reset()
action_list = [1] * 5 + [0] * 10 + [1] * 10
for i, act in enumerate(action_list):
obs_next, rew, terminated, truncated, info = env.step(act)
batch = Batch(
obs=obs,
act=act,
rew=rew,
terminated=terminated,
truncated=truncated,
obs_next=obs_next,
info=info,
policy=np.random.randn() - 0.5
)
batch_stack = Batch.stack([batch, batch, batch])
buf.add(Batch.stack([batch]), buffer_ids=[0])
buf2.add(batch_stack, buffer_ids=[0, 1, 2])
obs = obs_next
data, indices = buf.sample(len(buf) // 2)
if len(buf) // 2 == 0:
assert len(data) == len(buf)
else:
assert len(data) == len(buf) // 2
assert len(buf) == min(bufsize, i + 1)
assert len(buf2) == min(bufsize, 3 * (i + 1))
# check single buffer's data
assert buf.info.key.shape == (buf.maxsize, )
assert buf.rew.dtype == float
assert buf.done.dtype == bool
assert buf.terminated.dtype == bool
assert buf.truncated.dtype == bool
data, indices = buf.sample(len(buf) // 2)
buf.update_weight(indices, -data.weight / 2)
assert np.allclose(buf.weight[indices], np.abs(-data.weight / 2)**buf._alpha)
# check multi buffer's data
assert np.allclose(buf2[np.arange(buf2.maxsize)].weight, 1)
batch, indices = buf2.sample(10)
buf2.update_weight(indices, batch.weight * 0)
weight = buf2[np.arange(buf2.maxsize)].weight
mask = np.isin(np.arange(buf2.maxsize), indices)
assert np.all(weight[mask] == weight[mask][0])
assert np.all(weight[~mask] == weight[~mask][0])
assert weight[~mask][0] < weight[mask][0] and weight[mask][0] <= 1
def test_herreplaybuffer(size=10, bufsize=100, sample_sz=4):
env_size = size
env = MyGoalEnv(env_size, array_state=True)
def compute_reward_fn(ag, g):
return env.compute_reward_fn(ag, g, {})
buf = HERReplayBuffer(
bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8
)
buf2 = HERVectorReplayBuffer(
bufsize,
buffer_num=3,
compute_reward_fn=compute_reward_fn,
horizon=30,
future_k=8
)
# Apply her on every episodes sampled (Hacky but necessary for deterministic test)
buf.future_p = 1
for buf2_buf in buf2.buffers:
buf2_buf.future_p = 1
obs, _ = env.reset()
action_list = [1] * 5 + [0] * 10 + [1] * 10
for i, act in enumerate(action_list):
obs_next, rew, terminated, truncated, info = env.step(act)
batch = Batch(
obs=obs,
act=[act],
rew=rew,
terminated=terminated,
truncated=truncated,
obs_next=obs_next,
info=info
)
buf.add(batch)
buf2.add(Batch.stack([batch, batch, batch]), buffer_ids=[0, 1, 2])
obs = obs_next
assert len(buf) == min(bufsize, i + 1)
assert len(buf2) == min(bufsize, 3 * (i + 1))
batch, indices = buf.sample(sample_sz)
# Check that goals are the same for the episode (only 1 ep in buffer)
tmp_indices = indices.copy()
for _ in range(2 * env_size):
obs = buf[tmp_indices].obs
obs_next = buf[tmp_indices].obs_next
rew = buf[tmp_indices].rew
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0]
ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0]
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0]
assert np.all(g == g[0])
assert np.all(g_next == g_next[0])
assert np.all(rew == (ag_next == g).astype(np.float32))
tmp_indices = buf.next(tmp_indices)
# Check that goals are correctly restored
buf._restore_cache()
tmp_indices = indices.copy()
for _ in range(2 * env_size):
obs = buf[tmp_indices].obs
obs_next = buf[tmp_indices].obs_next
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0]
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0]
assert np.all(g == env_size)
assert np.all(g_next == g_next[0])
assert np.all(g == g[0])
tmp_indices = buf.next(tmp_indices)
# Test vector buffer
batch, indices = buf2.sample(sample_sz)
# Check that goals are the same for the episode (only 1 ep in buffer)
tmp_indices = indices.copy()
for _ in range(2 * env_size):
obs = buf2[tmp_indices].obs
obs_next = buf2[tmp_indices].obs_next
rew = buf2[tmp_indices].rew
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0]
ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0]
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0]
assert np.all(g == g_next)
assert np.all(rew == (ag_next == g).astype(np.float32))
tmp_indices = buf2.next(tmp_indices)
# Check that goals are correctly restored
buf2._restore_cache()
tmp_indices = indices.copy()
for _ in range(2 * env_size):
obs = buf2[tmp_indices].obs
obs_next = buf2[tmp_indices].obs_next
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0]
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0]
assert np.all(g == env_size)
assert np.all(g_next == g_next[0])
assert np.all(g == g[0])
tmp_indices = buf2.next(tmp_indices)
# Test handling cycled indices
env_size = size
bufsize = 15
env = MyGoalEnv(env_size, array_state=False)
def compute_reward_fn(ag, g):
return env.compute_reward_fn(ag, g, {})
buf = HERReplayBuffer(
bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8
)
buf._index = 5 # shifted start index
buf.future_p = 1
action_list = [1] * 10
for ep_len in [5, 10]:
obs, _ = env.reset()
for i in range(ep_len):
act = 1
obs_next, rew, terminated, truncated, info = env.step(act)
batch = Batch(
obs=obs,
act=[act],
rew=rew,
terminated=(i == ep_len - 1),
truncated=(i == ep_len - 1),
obs_next=obs_next,
info=info
)
buf.add(batch)
obs = obs_next
batch, indices = buf.sample(0)
assert np.all(buf[:5].obs.desired_goal == buf[0].obs.desired_goal)
assert np.all(buf[5:10].obs.desired_goal == buf[5].obs.desired_goal)
assert np.all(buf[10:].obs.desired_goal == buf[0].obs.desired_goal) # (same ep)
assert np.all(buf[0].obs.desired_goal != buf[5].obs.desired_goal) # (diff ep)
# Another test case for cycled indices
env_size = 99
bufsize = 15
env = MyGoalEnv(env_size, array_state=False)
buf = HERReplayBuffer(
bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8
)
buf.future_p = 1
for x, ep_len in enumerate([10, 20]):
obs, _ = env.reset()
for i in range(ep_len):
act = 1
obs_next, rew, terminated, truncated, info = env.step(act)
batch = Batch(
obs=obs,
act=[act],
rew=rew,
terminated=(i == ep_len - 1),
truncated=(i == ep_len - 1),
obs_next=obs_next,
info=info
)
if x == 1 and obs["observation"] < 10:
obs = obs_next
continue
buf.add(batch)
obs = obs_next
buf._restore_cache()
sample_indices = np.array([10]) # Suppose the sampled indices is [10]
buf.rewrite_transitions(sample_indices)
assert int(buf.obs.desired_goal[10][0]) in [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
def test_update():
buf1 = ReplayBuffer(4, stack_num=2)
buf2 = ReplayBuffer(4, stack_num=2)
for i in range(5):
buf1.add(
Batch(
obs=np.array([i]),
act=float(i),
rew=i * i,
terminated=i % 2 == 0,
truncated=False,
info={'incident': 'found'}
)
)
assert len(buf1) > len(buf2)
buf2.update(buf1)
assert len(buf1) == len(buf2)
assert (buf2[0].obs == buf1[1].obs).all()
assert (buf2[-1].obs == buf1[0].obs).all()
b = CachedReplayBuffer(ReplayBuffer(10), 4, 5)
with pytest.raises(NotImplementedError):
b.update(b)
def test_segtree():
realop = np.sum
# small test
actual_len = 8
tree = SegmentTree(actual_len) # 1-15. 8-15 are leaf nodes
assert len(tree) == actual_len
assert np.all([tree[i] == 0. for i in range(actual_len)])
with pytest.raises(IndexError):
tree[actual_len]
naive = np.zeros([actual_len])
for _ in range(1000):
# random choose a place to perform single update
index = np.random.randint(actual_len)
value = np.random.rand()
naive[index] = value
tree[index] = value
for i in range(actual_len):
for j in range(i + 1, actual_len):
ref = realop(naive[i:j])
out = tree.reduce(i, j)
assert np.allclose(ref, out), (ref, out)
assert np.allclose(tree.reduce(start=1), realop(naive[1:]))
assert np.allclose(tree.reduce(end=-1), realop(naive[:-1]))
# batch setitem
for _ in range(1000):
index = np.random.choice(actual_len, size=4)
value = np.random.rand(4)
naive[index] = value
tree[index] = value
assert np.allclose(realop(naive), tree.reduce())
for _ in range(10):
left = np.random.randint(actual_len)
right = np.random.randint(left + 1, actual_len + 1)
assert np.allclose(realop(naive[left:right]), tree.reduce(left, right))
# large test
actual_len = 16384
tree = SegmentTree(actual_len)
naive = np.zeros([actual_len])
for _ in range(1000):
index = np.random.choice(actual_len, size=64)
value = np.random.rand(64)
naive[index] = value
tree[index] = value
assert np.allclose(realop(naive), tree.reduce())
for _ in range(10):
left = np.random.randint(actual_len)
right = np.random.randint(left + 1, actual_len + 1)
assert np.allclose(realop(naive[left:right]), tree.reduce(left, right))
# test prefix-sum-idx
actual_len = 8
tree = SegmentTree(actual_len)
naive = np.random.rand(actual_len)
tree[np.arange(actual_len)] = naive
for _ in range(1000):
scalar = np.random.rand() * naive.sum()
index = tree.get_prefix_sum_idx(scalar)
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
# corner case here
naive = np.ones(actual_len, int)
tree[np.arange(actual_len)] = naive
for scalar in range(actual_len):
index = tree.get_prefix_sum_idx(scalar * 1.)
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
tree = SegmentTree(10)
tree[np.arange(3)] = np.array([0.1, 0, 0.1])
assert np.allclose(
tree.get_prefix_sum_idx(np.array([0, .1, .1 + 1e-6, .2 - 1e-6])), [0, 0, 2, 2]
)
with pytest.raises(AssertionError):
tree.get_prefix_sum_idx(.2)
# test large prefix-sum-idx
actual_len = 16384
tree = SegmentTree(actual_len)
naive = np.random.rand(actual_len)
tree[np.arange(actual_len)] = naive
for _ in range(1000):
scalar = np.random.rand() * naive.sum()
index = tree.get_prefix_sum_idx(scalar)
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
# profile
if __name__ == '__main__':
size = 100000
bsz = 64
naive = np.random.rand(size)
tree = SegmentTree(size)
tree[np.arange(size)] = naive
def sample_npbuf():
return np.random.choice(size, bsz, p=naive / naive.sum())
def sample_tree():
scalar = np.random.rand(bsz) * tree.reduce()
return tree.get_prefix_sum_idx(scalar)
print('npbuf', timeit(sample_npbuf, setup=sample_npbuf, number=1000))
print('tree', timeit(sample_tree, setup=sample_tree, number=1000))
def test_pickle():
size = 100
vbuf = ReplayBuffer(size, stack_num=2)
pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4)
rew = np.array([1, 1])
for i in range(4):
vbuf.add(
Batch(
obs=Batch(index=np.array([i])),
act=0,
rew=rew,
terminated=0,
truncated=0,
)
)
for i in range(5):
pbuf.add(
Batch(
obs=Batch(index=np.array([i])),
act=2,
rew=rew,
terminated=0,
truncated=0,
info=np.random.rand()
)
)
# save & load
_vbuf = pickle.loads(pickle.dumps(vbuf))
_pbuf = pickle.loads(pickle.dumps(pbuf))
assert len(_vbuf) == len(vbuf) and np.allclose(_vbuf.act, vbuf.act)
assert len(_pbuf) == len(pbuf) and np.allclose(_pbuf.act, pbuf.act)
# make sure the meta var is identical
assert _vbuf.stack_num == vbuf.stack_num
assert np.allclose(
_pbuf.weight[np.arange(len(_pbuf))], pbuf.weight[np.arange(len(pbuf))]
)
def test_hdf5():
size = 100
buffers = {
"array": ReplayBuffer(size, stack_num=2),
"prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4),
}
buffer_types = {k: b.__class__ for k, b in buffers.items()}
device = 'cuda' if torch.cuda.is_available() else 'cpu'
info_t = torch.tensor([1.]).to(device)
for i in range(4):
kwargs = {
'obs': Batch(index=np.array([i])),
'act': i,
'rew': np.array([1, 2]),
'terminated': i % 3 == 2,
'truncated': False,
'done': i % 3 == 2,
'info': {
"number": {
"n": i,
"t": info_t
},
'extra': None
},
}
buffers["array"].add(Batch(kwargs))
buffers["prioritized"].add(Batch(kwargs))
# save
paths = {}
for k, buf in buffers.items():
f, path = tempfile.mkstemp(suffix='.hdf5')
os.close(f)
buf.save_hdf5(path)
paths[k] = path
# load replay buffer
_buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()}
# compare
for k in buffers.keys():
assert len(_buffers[k]) == len(buffers[k])
assert np.allclose(_buffers[k].act, buffers[k].act)
assert _buffers[k].stack_num == buffers[k].stack_num
assert _buffers[k].maxsize == buffers[k].maxsize
assert np.all(_buffers[k]._indices == buffers[k]._indices)
for k in ["array", "prioritized"]:
assert _buffers[k]._index == buffers[k]._index
assert isinstance(buffers[k].get(0, "info"), Batch)
assert isinstance(_buffers[k].get(0, "info"), Batch)
for k in ["array"]:
assert np.all(buffers[k][:].info.number.n == _buffers[k][:].info.number.n)
assert np.all(buffers[k][:].info.extra == _buffers[k][:].info.extra)
# raise exception when value cannot be pickled
data = {"not_supported": lambda x: x * x}
grp = h5py.Group
with pytest.raises(NotImplementedError):
to_hdf5(data, grp)
# ndarray with data type not supported by HDF5 that cannot be pickled
data = {"not_supported": np.array(lambda x: x * x)}
grp = h5py.Group
with pytest.raises(RuntimeError):
to_hdf5(data, grp)
def test_replaybuffermanager():
buf = VectorReplayBuffer(20, 4)
batch = Batch(
obs=[1, 2, 3],
act=[1, 2, 3],
rew=[1, 2, 3],
terminated=[0, 0, 1],
truncated=[0, 0, 0],
)
ptr, ep_rew, ep_len, ep_idx = buf.add(batch, buffer_ids=[0, 1, 2])
assert np.all(ep_len == [0, 0, 1]) and np.all(ep_rew == [0, 0, 3])
assert np.all(ptr == [0, 5, 10]) and np.all(ep_idx == [0, 5, 10])
with pytest.raises(NotImplementedError):
# ReplayBufferManager cannot be updated
buf.update(buf)
# sample index / prev / next / unfinished_index
indices = buf.sample_indices(11000)
assert np.bincount(indices)[[0, 5, 10]].min() >= 3000 # uniform sample
batch, indices = buf.sample(0)
assert np.allclose(indices, [0, 5, 10])
indices_prev = buf.prev(indices)
assert np.allclose(indices_prev, indices), indices_prev
indices_next = buf.next(indices)
assert np.allclose(indices_next, indices), indices_next
assert np.allclose(buf.unfinished_index(), [0, 5])
buf.add(
Batch(obs=[4], act=[4], rew=[4], terminated=[1], truncated=[0]),
buffer_ids=[3]
)
assert np.allclose(buf.unfinished_index(), [0, 5])
batch, indices = buf.sample(10)
batch, indices = buf.sample(0)
assert np.allclose(indices, [0, 5, 10, 15])
indices_prev = buf.prev(indices)
assert np.allclose(indices_prev, indices), indices_prev
indices_next = buf.next(indices)
assert np.allclose(indices_next, indices), indices_next
data = np.array([0, 0, 0, 0])
buf.add(
Batch(obs=data, act=data, rew=data, terminated=data, truncated=data),
buffer_ids=[0, 1, 2, 3]
)
buf.add(
Batch(obs=data, act=data, rew=data, terminated=1 - data, truncated=data),
buffer_ids=[0, 1, 2, 3]
)
assert len(buf) == 12
buf.add(
Batch(obs=data, act=data, rew=data, terminated=data, truncated=data),
buffer_ids=[0, 1, 2, 3]
)
buf.add(
Batch(obs=data, act=data, rew=data, terminated=[0, 1, 0, 1], truncated=data),
buffer_ids=[0, 1, 2, 3]
)
assert len(buf) == 20
indices = buf.sample_indices(120000)
assert np.bincount(indices).min() >= 5000
batch, indices = buf.sample(10)
indices = buf.sample_indices(0)
assert np.allclose(indices, np.arange(len(buf)))
# check the actual data stored in buf._meta
assert np.allclose(
buf.done, [
0,
0,
1,
0,
0,
0,
0,
1,
0,
1,
1,
0,
1,
0,
0,
1,
0,
1,
0,
1,
]
)
assert np.allclose(
buf.prev(indices), [
0,
0,
1,
3,
3,
5,
5,
6,
8,
8,
10,
11,
11,
13,
13,
15,
16,
16,
18,
18,
]
)
assert np.allclose(
buf.next(indices), [
1,
2,
2,
4,
4,
6,
7,
7,
9,
9,
10,
12,
12,
14,
14,
15,
17,
17,
19,
19,
]
)
assert np.allclose(buf.unfinished_index(), [4, 14])
ptr, ep_rew, ep_len, ep_idx = buf.add(
Batch(obs=[1], act=[1], rew=[1], terminated=[1], truncated=[0]),
buffer_ids=[2]
)
assert np.all(ep_len == [3]) and np.all(ep_rew == [1])
assert np.all(ptr == [10]) and np.all(ep_idx == [13])
assert np.allclose(buf.unfinished_index(), [4])
indices = list(sorted(buf.sample_indices(0)))
assert np.allclose(indices, np.arange(len(buf)))
assert np.allclose(
buf.prev(indices), [
0,
0,
1,
3,
3,
5,
5,
6,
8,
8,
14,
11,
11,
13,
13,
15,
16,
16,
18,
18,
]
)
assert np.allclose(
buf.next(indices), [
1,
2,
2,
4,
4,
6,
7,
7,
9,
9,
10,
12,
12,
14,
10,
15,
17,
17,
19,
19,
]
)
# corner case: list, int and -1
assert buf.prev(-1) == buf.prev([buf.maxsize - 1])[0]
assert buf.next(-1) == buf.next([buf.maxsize - 1])[0]
batch = buf._meta
batch.info = np.ones(buf.maxsize)
buf.set_batch(batch)
assert np.allclose(buf.buffers[-1].info, [1] * 5)
assert buf.sample_indices(-1).tolist() == []
assert np.array([ReplayBuffer(0, ignore_obs_next=True)]).dtype == object
def test_cachedbuffer():
buf = CachedReplayBuffer(ReplayBuffer(10), 4, 5)
assert buf.sample_indices(0).tolist() == []
# check the normal function/usage/storage in CachedReplayBuffer
ptr, ep_rew, ep_len, ep_idx = buf.add(
Batch(obs=[1], act=[1], rew=[1], terminated=[0], truncated=[0]),
buffer_ids=[1]
)
obs = np.zeros(buf.maxsize)
obs[15] = 1
indices = buf.sample_indices(0)
assert np.allclose(indices, [15])
assert np.allclose(buf.prev(indices), [15])
assert np.allclose(buf.next(indices), [15])
assert np.allclose(buf.obs, obs)
assert np.all(ep_len == [0]) and np.all(ep_rew == [0.0])
assert np.all(ptr == [15]) and np.all(ep_idx == [15])
ptr, ep_rew, ep_len, ep_idx = buf.add(
Batch(obs=[2], act=[2], rew=[2], terminated=[1], truncated=[0]),
buffer_ids=[3]
)
obs[[0, 25]] = 2
indices = buf.sample_indices(0)
assert np.allclose(indices, [0, 15])
assert np.allclose(buf.prev(indices), [0, 15])
assert np.allclose(buf.next(indices), [0, 15])
assert np.allclose(buf.obs, obs)
assert np.all(ep_len == [1]) and np.all(ep_rew == [2.0])
assert np.all(ptr == [0]) and np.all(ep_idx == [0])
assert np.allclose(buf.unfinished_index(), [15])
assert np.allclose(buf.sample_indices(0), [0, 15])
ptr, ep_rew, ep_len, ep_idx = buf.add(
Batch(obs=[3, 4], act=[3, 4], rew=[3, 4], terminated=[0, 1], truncated=[0, 0]),
buffer_ids=[3, 1] # TODO
)
assert np.all(ep_len == [0, 2]) and np.all(ep_rew == [0, 5.0])
assert np.all(ptr == [25, 2]) and np.all(ep_idx == [25, 1])
obs[[0, 1, 2, 15, 16, 25]] = [2, 1, 4, 1, 4, 3]
assert np.allclose(buf.obs, obs)
assert np.allclose(buf.unfinished_index(), [25])
indices = buf.sample_indices(0)
assert np.allclose(indices, [0, 1, 2, 25])
assert np.allclose(buf.done[indices], [1, 0, 1, 0])
assert np.allclose(buf.prev(indices), [0, 1, 1, 25])
assert np.allclose(buf.next(indices), [0, 2, 2, 25])
indices = buf.sample_indices(10000)
assert np.bincount(indices)[[0, 1, 2, 25]].min() > 2000 # uniform sample
# cached buffer with main_buffer size == 0 (no update)
# used in test_collector
buf = CachedReplayBuffer(ReplayBuffer(0, sample_avail=True), 4, 5)
data = np.zeros(4)
rew = np.ones([4, 4])
buf.add(
Batch(
obs=data,
act=data,
rew=rew,
terminated=[0, 0, 1, 1],
truncated=[0, 0, 0, 0]
)
)
buf.add(
Batch(
obs=data,
act=data,
rew=rew,
terminated=[0, 0, 0, 0],
truncated=[0, 0, 0, 0]
)
)
buf.add(
Batch(
obs=data,
act=data,
rew=rew,
terminated=[1, 1, 1, 1],
truncated=[0, 0, 0, 0]
)
)
buf.add(
Batch(
obs=data,
act=data,
rew=rew,
terminated=[0, 0, 0, 0],
truncated=[0, 0, 0, 0]
)
)
ptr, ep_rew, ep_len, ep_idx = buf.add(
Batch(
obs=data,
act=data,
rew=rew,
terminated=[0, 1, 0, 1],
truncated=[0, 0, 0, 0]
)
)
assert np.all(ptr == [1, -1, 11, -1]) and np.all(ep_idx == [0, -1, 10, -1])
assert np.all(ep_len == [0, 2, 0, 2])
assert np.all(ep_rew == [data, data + 2, data, data + 2])
assert np.allclose(
buf.done, [
0,
0,
1,
0,
0,
0,
1,
1,
0,
0,
0,
0,
0,
0,
0,
0,
1,
0,
0,
0,
]
)
indices = buf.sample_indices(0)
assert np.allclose(indices, [0, 1, 10, 11])
assert np.allclose(buf.prev(indices), [0, 0, 10, 10])
assert np.allclose(buf.next(indices), [1, 1, 11, 11])
def test_multibuf_stack():
size = 5
bufsize = 9
stack_num = 4
cached_num = 3
env = MyTestEnv(size)
# test if CachedReplayBuffer can handle stack_num + ignore_obs_next
buf4 = CachedReplayBuffer(
ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True), cached_num,
size
)
# test if CachedReplayBuffer can handle corner case:
# buffer + stack_num + ignore_obs_next + sample_avail
buf5 = CachedReplayBuffer(
ReplayBuffer(
bufsize, stack_num=stack_num, ignore_obs_next=True, sample_avail=True
), cached_num, size
)
obs, info = env.reset(options={"state": 1})
for i in range(18):
obs_next, rew, terminated, truncated, info = env.step(1)
done = terminated or truncated
obs_list = np.array([obs + size * i for i in range(cached_num)])
act_list = [1] * cached_num
rew_list = [rew] * cached_num
terminated_list = [terminated] * cached_num
truncated_list = [truncated] * cached_num
obs_next_list = -obs_list
info_list = [info] * cached_num
batch = Batch(
obs=obs_list,
act=act_list,
rew=rew_list,
terminated=terminated_list,
truncated=truncated_list,
obs_next=obs_next_list,
info=info_list
)
buf5.add(batch)
buf4.add(batch)
assert np.all(buf4.obs == buf5.obs)
assert np.all(buf4.done == buf5.done)
assert np.all(buf4.terminated == buf5.terminated)
assert np.all(buf4.truncated == buf5.truncated)
obs = obs_next
if done:
obs, info = env.reset(options={"state": 1})
# check the `add` order is correct
assert np.allclose(
buf4.obs.reshape(-1),
[
12,
13,
14,
4,
6,
7,
8,
9,
11, # main_buffer
1,
2,
3,
4,
0, # cached_buffer[0]
6,
7,
8,
9,
0, # cached_buffer[1]
11,
12,
13,
14,
0, # cached_buffer[2]
]
), buf4.obs
assert np.allclose(
buf4.done,
[
0,
0,
1,
1,
0,
0,
0,
1,
0, # main_buffer
0,
0,
0,
1,
0, # cached_buffer[0]
0,
0,
0,
1,
0, # cached_buffer[1]
0,
0,
0,
1,
0, # cached_buffer[2]
]
), buf4.done
assert np.allclose(buf4.unfinished_index(), [10, 15, 20])
indices = sorted(buf4.sample_indices(0))
assert np.allclose(indices, list(range(bufsize)) + [9, 10, 14, 15, 19, 20])
assert np.allclose(
buf4[indices].obs[..., 0], [
[11, 11, 11, 12],
[11, 11, 12, 13],
[11, 12, 13, 14],
[4, 4, 4, 4],
[6, 6, 6, 6],
[6, 6, 6, 7],
[6, 6, 7, 8],
[6, 7, 8, 9],
[11, 11, 11, 11],
[1, 1, 1, 1],
[1, 1, 1, 2],
[6, 6, 6, 6],
[6, 6, 6, 7],
[11, 11, 11, 11],
[11, 11, 11, 12],
]
)
assert np.allclose(
buf4[indices].obs_next[..., 0], [
[11, 11, 12, 13],
[11, 12, 13, 14],
[11, 12, 13, 14],
[4, 4, 4, 4],
[6, 6, 6, 7],
[6, 6, 7, 8],
[6, 7, 8, 9],
[6, 7, 8, 9],
[11, 11, 11, 12],
[1, 1, 1, 2],
[1, 1, 1, 2],
[6, 6, 6, 7],
[6, 6, 6, 7],
[11, 11, 11, 12],
[11, 11, 11, 12],
]
)
indices = buf5.sample_indices(0)
assert np.allclose(sorted(indices), [2, 7])
assert np.all(np.isin(buf5.sample_indices(100), indices))
# manually change the stack num
buf5.stack_num = 2
for buf in buf5.buffers:
buf.stack_num = 2
indices = buf5.sample_indices(0)
assert np.allclose(sorted(indices), [0, 1, 2, 5, 6, 7, 10, 15, 20])
batch, _ = buf5.sample(0)
# test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next
buf6 = CachedReplayBuffer(
ReplayBuffer(
bufsize,
stack_num=stack_num,
save_only_last_obs=True,
ignore_obs_next=True
), cached_num, size
)
obs = np.random.rand(size, 4, 84, 84)
buf6.add(
Batch(
obs=[obs[2], obs[0]],
act=[1, 1],
rew=[0, 0],
terminated=[0, 1],
truncated=[0, 0],
obs_next=[obs[3], obs[1]]
),
buffer_ids=[1, 2]
)
assert buf6.obs.shape == (buf6.maxsize, 84, 84)
assert np.allclose(buf6.obs[0], obs[0, -1])
assert np.allclose(buf6.obs[14], obs[2, -1])
assert np.allclose(buf6.obs[19], obs[0, -1])
assert buf6[0].obs.shape == (4, 84, 84)
def test_multibuf_hdf5():
size = 100
buffers = {
"vector": VectorReplayBuffer(size * 4, 4),
"cached": CachedReplayBuffer(ReplayBuffer(size), 4, size)
}
buffer_types = {k: b.__class__ for k, b in buffers.items()}
device = 'cuda' if torch.cuda.is_available() else 'cpu'
info_t = torch.tensor([1.]).to(device)
for i in range(4):
kwargs = {
'obs': Batch(index=np.array([i])),
'act': i,
'rew': np.array([1, 2]),
'terminated': i % 3 == 2,
'truncated': False,
'done': i % 3 == 2,
'info': {
"number": {
"n": i,
"t": info_t
},
'extra': None
},
}
buffers["vector"].add(
Batch.stack([kwargs, kwargs, kwargs]), buffer_ids=[0, 1, 2]
)
buffers["cached"].add(
Batch.stack([kwargs, kwargs, kwargs]), buffer_ids=[0, 1, 2]
)
# save
paths = {}
for k, buf in buffers.items():
f, path = tempfile.mkstemp(suffix='.hdf5')
os.close(f)
buf.save_hdf5(path)
paths[k] = path
# load replay buffer
_buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()}
# compare
for k in buffers.keys():
assert len(_buffers[k]) == len(buffers[k])
assert np.allclose(_buffers[k].act, buffers[k].act)
assert _buffers[k].stack_num == buffers[k].stack_num
assert _buffers[k].maxsize == buffers[k].maxsize
assert np.all(_buffers[k]._indices == buffers[k]._indices)
# check shallow copy in VectorReplayBuffer
for k in ["vector", "cached"]:
buffers[k].info.number.n[0] = -100
assert buffers[k].buffers[0].info.number.n[0] == -100
# check if still behave normally
for k in ["vector", "cached"]:
kwargs = {
'obs': Batch(index=np.array([5])),
'act': 5,
'rew': np.array([2, 1]),
'terminated': False,
'truncated': False,
'done': False,
'info': {
"number": {
"n": i
},
'Timelimit.truncate': True
},
}
buffers[k].add(Batch.stack([kwargs, kwargs, kwargs, kwargs]))
act = np.zeros(buffers[k].maxsize)
if k == "vector":
act[np.arange(5)] = np.array([0, 1, 2, 3, 5])
act[np.arange(5) + size] = np.array([0, 1, 2, 3, 5])
act[np.arange(5) + size * 2] = np.array([0, 1, 2, 3, 5])
act[size * 3] = 5
elif k == "cached":
act[np.arange(9)] = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2])
act[np.arange(3) + size] = np.array([3, 5, 2])
act[np.arange(3) + size * 2] = np.array([3, 5, 2])
act[np.arange(3) + size * 3] = np.array([3, 5, 2])
act[size * 4] = 5
assert np.allclose(buffers[k].act, act)
info_keys = ["number", "extra", "Timelimit.truncate"]
assert set(buffers[k].info.keys()) == set(info_keys)
for path in paths.values():
os.remove(path)
def test_from_data():
obs_data = np.ndarray((10, 3, 3), dtype="uint8")
for i in range(10):
obs_data[i] = i * np.ones((3, 3), dtype="uint8")
obs_next_data = np.zeros_like(obs_data)
obs_next_data[:-1] = obs_data[1:]
f, path = tempfile.mkstemp(suffix='.hdf5')
os.close(f)
with h5py.File(path, "w") as f:
obs = f.create_dataset("obs", data=obs_data)
act = f.create_dataset("act", data=np.arange(10, dtype="int32"))
rew = f.create_dataset("rew", data=np.arange(10, dtype="float32"))
terminated = f.create_dataset("terminated", data=np.zeros(10, dtype="bool"))
truncated = f.create_dataset("truncated", data=np.zeros(10, dtype="bool"))
done = f.create_dataset("done", data=np.zeros(10, dtype="bool"))
obs_next = f.create_dataset("obs_next", data=obs_next_data)
buf = ReplayBuffer.from_data(
obs, act, rew, terminated, truncated, done, obs_next
)
assert len(buf) == 10
batch = buf[3]
assert np.array_equal(batch.obs, 3 * np.ones((3, 3), dtype="uint8"))
assert batch.act == 3
assert batch.rew == 3.0
assert not batch.done
assert np.array_equal(batch.obs_next, 4 * np.ones((3, 3), dtype="uint8"))
os.remove(path)
if __name__ == '__main__':
test_replaybuffer()
test_ignore_obs_next()
test_stack()
test_segtree()
test_priortized_replaybuffer()
test_update()
test_pickle()
test_hdf5()
test_replaybuffermanager()
test_cachedbuffer()
test_multibuf_stack()
test_multibuf_hdf5()
test_from_data()
test_herreplaybuffer()