Training FPS improvement (base commit is 94bfb32): test_pdqn: 1660 (without numba) -> 1930 discrete/test_ppo: 5100 -> 5170 since nstep has little impact on overall performance, the unit test result is: GAE: 4.1s -> 0.057s nstep: 0.3s -> 0.15s (little improvement) Others: - fix a bug in ttt set_eps - keep only sumtree in segment tree implementation - dirty fix for asyncVenv check_id test
		
			
				
	
	
		
			290 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			290 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import torch
 | |
| import pickle
 | |
| import pytest
 | |
| import numpy as np
 | |
| from timeit import timeit
 | |
| 
 | |
| from tianshou.data import Batch, SegmentTree, \
 | |
|     ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     from env import MyTestEnv
 | |
| else:  # pytest
 | |
|     from test.base.env import MyTestEnv
 | |
| 
 | |
| 
 | |
| def test_replaybuffer(size=10, bufsize=20):
 | |
|     env = MyTestEnv(size)
 | |
|     buf = ReplayBuffer(bufsize)
 | |
|     buf.update(buf)
 | |
|     assert str(buf) == buf.__class__.__name__ + '()'
 | |
|     obs = env.reset()
 | |
|     action_list = [1] * 5 + [0] * 10 + [1] * 10
 | |
|     for i, a in enumerate(action_list):
 | |
|         obs_next, rew, done, info = env.step(a)
 | |
|         buf.add(obs, [a], rew, done, obs_next, info)
 | |
|         obs = obs_next
 | |
|         assert len(buf) == min(bufsize, i + 1)
 | |
|     with pytest.raises(ValueError):
 | |
|         buf._add_to_buffer('rew', np.array([1, 2, 3]))
 | |
|     assert buf.act.dtype == np.object
 | |
|     assert isinstance(buf.act[0], list)
 | |
|     data, indice = buf.sample(bufsize * 2)
 | |
|     assert (indice < len(buf)).all()
 | |
|     assert (data.obs < size).all()
 | |
|     assert (0 <= data.done).all() and (data.done <= 1).all()
 | |
|     b = ReplayBuffer(size=10)
 | |
|     b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}})
 | |
|     assert b.obs[0] == 1
 | |
|     assert b.done[0] == 'str'
 | |
|     assert np.all(b.obs[1:] == 0)
 | |
|     assert np.all(b.done[1:] == np.array(None))
 | |
|     assert b.info.a[0] == 3 and b.info.a.dtype == np.integer
 | |
|     assert np.all(b.info.a[1:] == 0)
 | |
|     assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact
 | |
|     assert np.all(b.info.b.c[1:] == 0.0)
 | |
|     with pytest.raises(IndexError):
 | |
|         b[22]
 | |
|     b = ListReplayBuffer()
 | |
|     with pytest.raises(NotImplementedError):
 | |
|         b.sample(0)
 | |
| 
 | |
| 
 | |
| def test_ignore_obs_next(size=10):
 | |
|     # Issue 82
 | |
|     buf = ReplayBuffer(size, ignore_obs_next=True)
 | |
|     for i in range(size):
 | |
|         buf.add(obs={'mask1': np.array([i, 1, 1, 0, 0]),
 | |
|                      'mask2': np.array([i + 4, 0, 1, 0, 0]),
 | |
|                      'mask': i},
 | |
|                 act={'act_id': i,
 | |
|                      'position_id': i + 3},
 | |
|                 rew=i,
 | |
|                 done=i % 3 == 0,
 | |
|                 info={'if': i})
 | |
|     indice = np.arange(len(buf))
 | |
|     orig = np.arange(len(buf))
 | |
|     data = buf[indice]
 | |
|     data2 = buf[indice]
 | |
|     assert isinstance(data, Batch)
 | |
|     assert isinstance(data2, Batch)
 | |
|     assert np.allclose(indice, orig)
 | |
|     assert np.allclose(data.obs_next.mask, data2.obs_next.mask)
 | |
|     assert np.allclose(data.obs_next.mask, [0, 2, 3, 3, 5, 6, 6, 8, 9, 9])
 | |
|     buf.stack_num = 4
 | |
|     data = buf[indice]
 | |
|     data2 = buf[indice]
 | |
|     assert np.allclose(data.obs_next.mask, data2.obs_next.mask)
 | |
|     assert np.allclose(data.obs_next.mask, np.array([
 | |
|         [0, 0, 0, 0], [1, 1, 1, 2], [1, 1, 2, 3], [1, 1, 2, 3],
 | |
|         [4, 4, 4, 5], [4, 4, 5, 6], [4, 4, 5, 6],
 | |
|         [7, 7, 7, 8], [7, 7, 8, 9], [7, 7, 8, 9]]))
 | |
|     assert np.allclose(data.info['if'], data2.info['if'])
 | |
|     assert np.allclose(data.info['if'], np.array([
 | |
|         [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
 | |
|         [4, 4, 4, 4], [4, 4, 4, 5], [4, 4, 5, 6],
 | |
|         [7, 7, 7, 7], [7, 7, 7, 8], [7, 7, 8, 9]]))
 | |
|     assert data.obs_next
 | |
| 
 | |
| 
 | |
| def test_stack(size=5, bufsize=9, stack_num=4):
 | |
|     env = MyTestEnv(size)
 | |
|     buf = ReplayBuffer(bufsize, stack_num=stack_num)
 | |
|     buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
 | |
|     buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True)
 | |
|     obs = env.reset(1)
 | |
|     for i in range(16):
 | |
|         obs_next, rew, done, info = env.step(1)
 | |
|         buf.add(obs, 1, rew, done, None, info)
 | |
|         buf2.add(obs, 1, rew, done, None, info)
 | |
|         buf3.add([None, None, obs], 1, rew, done, [None, obs], info)
 | |
|         obs = obs_next
 | |
|         if done:
 | |
|             obs = env.reset(1)
 | |
|     indice = np.arange(len(buf))
 | |
|     assert np.allclose(buf.get(indice, 'obs')[..., 0], [
 | |
|         [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
 | |
|         [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
 | |
|         [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]])
 | |
|     assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs'))
 | |
|     assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs_next'))
 | |
|     _, indice = buf2.sample(0)
 | |
|     assert indice.tolist() == [2, 6]
 | |
|     _, indice = buf2.sample(1)
 | |
|     assert indice in [2, 6]
 | |
|     with pytest.raises(IndexError):
 | |
|         buf[bufsize * 2]
 | |
| 
 | |
| 
 | |
| def test_priortized_replaybuffer(size=32, bufsize=15):
 | |
|     env = MyTestEnv(size)
 | |
|     buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
 | |
|     obs = env.reset()
 | |
|     action_list = [1] * 5 + [0] * 10 + [1] * 10
 | |
|     for i, a in enumerate(action_list):
 | |
|         obs_next, rew, done, info = env.step(a)
 | |
|         buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5)
 | |
|         obs = obs_next
 | |
|         data, indice = buf.sample(len(buf) // 2)
 | |
|         if len(buf) // 2 == 0:
 | |
|             assert len(data) == len(buf)
 | |
|         else:
 | |
|             assert len(data) == len(buf) // 2
 | |
|         assert len(buf) == min(bufsize, i + 1)
 | |
|     data, indice = buf.sample(len(buf) // 2)
 | |
|     buf.update_weight(indice, -data.weight / 2)
 | |
|     assert np.allclose(
 | |
|         buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha)
 | |
| 
 | |
| 
 | |
| def test_update():
 | |
|     buf1 = ReplayBuffer(4, stack_num=2)
 | |
|     buf2 = ReplayBuffer(4, stack_num=2)
 | |
|     for i in range(5):
 | |
|         buf1.add(obs=np.array([i]), act=float(i), rew=i * i,
 | |
|                  done=i % 2 == 0, info={'incident': 'found'})
 | |
|     assert len(buf1) > len(buf2)
 | |
|     buf2.update(buf1)
 | |
|     assert len(buf1) == len(buf2)
 | |
|     assert (buf2[0].obs == buf1[1].obs).all()
 | |
|     assert (buf2[-1].obs == buf1[0].obs).all()
 | |
| 
 | |
| 
 | |
| def test_segtree():
 | |
|     realop = np.sum
 | |
|     # small test
 | |
|     actual_len = 8
 | |
|     tree = SegmentTree(actual_len)  # 1-15. 8-15 are leaf nodes
 | |
|     assert len(tree) == actual_len
 | |
|     assert np.all([tree[i] == 0. for i in range(actual_len)])
 | |
|     with pytest.raises(IndexError):
 | |
|         tree[actual_len]
 | |
|     naive = np.zeros([actual_len])
 | |
|     for _ in range(1000):
 | |
|         # random choose a place to perform single update
 | |
|         index = np.random.randint(actual_len)
 | |
|         value = np.random.rand()
 | |
|         naive[index] = value
 | |
|         tree[index] = value
 | |
|         for i in range(actual_len):
 | |
|             for j in range(i + 1, actual_len):
 | |
|                 ref = realop(naive[i:j])
 | |
|                 out = tree.reduce(i, j)
 | |
|                 assert np.allclose(ref, out), (ref, out)
 | |
|     assert np.allclose(tree.reduce(start=1), realop(naive[1:]))
 | |
|     assert np.allclose(tree.reduce(end=-1), realop(naive[:-1]))
 | |
|     # batch setitem
 | |
|     for _ in range(1000):
 | |
|         index = np.random.choice(actual_len, size=4)
 | |
|         value = np.random.rand(4)
 | |
|         naive[index] = value
 | |
|         tree[index] = value
 | |
|         assert np.allclose(realop(naive), tree.reduce())
 | |
|         for i in range(10):
 | |
|             left = np.random.randint(actual_len)
 | |
|             right = np.random.randint(left + 1, actual_len + 1)
 | |
|             assert np.allclose(realop(naive[left:right]),
 | |
|                                tree.reduce(left, right))
 | |
|     # large test
 | |
|     actual_len = 16384
 | |
|     tree = SegmentTree(actual_len)
 | |
|     naive = np.zeros([actual_len])
 | |
|     for _ in range(1000):
 | |
|         index = np.random.choice(actual_len, size=64)
 | |
|         value = np.random.rand(64)
 | |
|         naive[index] = value
 | |
|         tree[index] = value
 | |
|         assert np.allclose(realop(naive), tree.reduce())
 | |
|         for i in range(10):
 | |
|             left = np.random.randint(actual_len)
 | |
|             right = np.random.randint(left + 1, actual_len + 1)
 | |
|             assert np.allclose(realop(naive[left:right]),
 | |
|                                tree.reduce(left, right))
 | |
| 
 | |
|     # test prefix-sum-idx
 | |
|     actual_len = 8
 | |
|     tree = SegmentTree(actual_len)
 | |
|     naive = np.random.rand(actual_len)
 | |
|     tree[np.arange(actual_len)] = naive
 | |
|     for _ in range(1000):
 | |
|         scalar = np.random.rand() * naive.sum()
 | |
|         index = tree.get_prefix_sum_idx(scalar)
 | |
|         assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
 | |
|     # corner case here
 | |
|     naive = np.ones(actual_len, np.int)
 | |
|     tree[np.arange(actual_len)] = naive
 | |
|     for scalar in range(actual_len):
 | |
|         index = tree.get_prefix_sum_idx(scalar * 1.)
 | |
|         assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
 | |
|     tree = SegmentTree(10)
 | |
|     tree[np.arange(3)] = np.array([0.1, 0, 0.1])
 | |
|     assert np.allclose(tree.get_prefix_sum_idx(
 | |
|         np.array([0, .1, .1 + 1e-6, .2 - 1e-6])), [0, 0, 2, 2])
 | |
|     with pytest.raises(AssertionError):
 | |
|         tree.get_prefix_sum_idx(.2)
 | |
|     # test large prefix-sum-idx
 | |
|     actual_len = 16384
 | |
|     tree = SegmentTree(actual_len)
 | |
|     naive = np.random.rand(actual_len)
 | |
|     tree[np.arange(actual_len)] = naive
 | |
|     for _ in range(1000):
 | |
|         scalar = np.random.rand() * naive.sum()
 | |
|         index = tree.get_prefix_sum_idx(scalar)
 | |
|         assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
 | |
| 
 | |
|     # profile
 | |
|     if __name__ == '__main__':
 | |
|         size = 100000
 | |
|         bsz = 64
 | |
|         naive = np.random.rand(size)
 | |
|         tree = SegmentTree(size)
 | |
|         tree[np.arange(size)] = naive
 | |
| 
 | |
|         def sample_npbuf():
 | |
|             return np.random.choice(size, bsz, p=naive / naive.sum())
 | |
| 
 | |
|         def sample_tree():
 | |
|             scalar = np.random.rand(bsz) * tree.reduce()
 | |
|             return tree.get_prefix_sum_idx(scalar)
 | |
| 
 | |
|         print('npbuf', timeit(sample_npbuf, setup=sample_npbuf, number=1000))
 | |
|         print('tree', timeit(sample_tree, setup=sample_tree, number=1000))
 | |
| 
 | |
| 
 | |
| def test_pickle():
 | |
|     size = 100
 | |
|     vbuf = ReplayBuffer(size, stack_num=2)
 | |
|     lbuf = ListReplayBuffer()
 | |
|     pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4)
 | |
|     device = 'cuda' if torch.cuda.is_available() else 'cpu'
 | |
|     rew = torch.tensor([1.]).to(device)
 | |
|     for i in range(4):
 | |
|         vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0)
 | |
|     for i in range(3):
 | |
|         lbuf.add(obs=Batch(index=np.array([i])), act=1, rew=rew, done=0)
 | |
|     for i in range(5):
 | |
|         pbuf.add(obs=Batch(index=np.array([i])),
 | |
|                  act=2, rew=rew, done=0, weight=np.random.rand())
 | |
|     # save & load
 | |
|     _vbuf = pickle.loads(pickle.dumps(vbuf))
 | |
|     _lbuf = pickle.loads(pickle.dumps(lbuf))
 | |
|     _pbuf = pickle.loads(pickle.dumps(pbuf))
 | |
|     assert len(_vbuf) == len(vbuf) and np.allclose(_vbuf.act, vbuf.act)
 | |
|     assert len(_lbuf) == len(lbuf) and np.allclose(_lbuf.act, lbuf.act)
 | |
|     assert len(_pbuf) == len(pbuf) and np.allclose(_pbuf.act, pbuf.act)
 | |
|     # make sure the meta var is identical
 | |
|     assert _vbuf.stack_num == vbuf.stack_num
 | |
|     assert np.allclose(_pbuf.weight[np.arange(len(_pbuf))],
 | |
|                        pbuf.weight[np.arange(len(pbuf))])
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     test_replaybuffer()
 | |
|     test_ignore_obs_next()
 | |
|     test_stack()
 | |
|     test_pickle()
 | |
|     test_segtree()
 | |
|     test_priortized_replaybuffer()
 | |
|     test_priortized_replaybuffer(233333, 200000)
 | |
|     test_update()
 |