Tianshou/test/base/test_buffer.py
yingchengyang 5b49192a48
DQN Atari examples ()
This PR aims to provide the script of Atari DQN setting:
- A speedrun of PongNoFrameskip-v4 (finished, about half an hour in i7-8750 + GTX1060 with 1M environment steps)
- A general script for all atari game
Since we use multiple env for simulation, the result is slightly different from the original paper, but consider to be acceptable.

It also adds another parameter save_only_last_obs for replay buffer in order to save the memory.

Co-authored-by: Trinkle23897 <463003665@qq.com>
2020-08-30 05:48:09 +08:00

291 lines
11 KiB
Python

import torch
import pickle
import pytest
import numpy as np
from timeit import timeit
from tianshou.data import Batch, SegmentTree, \
ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer
if __name__ == '__main__':
from env import MyTestEnv
else: # pytest
from test.base.env import MyTestEnv
def test_replaybuffer(size=10, bufsize=20):
env = MyTestEnv(size)
buf = ReplayBuffer(bufsize)
buf.update(buf)
assert str(buf) == buf.__class__.__name__ + '()'
obs = env.reset()
action_list = [1] * 5 + [0] * 10 + [1] * 10
for i, a in enumerate(action_list):
obs_next, rew, done, info = env.step(a)
buf.add(obs, [a], rew, done, obs_next, info)
obs = obs_next
assert len(buf) == min(bufsize, i + 1)
with pytest.raises(ValueError):
buf._add_to_buffer('rew', np.array([1, 2, 3]))
assert buf.act.dtype == np.object
assert isinstance(buf.act[0], list)
data, indice = buf.sample(bufsize * 2)
assert (indice < len(buf)).all()
assert (data.obs < size).all()
assert (0 <= data.done).all() and (data.done <= 1).all()
b = ReplayBuffer(size=10)
b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}})
assert b.obs[0] == 1
assert b.done[0] == 'str'
assert np.all(b.obs[1:] == 0)
assert np.all(b.done[1:] == np.array(None))
assert b.info.a[0] == 3 and b.info.a.dtype == np.integer
assert np.all(b.info.a[1:] == 0)
assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact
assert np.all(b.info.b.c[1:] == 0.0)
with pytest.raises(IndexError):
b[22]
b = ListReplayBuffer()
with pytest.raises(NotImplementedError):
b.sample(0)
def test_ignore_obs_next(size=10):
# Issue 82
buf = ReplayBuffer(size, ignore_obs_next=True)
for i in range(size):
buf.add(obs={'mask1': np.array([i, 1, 1, 0, 0]),
'mask2': np.array([i + 4, 0, 1, 0, 0]),
'mask': i},
act={'act_id': i,
'position_id': i + 3},
rew=i,
done=i % 3 == 0,
info={'if': i})
indice = np.arange(len(buf))
orig = np.arange(len(buf))
data = buf[indice]
data2 = buf[indice]
assert isinstance(data, Batch)
assert isinstance(data2, Batch)
assert np.allclose(indice, orig)
assert np.allclose(data.obs_next.mask, data2.obs_next.mask)
assert np.allclose(data.obs_next.mask, [0, 2, 3, 3, 5, 6, 6, 8, 9, 9])
buf.stack_num = 4
data = buf[indice]
data2 = buf[indice]
assert np.allclose(data.obs_next.mask, data2.obs_next.mask)
assert np.allclose(data.obs_next.mask, np.array([
[0, 0, 0, 0], [1, 1, 1, 2], [1, 1, 2, 3], [1, 1, 2, 3],
[4, 4, 4, 5], [4, 4, 5, 6], [4, 4, 5, 6],
[7, 7, 7, 8], [7, 7, 8, 9], [7, 7, 8, 9]]))
assert np.allclose(data.info['if'], data2.info['if'])
assert np.allclose(data.info['if'], np.array([
[0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
[4, 4, 4, 4], [4, 4, 4, 5], [4, 4, 5, 6],
[7, 7, 7, 7], [7, 7, 7, 8], [7, 7, 8, 9]]))
assert data.obs_next
def test_stack(size=5, bufsize=9, stack_num=4):
env = MyTestEnv(size)
buf = ReplayBuffer(bufsize, stack_num=stack_num)
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True)
obs = env.reset(1)
for i in range(16):
obs_next, rew, done, info = env.step(1)
buf.add(obs, 1, rew, done, None, info)
buf2.add(obs, 1, rew, done, None, info)
buf3.add([None, None, obs], 1, rew, done, [None, obs], info)
obs = obs_next
if done:
obs = env.reset(1)
indice = np.arange(len(buf))
assert np.allclose(buf.get(indice, 'obs')[..., 0], [
[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
[1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]])
assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs'))
assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs_next'))
_, indice = buf2.sample(0)
assert indice.tolist() == [2, 6]
_, indice = buf2.sample(1)
assert indice in [2, 6]
with pytest.raises(IndexError):
buf[bufsize * 2]
def test_priortized_replaybuffer(size=32, bufsize=15):
env = MyTestEnv(size)
buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
obs = env.reset()
action_list = [1] * 5 + [0] * 10 + [1] * 10
for i, a in enumerate(action_list):
obs_next, rew, done, info = env.step(a)
buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5)
obs = obs_next
data, indice = buf.sample(len(buf) // 2)
if len(buf) // 2 == 0:
assert len(data) == len(buf)
else:
assert len(data) == len(buf) // 2
assert len(buf) == min(bufsize, i + 1)
data, indice = buf.sample(len(buf) // 2)
buf.update_weight(indice, -data.weight / 2)
assert np.allclose(
buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha)
def test_update():
buf1 = ReplayBuffer(4, stack_num=2)
buf2 = ReplayBuffer(4, stack_num=2)
for i in range(5):
buf1.add(obs=np.array([i]), act=float(i), rew=i * i,
done=i % 2 == 0, info={'incident': 'found'})
assert len(buf1) > len(buf2)
buf2.update(buf1)
assert len(buf1) == len(buf2)
assert (buf2[0].obs == buf1[1].obs).all()
assert (buf2[-1].obs == buf1[0].obs).all()
def test_segtree():
for op, init in zip(['sum', 'max', 'min'], [0., -np.inf, np.inf]):
realop = getattr(np, op)
# small test
actual_len = 8
tree = SegmentTree(actual_len, op) # 1-15. 8-15 are leaf nodes
assert len(tree) == actual_len
assert np.all([tree[i] == init for i in range(actual_len)])
with pytest.raises(IndexError):
tree[actual_len]
naive = np.full([actual_len], init)
for _ in range(1000):
# random choose a place to perform single update
index = np.random.randint(actual_len)
value = np.random.rand()
naive[index] = value
tree[index] = value
for i in range(actual_len):
for j in range(i + 1, actual_len):
ref = realop(naive[i:j])
out = tree.reduce(i, j)
assert np.allclose(ref, out)
assert np.allclose(tree.reduce(start=1), realop(naive[1:]))
assert np.allclose(tree.reduce(end=-1), realop(naive[:-1]))
# batch setitem
for _ in range(1000):
index = np.random.choice(actual_len, size=4)
value = np.random.rand(4)
naive[index] = value
tree[index] = value
assert np.allclose(realop(naive), tree.reduce())
for i in range(10):
left = np.random.randint(actual_len)
right = np.random.randint(left + 1, actual_len + 1)
assert np.allclose(realop(naive[left:right]),
tree.reduce(left, right))
# large test
actual_len = 16384
tree = SegmentTree(actual_len, op)
naive = np.full([actual_len], init)
for _ in range(1000):
index = np.random.choice(actual_len, size=64)
value = np.random.rand(64)
naive[index] = value
tree[index] = value
assert np.allclose(realop(naive), tree.reduce())
for i in range(10):
left = np.random.randint(actual_len)
right = np.random.randint(left + 1, actual_len + 1)
assert np.allclose(realop(naive[left:right]),
tree.reduce(left, right))
# test prefix-sum-idx
actual_len = 8
tree = SegmentTree(actual_len)
naive = np.random.rand(actual_len)
tree[np.arange(actual_len)] = naive
for _ in range(1000):
scalar = np.random.rand() * naive.sum()
index = tree.get_prefix_sum_idx(scalar)
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
# corner case here
naive = np.ones(actual_len, np.int)
tree[np.arange(actual_len)] = naive
for scalar in range(actual_len):
index = tree.get_prefix_sum_idx(scalar * 1.)
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
tree = SegmentTree(10)
tree[np.arange(3)] = np.array([0.1, 0, 0.1])
assert np.allclose(tree.get_prefix_sum_idx(
np.array([0, .1, .1 + 1e-6, .2 - 1e-6])), [0, 0, 2, 2])
with pytest.raises(AssertionError):
tree.get_prefix_sum_idx(.2)
# test large prefix-sum-idx
actual_len = 16384
tree = SegmentTree(actual_len)
naive = np.random.rand(actual_len)
tree[np.arange(actual_len)] = naive
for _ in range(1000):
scalar = np.random.rand() * naive.sum()
index = tree.get_prefix_sum_idx(scalar)
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
# profile
if __name__ == '__main__':
size = 100000
bsz = 64
naive = np.random.rand(size)
tree = SegmentTree(size)
tree[np.arange(size)] = naive
def sample_npbuf():
return np.random.choice(size, bsz, p=naive / naive.sum())
def sample_tree():
scalar = np.random.rand(bsz) * tree.reduce()
return tree.get_prefix_sum_idx(scalar)
print('npbuf', timeit(sample_npbuf, setup=sample_npbuf, number=1000))
print('tree', timeit(sample_tree, setup=sample_tree, number=1000))
def test_pickle():
size = 100
vbuf = ReplayBuffer(size, stack_num=2)
lbuf = ListReplayBuffer()
pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
rew = torch.tensor([1.]).to(device)
for i in range(4):
vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0)
for i in range(3):
lbuf.add(obs=Batch(index=np.array([i])), act=1, rew=rew, done=0)
for i in range(5):
pbuf.add(obs=Batch(index=np.array([i])),
act=2, rew=rew, done=0, weight=np.random.rand())
# save & load
_vbuf = pickle.loads(pickle.dumps(vbuf))
_lbuf = pickle.loads(pickle.dumps(lbuf))
_pbuf = pickle.loads(pickle.dumps(pbuf))
assert len(_vbuf) == len(vbuf) and np.allclose(_vbuf.act, vbuf.act)
assert len(_lbuf) == len(lbuf) and np.allclose(_lbuf.act, lbuf.act)
assert len(_pbuf) == len(pbuf) and np.allclose(_pbuf.act, pbuf.act)
# make sure the meta var is identical
assert _vbuf.stack_num == vbuf.stack_num
assert np.allclose(_pbuf.weight[np.arange(len(_pbuf))],
pbuf.weight[np.arange(len(pbuf))])
if __name__ == '__main__':
test_replaybuffer()
test_ignore_obs_next()
test_stack()
test_pickle()
test_segtree()
test_priortized_replaybuffer()
test_priortized_replaybuffer(233333, 200000)
test_update()