This PR aims to provide the script of Atari DQN setting: - A speedrun of PongNoFrameskip-v4 (finished, about half an hour in i7-8750 + GTX1060 with 1M environment steps) - A general script for all atari game Since we use multiple env for simulation, the result is slightly different from the original paper, but consider to be acceptable. It also adds another parameter save_only_last_obs for replay buffer in order to save the memory. Co-authored-by: Trinkle23897 <463003665@qq.com>
291 lines
11 KiB
Python
291 lines
11 KiB
Python
import torch
|
|
import pickle
|
|
import pytest
|
|
import numpy as np
|
|
from timeit import timeit
|
|
|
|
from tianshou.data import Batch, SegmentTree, \
|
|
ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer
|
|
|
|
if __name__ == '__main__':
|
|
from env import MyTestEnv
|
|
else: # pytest
|
|
from test.base.env import MyTestEnv
|
|
|
|
|
|
def test_replaybuffer(size=10, bufsize=20):
|
|
env = MyTestEnv(size)
|
|
buf = ReplayBuffer(bufsize)
|
|
buf.update(buf)
|
|
assert str(buf) == buf.__class__.__name__ + '()'
|
|
obs = env.reset()
|
|
action_list = [1] * 5 + [0] * 10 + [1] * 10
|
|
for i, a in enumerate(action_list):
|
|
obs_next, rew, done, info = env.step(a)
|
|
buf.add(obs, [a], rew, done, obs_next, info)
|
|
obs = obs_next
|
|
assert len(buf) == min(bufsize, i + 1)
|
|
with pytest.raises(ValueError):
|
|
buf._add_to_buffer('rew', np.array([1, 2, 3]))
|
|
assert buf.act.dtype == np.object
|
|
assert isinstance(buf.act[0], list)
|
|
data, indice = buf.sample(bufsize * 2)
|
|
assert (indice < len(buf)).all()
|
|
assert (data.obs < size).all()
|
|
assert (0 <= data.done).all() and (data.done <= 1).all()
|
|
b = ReplayBuffer(size=10)
|
|
b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}})
|
|
assert b.obs[0] == 1
|
|
assert b.done[0] == 'str'
|
|
assert np.all(b.obs[1:] == 0)
|
|
assert np.all(b.done[1:] == np.array(None))
|
|
assert b.info.a[0] == 3 and b.info.a.dtype == np.integer
|
|
assert np.all(b.info.a[1:] == 0)
|
|
assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact
|
|
assert np.all(b.info.b.c[1:] == 0.0)
|
|
with pytest.raises(IndexError):
|
|
b[22]
|
|
b = ListReplayBuffer()
|
|
with pytest.raises(NotImplementedError):
|
|
b.sample(0)
|
|
|
|
|
|
def test_ignore_obs_next(size=10):
|
|
# Issue 82
|
|
buf = ReplayBuffer(size, ignore_obs_next=True)
|
|
for i in range(size):
|
|
buf.add(obs={'mask1': np.array([i, 1, 1, 0, 0]),
|
|
'mask2': np.array([i + 4, 0, 1, 0, 0]),
|
|
'mask': i},
|
|
act={'act_id': i,
|
|
'position_id': i + 3},
|
|
rew=i,
|
|
done=i % 3 == 0,
|
|
info={'if': i})
|
|
indice = np.arange(len(buf))
|
|
orig = np.arange(len(buf))
|
|
data = buf[indice]
|
|
data2 = buf[indice]
|
|
assert isinstance(data, Batch)
|
|
assert isinstance(data2, Batch)
|
|
assert np.allclose(indice, orig)
|
|
assert np.allclose(data.obs_next.mask, data2.obs_next.mask)
|
|
assert np.allclose(data.obs_next.mask, [0, 2, 3, 3, 5, 6, 6, 8, 9, 9])
|
|
buf.stack_num = 4
|
|
data = buf[indice]
|
|
data2 = buf[indice]
|
|
assert np.allclose(data.obs_next.mask, data2.obs_next.mask)
|
|
assert np.allclose(data.obs_next.mask, np.array([
|
|
[0, 0, 0, 0], [1, 1, 1, 2], [1, 1, 2, 3], [1, 1, 2, 3],
|
|
[4, 4, 4, 5], [4, 4, 5, 6], [4, 4, 5, 6],
|
|
[7, 7, 7, 8], [7, 7, 8, 9], [7, 7, 8, 9]]))
|
|
assert np.allclose(data.info['if'], data2.info['if'])
|
|
assert np.allclose(data.info['if'], np.array([
|
|
[0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
|
|
[4, 4, 4, 4], [4, 4, 4, 5], [4, 4, 5, 6],
|
|
[7, 7, 7, 7], [7, 7, 7, 8], [7, 7, 8, 9]]))
|
|
assert data.obs_next
|
|
|
|
|
|
def test_stack(size=5, bufsize=9, stack_num=4):
|
|
env = MyTestEnv(size)
|
|
buf = ReplayBuffer(bufsize, stack_num=stack_num)
|
|
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
|
|
buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True)
|
|
obs = env.reset(1)
|
|
for i in range(16):
|
|
obs_next, rew, done, info = env.step(1)
|
|
buf.add(obs, 1, rew, done, None, info)
|
|
buf2.add(obs, 1, rew, done, None, info)
|
|
buf3.add([None, None, obs], 1, rew, done, [None, obs], info)
|
|
obs = obs_next
|
|
if done:
|
|
obs = env.reset(1)
|
|
indice = np.arange(len(buf))
|
|
assert np.allclose(buf.get(indice, 'obs')[..., 0], [
|
|
[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
|
|
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
|
|
[1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]])
|
|
assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs'))
|
|
assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs_next'))
|
|
_, indice = buf2.sample(0)
|
|
assert indice.tolist() == [2, 6]
|
|
_, indice = buf2.sample(1)
|
|
assert indice in [2, 6]
|
|
with pytest.raises(IndexError):
|
|
buf[bufsize * 2]
|
|
|
|
|
|
def test_priortized_replaybuffer(size=32, bufsize=15):
|
|
env = MyTestEnv(size)
|
|
buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
|
|
obs = env.reset()
|
|
action_list = [1] * 5 + [0] * 10 + [1] * 10
|
|
for i, a in enumerate(action_list):
|
|
obs_next, rew, done, info = env.step(a)
|
|
buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5)
|
|
obs = obs_next
|
|
data, indice = buf.sample(len(buf) // 2)
|
|
if len(buf) // 2 == 0:
|
|
assert len(data) == len(buf)
|
|
else:
|
|
assert len(data) == len(buf) // 2
|
|
assert len(buf) == min(bufsize, i + 1)
|
|
data, indice = buf.sample(len(buf) // 2)
|
|
buf.update_weight(indice, -data.weight / 2)
|
|
assert np.allclose(
|
|
buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha)
|
|
|
|
|
|
def test_update():
|
|
buf1 = ReplayBuffer(4, stack_num=2)
|
|
buf2 = ReplayBuffer(4, stack_num=2)
|
|
for i in range(5):
|
|
buf1.add(obs=np.array([i]), act=float(i), rew=i * i,
|
|
done=i % 2 == 0, info={'incident': 'found'})
|
|
assert len(buf1) > len(buf2)
|
|
buf2.update(buf1)
|
|
assert len(buf1) == len(buf2)
|
|
assert (buf2[0].obs == buf1[1].obs).all()
|
|
assert (buf2[-1].obs == buf1[0].obs).all()
|
|
|
|
|
|
def test_segtree():
|
|
for op, init in zip(['sum', 'max', 'min'], [0., -np.inf, np.inf]):
|
|
realop = getattr(np, op)
|
|
# small test
|
|
actual_len = 8
|
|
tree = SegmentTree(actual_len, op) # 1-15. 8-15 are leaf nodes
|
|
assert len(tree) == actual_len
|
|
assert np.all([tree[i] == init for i in range(actual_len)])
|
|
with pytest.raises(IndexError):
|
|
tree[actual_len]
|
|
naive = np.full([actual_len], init)
|
|
for _ in range(1000):
|
|
# random choose a place to perform single update
|
|
index = np.random.randint(actual_len)
|
|
value = np.random.rand()
|
|
naive[index] = value
|
|
tree[index] = value
|
|
for i in range(actual_len):
|
|
for j in range(i + 1, actual_len):
|
|
ref = realop(naive[i:j])
|
|
out = tree.reduce(i, j)
|
|
assert np.allclose(ref, out)
|
|
assert np.allclose(tree.reduce(start=1), realop(naive[1:]))
|
|
assert np.allclose(tree.reduce(end=-1), realop(naive[:-1]))
|
|
# batch setitem
|
|
for _ in range(1000):
|
|
index = np.random.choice(actual_len, size=4)
|
|
value = np.random.rand(4)
|
|
naive[index] = value
|
|
tree[index] = value
|
|
assert np.allclose(realop(naive), tree.reduce())
|
|
for i in range(10):
|
|
left = np.random.randint(actual_len)
|
|
right = np.random.randint(left + 1, actual_len + 1)
|
|
assert np.allclose(realop(naive[left:right]),
|
|
tree.reduce(left, right))
|
|
# large test
|
|
actual_len = 16384
|
|
tree = SegmentTree(actual_len, op)
|
|
naive = np.full([actual_len], init)
|
|
for _ in range(1000):
|
|
index = np.random.choice(actual_len, size=64)
|
|
value = np.random.rand(64)
|
|
naive[index] = value
|
|
tree[index] = value
|
|
assert np.allclose(realop(naive), tree.reduce())
|
|
for i in range(10):
|
|
left = np.random.randint(actual_len)
|
|
right = np.random.randint(left + 1, actual_len + 1)
|
|
assert np.allclose(realop(naive[left:right]),
|
|
tree.reduce(left, right))
|
|
|
|
# test prefix-sum-idx
|
|
actual_len = 8
|
|
tree = SegmentTree(actual_len)
|
|
naive = np.random.rand(actual_len)
|
|
tree[np.arange(actual_len)] = naive
|
|
for _ in range(1000):
|
|
scalar = np.random.rand() * naive.sum()
|
|
index = tree.get_prefix_sum_idx(scalar)
|
|
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
|
|
# corner case here
|
|
naive = np.ones(actual_len, np.int)
|
|
tree[np.arange(actual_len)] = naive
|
|
for scalar in range(actual_len):
|
|
index = tree.get_prefix_sum_idx(scalar * 1.)
|
|
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
|
|
tree = SegmentTree(10)
|
|
tree[np.arange(3)] = np.array([0.1, 0, 0.1])
|
|
assert np.allclose(tree.get_prefix_sum_idx(
|
|
np.array([0, .1, .1 + 1e-6, .2 - 1e-6])), [0, 0, 2, 2])
|
|
with pytest.raises(AssertionError):
|
|
tree.get_prefix_sum_idx(.2)
|
|
# test large prefix-sum-idx
|
|
actual_len = 16384
|
|
tree = SegmentTree(actual_len)
|
|
naive = np.random.rand(actual_len)
|
|
tree[np.arange(actual_len)] = naive
|
|
for _ in range(1000):
|
|
scalar = np.random.rand() * naive.sum()
|
|
index = tree.get_prefix_sum_idx(scalar)
|
|
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
|
|
|
|
# profile
|
|
if __name__ == '__main__':
|
|
size = 100000
|
|
bsz = 64
|
|
naive = np.random.rand(size)
|
|
tree = SegmentTree(size)
|
|
tree[np.arange(size)] = naive
|
|
|
|
def sample_npbuf():
|
|
return np.random.choice(size, bsz, p=naive / naive.sum())
|
|
|
|
def sample_tree():
|
|
scalar = np.random.rand(bsz) * tree.reduce()
|
|
return tree.get_prefix_sum_idx(scalar)
|
|
|
|
print('npbuf', timeit(sample_npbuf, setup=sample_npbuf, number=1000))
|
|
print('tree', timeit(sample_tree, setup=sample_tree, number=1000))
|
|
|
|
|
|
def test_pickle():
|
|
size = 100
|
|
vbuf = ReplayBuffer(size, stack_num=2)
|
|
lbuf = ListReplayBuffer()
|
|
pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4)
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
rew = torch.tensor([1.]).to(device)
|
|
for i in range(4):
|
|
vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0)
|
|
for i in range(3):
|
|
lbuf.add(obs=Batch(index=np.array([i])), act=1, rew=rew, done=0)
|
|
for i in range(5):
|
|
pbuf.add(obs=Batch(index=np.array([i])),
|
|
act=2, rew=rew, done=0, weight=np.random.rand())
|
|
# save & load
|
|
_vbuf = pickle.loads(pickle.dumps(vbuf))
|
|
_lbuf = pickle.loads(pickle.dumps(lbuf))
|
|
_pbuf = pickle.loads(pickle.dumps(pbuf))
|
|
assert len(_vbuf) == len(vbuf) and np.allclose(_vbuf.act, vbuf.act)
|
|
assert len(_lbuf) == len(lbuf) and np.allclose(_lbuf.act, lbuf.act)
|
|
assert len(_pbuf) == len(pbuf) and np.allclose(_pbuf.act, pbuf.act)
|
|
# make sure the meta var is identical
|
|
assert _vbuf.stack_num == vbuf.stack_num
|
|
assert np.allclose(_pbuf.weight[np.arange(len(_pbuf))],
|
|
pbuf.weight[np.arange(len(pbuf))])
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_replaybuffer()
|
|
test_ignore_obs_next()
|
|
test_stack()
|
|
test_pickle()
|
|
test_segtree()
|
|
test_priortized_replaybuffer()
|
|
test_priortized_replaybuffer(233333, 200000)
|
|
test_update()
|