| 
									
										
										
										
											2020-08-16 16:26:23 +08:00
										 |  |  | import torch | 
					
						
							|  |  |  | import pickle | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  | import pytest | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  | from timeit import timeit | 
					
						
							| 
									
										
										
										
											2020-07-20 22:12:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-16 16:26:23 +08:00
										 |  |  | from tianshou.data import Batch, SegmentTree, \ | 
					
						
							|  |  |  |     ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer | 
					
						
							| 
									
										
										
										
											2020-03-26 09:01:20 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-11 18:02:19 +08:00
										 |  |  | if __name__ == '__main__': | 
					
						
							| 
									
										
										
										
											2020-03-21 10:58:01 +08:00
										 |  |  |     from env import MyTestEnv | 
					
						
							| 
									
										
										
										
											2020-03-12 22:20:33 +08:00
										 |  |  | else:  # pytest | 
					
						
							| 
									
										
										
										
											2020-03-21 10:58:01 +08:00
										 |  |  |     from test.base.env import MyTestEnv | 
					
						
							| 
									
										
										
										
											2020-03-11 17:28:51 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-11 18:02:19 +08:00
										 |  |  | def test_replaybuffer(size=10, bufsize=20): | 
					
						
							|  |  |  |     env = MyTestEnv(size) | 
					
						
							| 
									
										
										
										
											2020-03-11 17:28:51 +08:00
										 |  |  |     buf = ReplayBuffer(bufsize) | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     buf.update(buf) | 
					
						
							|  |  |  |     assert str(buf) == buf.__class__.__name__ + '()' | 
					
						
							| 
									
										
										
										
											2020-03-11 17:28:51 +08:00
										 |  |  |     obs = env.reset() | 
					
						
							| 
									
										
										
										
											2020-03-16 11:11:29 +08:00
										 |  |  |     action_list = [1] * 5 + [0] * 10 + [1] * 10 | 
					
						
							| 
									
										
										
										
											2020-03-11 17:28:51 +08:00
										 |  |  |     for i, a in enumerate(action_list): | 
					
						
							|  |  |  |         obs_next, rew, done, info = env.step(a) | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |         buf.add(obs, [a], rew, done, obs_next, info) | 
					
						
							| 
									
										
										
										
											2020-03-16 11:11:29 +08:00
										 |  |  |         obs = obs_next | 
					
						
							| 
									
										
										
										
											2020-06-08 21:53:00 +08:00
										 |  |  |         assert len(buf) == min(bufsize, i + 1) | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     with pytest.raises(ValueError): | 
					
						
							|  |  |  |         buf._add_to_buffer('rew', np.array([1, 2, 3])) | 
					
						
							|  |  |  |     assert buf.act.dtype == np.object | 
					
						
							|  |  |  |     assert isinstance(buf.act[0], list) | 
					
						
							| 
									
										
										
										
											2020-03-14 21:48:31 +08:00
										 |  |  |     data, indice = buf.sample(bufsize * 2) | 
					
						
							| 
									
										
										
										
											2020-03-11 18:02:19 +08:00
										 |  |  |     assert (indice < len(buf)).all() | 
					
						
							|  |  |  |     assert (data.obs < size).all() | 
					
						
							|  |  |  |     assert (0 <= data.done).all() and (data.done <= 1).all() | 
					
						
							| 
									
										
										
										
											2020-06-26 12:37:50 +02:00
										 |  |  |     b = ReplayBuffer(size=10) | 
					
						
							|  |  |  |     b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}}) | 
					
						
							|  |  |  |     assert b.obs[0] == 1 | 
					
						
							|  |  |  |     assert b.done[0] == 'str' | 
					
						
							|  |  |  |     assert np.all(b.obs[1:] == 0) | 
					
						
							|  |  |  |     assert np.all(b.done[1:] == np.array(None)) | 
					
						
							|  |  |  |     assert b.info.a[0] == 3 and b.info.a.dtype == np.integer | 
					
						
							|  |  |  |     assert np.all(b.info.a[1:] == 0) | 
					
						
							|  |  |  |     assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact | 
					
						
							| 
									
										
										
										
											2020-06-27 03:06:40 +02:00
										 |  |  |     assert np.all(b.info.b.c[1:] == 0.0) | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     with pytest.raises(IndexError): | 
					
						
							|  |  |  |         b[22] | 
					
						
							|  |  |  |     b = ListReplayBuffer() | 
					
						
							|  |  |  |     with pytest.raises(NotImplementedError): | 
					
						
							|  |  |  |         b.sample(0) | 
					
						
							| 
									
										
										
										
											2020-03-11 17:28:51 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-17 16:37:51 +08:00
										 |  |  | def test_ignore_obs_next(size=10): | 
					
						
							|  |  |  |     # Issue 82 | 
					
						
							| 
									
										
										
										
											2020-08-16 16:26:23 +08:00
										 |  |  |     buf = ReplayBuffer(size, ignore_obs_next=True) | 
					
						
							| 
									
										
										
										
											2020-06-17 16:37:51 +08:00
										 |  |  |     for i in range(size): | 
					
						
							|  |  |  |         buf.add(obs={'mask1': np.array([i, 1, 1, 0, 0]), | 
					
						
							| 
									
										
										
										
											2020-08-16 16:26:23 +08:00
										 |  |  |                      'mask2': np.array([i + 4, 0, 1, 0, 0]), | 
					
						
							|  |  |  |                      'mask': i}, | 
					
						
							| 
									
										
										
										
											2020-06-17 16:37:51 +08:00
										 |  |  |                 act={'act_id': i, | 
					
						
							|  |  |  |                      'position_id': i + 3}, | 
					
						
							|  |  |  |                 rew=i, | 
					
						
							|  |  |  |                 done=i % 3 == 0, | 
					
						
							|  |  |  |                 info={'if': i}) | 
					
						
							|  |  |  |     indice = np.arange(len(buf)) | 
					
						
							|  |  |  |     orig = np.arange(len(buf)) | 
					
						
							|  |  |  |     data = buf[indice] | 
					
						
							|  |  |  |     data2 = buf[indice] | 
					
						
							|  |  |  |     assert isinstance(data, Batch) | 
					
						
							|  |  |  |     assert isinstance(data2, Batch) | 
					
						
							|  |  |  |     assert np.allclose(indice, orig) | 
					
						
							| 
									
										
										
										
											2020-08-16 16:26:23 +08:00
										 |  |  |     assert np.allclose(data.obs_next.mask, data2.obs_next.mask) | 
					
						
							|  |  |  |     assert np.allclose(data.obs_next.mask, [0, 2, 3, 3, 5, 6, 6, 8, 9, 9]) | 
					
						
							|  |  |  |     buf.stack_num = 4 | 
					
						
							|  |  |  |     data = buf[indice] | 
					
						
							|  |  |  |     data2 = buf[indice] | 
					
						
							|  |  |  |     assert np.allclose(data.obs_next.mask, data2.obs_next.mask) | 
					
						
							|  |  |  |     assert np.allclose(data.obs_next.mask, np.array([ | 
					
						
							|  |  |  |         [0, 0, 0, 0], [1, 1, 1, 2], [1, 1, 2, 3], [1, 1, 2, 3], | 
					
						
							|  |  |  |         [4, 4, 4, 5], [4, 4, 5, 6], [4, 4, 5, 6], | 
					
						
							|  |  |  |         [7, 7, 7, 8], [7, 7, 8, 9], [7, 7, 8, 9]])) | 
					
						
							|  |  |  |     assert np.allclose(data.info['if'], data2.info['if']) | 
					
						
							|  |  |  |     assert np.allclose(data.info['if'], np.array([ | 
					
						
							|  |  |  |         [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], | 
					
						
							|  |  |  |         [4, 4, 4, 4], [4, 4, 4, 5], [4, 4, 5, 6], | 
					
						
							|  |  |  |         [7, 7, 7, 7], [7, 7, 7, 8], [7, 7, 8, 9]])) | 
					
						
							|  |  |  |     assert data.obs_next | 
					
						
							| 
									
										
										
										
											2020-06-17 16:37:51 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  | def test_stack(size=5, bufsize=9, stack_num=4): | 
					
						
							|  |  |  |     env = MyTestEnv(size) | 
					
						
							| 
									
										
										
										
											2020-06-29 12:18:52 +08:00
										 |  |  |     buf = ReplayBuffer(bufsize, stack_num=stack_num) | 
					
						
							|  |  |  |     buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) | 
					
						
							| 
									
										
										
										
											2020-08-30 05:48:09 +08:00
										 |  |  |     buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True) | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |     obs = env.reset(1) | 
					
						
							| 
									
										
										
										
											2020-08-16 16:26:23 +08:00
										 |  |  |     for i in range(16): | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |         obs_next, rew, done, info = env.step(1) | 
					
						
							|  |  |  |         buf.add(obs, 1, rew, done, None, info) | 
					
						
							| 
									
										
										
										
											2020-06-29 12:18:52 +08:00
										 |  |  |         buf2.add(obs, 1, rew, done, None, info) | 
					
						
							| 
									
										
										
										
											2020-08-30 05:48:09 +08:00
										 |  |  |         buf3.add([None, None, obs], 1, rew, done, [None, obs], info) | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |         obs = obs_next | 
					
						
							|  |  |  |         if done: | 
					
						
							|  |  |  |             obs = env.reset(1) | 
					
						
							|  |  |  |     indice = np.arange(len(buf)) | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     assert np.allclose(buf.get(indice, 'obs')[..., 0], [ | 
					
						
							|  |  |  |         [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], | 
					
						
							|  |  |  |         [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], | 
					
						
							|  |  |  |         [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]]) | 
					
						
							| 
									
										
										
										
											2020-08-30 05:48:09 +08:00
										 |  |  |     assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs')) | 
					
						
							|  |  |  |     assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs_next')) | 
					
						
							| 
									
										
										
										
											2020-06-29 12:18:52 +08:00
										 |  |  |     _, indice = buf2.sample(0) | 
					
						
							| 
									
										
										
										
											2020-08-16 16:26:23 +08:00
										 |  |  |     assert indice.tolist() == [2, 6] | 
					
						
							| 
									
										
										
										
											2020-06-29 12:18:52 +08:00
										 |  |  |     _, indice = buf2.sample(1) | 
					
						
							| 
									
										
										
										
											2020-08-16 16:26:23 +08:00
										 |  |  |     assert indice in [2, 6] | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     with pytest.raises(IndexError): | 
					
						
							|  |  |  |         buf[bufsize * 2] | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-26 15:11:20 +08:00
										 |  |  | def test_priortized_replaybuffer(size=32, bufsize=15): | 
					
						
							|  |  |  |     env = MyTestEnv(size) | 
					
						
							|  |  |  |     buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5) | 
					
						
							|  |  |  |     obs = env.reset() | 
					
						
							|  |  |  |     action_list = [1] * 5 + [0] * 10 + [1] * 10 | 
					
						
							|  |  |  |     for i, a in enumerate(action_list): | 
					
						
							|  |  |  |         obs_next, rew, done, info = env.step(a) | 
					
						
							|  |  |  |         buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5) | 
					
						
							|  |  |  |         obs = obs_next | 
					
						
							|  |  |  |         data, indice = buf.sample(len(buf) // 2) | 
					
						
							|  |  |  |         if len(buf) // 2 == 0: | 
					
						
							|  |  |  |             assert len(data) == len(buf) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             assert len(data) == len(buf) // 2 | 
					
						
							| 
									
										
										
										
											2020-06-08 21:53:00 +08:00
										 |  |  |         assert len(buf) == min(bufsize, i + 1) | 
					
						
							| 
									
										
										
										
											2020-04-26 15:11:20 +08:00
										 |  |  |     data, indice = buf.sample(len(buf) // 2) | 
					
						
							|  |  |  |     buf.update_weight(indice, -data.weight / 2) | 
					
						
							| 
									
										
										
										
											2020-07-10 08:24:11 +08:00
										 |  |  |     assert np.allclose( | 
					
						
							|  |  |  |         buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha) | 
					
						
							| 
									
										
										
										
											2020-04-26 15:11:20 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-20 22:12:57 +08:00
										 |  |  | def test_update(): | 
					
						
							|  |  |  |     buf1 = ReplayBuffer(4, stack_num=2) | 
					
						
							|  |  |  |     buf2 = ReplayBuffer(4, stack_num=2) | 
					
						
							|  |  |  |     for i in range(5): | 
					
						
							|  |  |  |         buf1.add(obs=np.array([i]), act=float(i), rew=i * i, | 
					
						
							| 
									
										
										
										
											2020-08-16 16:26:23 +08:00
										 |  |  |                  done=i % 2 == 0, info={'incident': 'found'}) | 
					
						
							| 
									
										
										
										
											2020-07-20 22:12:57 +08:00
										 |  |  |     assert len(buf1) > len(buf2) | 
					
						
							|  |  |  |     buf2.update(buf1) | 
					
						
							|  |  |  |     assert len(buf1) == len(buf2) | 
					
						
							|  |  |  |     assert (buf2[0].obs == buf1[1].obs).all() | 
					
						
							|  |  |  |     assert (buf2[-1].obs == buf1[0].obs).all() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  | def test_segtree(): | 
					
						
							| 
									
										
										
										
											2020-09-02 13:03:32 +08:00
										 |  |  |     realop = np.sum | 
					
						
							|  |  |  |     # small test | 
					
						
							|  |  |  |     actual_len = 8 | 
					
						
							|  |  |  |     tree = SegmentTree(actual_len)  # 1-15. 8-15 are leaf nodes | 
					
						
							|  |  |  |     assert len(tree) == actual_len | 
					
						
							|  |  |  |     assert np.all([tree[i] == 0. for i in range(actual_len)]) | 
					
						
							|  |  |  |     with pytest.raises(IndexError): | 
					
						
							|  |  |  |         tree[actual_len] | 
					
						
							|  |  |  |     naive = np.zeros([actual_len]) | 
					
						
							|  |  |  |     for _ in range(1000): | 
					
						
							|  |  |  |         # random choose a place to perform single update | 
					
						
							|  |  |  |         index = np.random.randint(actual_len) | 
					
						
							|  |  |  |         value = np.random.rand() | 
					
						
							|  |  |  |         naive[index] = value | 
					
						
							|  |  |  |         tree[index] = value | 
					
						
							|  |  |  |         for i in range(actual_len): | 
					
						
							|  |  |  |             for j in range(i + 1, actual_len): | 
					
						
							|  |  |  |                 ref = realop(naive[i:j]) | 
					
						
							|  |  |  |                 out = tree.reduce(i, j) | 
					
						
							|  |  |  |                 assert np.allclose(ref, out), (ref, out) | 
					
						
							|  |  |  |     assert np.allclose(tree.reduce(start=1), realop(naive[1:])) | 
					
						
							|  |  |  |     assert np.allclose(tree.reduce(end=-1), realop(naive[:-1])) | 
					
						
							|  |  |  |     # batch setitem | 
					
						
							|  |  |  |     for _ in range(1000): | 
					
						
							|  |  |  |         index = np.random.choice(actual_len, size=4) | 
					
						
							|  |  |  |         value = np.random.rand(4) | 
					
						
							|  |  |  |         naive[index] = value | 
					
						
							|  |  |  |         tree[index] = value | 
					
						
							|  |  |  |         assert np.allclose(realop(naive), tree.reduce()) | 
					
						
							|  |  |  |         for i in range(10): | 
					
						
							|  |  |  |             left = np.random.randint(actual_len) | 
					
						
							|  |  |  |             right = np.random.randint(left + 1, actual_len + 1) | 
					
						
							|  |  |  |             assert np.allclose(realop(naive[left:right]), | 
					
						
							|  |  |  |                                tree.reduce(left, right)) | 
					
						
							|  |  |  |     # large test | 
					
						
							|  |  |  |     actual_len = 16384 | 
					
						
							|  |  |  |     tree = SegmentTree(actual_len) | 
					
						
							|  |  |  |     naive = np.zeros([actual_len]) | 
					
						
							|  |  |  |     for _ in range(1000): | 
					
						
							|  |  |  |         index = np.random.choice(actual_len, size=64) | 
					
						
							|  |  |  |         value = np.random.rand(64) | 
					
						
							|  |  |  |         naive[index] = value | 
					
						
							|  |  |  |         tree[index] = value | 
					
						
							|  |  |  |         assert np.allclose(realop(naive), tree.reduce()) | 
					
						
							|  |  |  |         for i in range(10): | 
					
						
							|  |  |  |             left = np.random.randint(actual_len) | 
					
						
							|  |  |  |             right = np.random.randint(left + 1, actual_len + 1) | 
					
						
							|  |  |  |             assert np.allclose(realop(naive[left:right]), | 
					
						
							|  |  |  |                                tree.reduce(left, right)) | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # test prefix-sum-idx | 
					
						
							|  |  |  |     actual_len = 8 | 
					
						
							|  |  |  |     tree = SegmentTree(actual_len) | 
					
						
							|  |  |  |     naive = np.random.rand(actual_len) | 
					
						
							|  |  |  |     tree[np.arange(actual_len)] = naive | 
					
						
							|  |  |  |     for _ in range(1000): | 
					
						
							|  |  |  |         scalar = np.random.rand() * naive.sum() | 
					
						
							|  |  |  |         index = tree.get_prefix_sum_idx(scalar) | 
					
						
							|  |  |  |         assert naive[:index].sum() <= scalar <= naive[:index + 1].sum() | 
					
						
							|  |  |  |     # corner case here | 
					
						
							|  |  |  |     naive = np.ones(actual_len, np.int) | 
					
						
							|  |  |  |     tree[np.arange(actual_len)] = naive | 
					
						
							|  |  |  |     for scalar in range(actual_len): | 
					
						
							|  |  |  |         index = tree.get_prefix_sum_idx(scalar * 1.) | 
					
						
							|  |  |  |         assert naive[:index].sum() <= scalar <= naive[:index + 1].sum() | 
					
						
							|  |  |  |     tree = SegmentTree(10) | 
					
						
							|  |  |  |     tree[np.arange(3)] = np.array([0.1, 0, 0.1]) | 
					
						
							|  |  |  |     assert np.allclose(tree.get_prefix_sum_idx( | 
					
						
							|  |  |  |         np.array([0, .1, .1 + 1e-6, .2 - 1e-6])), [0, 0, 2, 2]) | 
					
						
							|  |  |  |     with pytest.raises(AssertionError): | 
					
						
							|  |  |  |         tree.get_prefix_sum_idx(.2) | 
					
						
							|  |  |  |     # test large prefix-sum-idx | 
					
						
							|  |  |  |     actual_len = 16384 | 
					
						
							|  |  |  |     tree = SegmentTree(actual_len) | 
					
						
							|  |  |  |     naive = np.random.rand(actual_len) | 
					
						
							|  |  |  |     tree[np.arange(actual_len)] = naive | 
					
						
							|  |  |  |     for _ in range(1000): | 
					
						
							|  |  |  |         scalar = np.random.rand() * naive.sum() | 
					
						
							|  |  |  |         index = tree.get_prefix_sum_idx(scalar) | 
					
						
							|  |  |  |         assert naive[:index].sum() <= scalar <= naive[:index + 1].sum() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # profile | 
					
						
							|  |  |  |     if __name__ == '__main__': | 
					
						
							|  |  |  |         size = 100000 | 
					
						
							|  |  |  |         bsz = 64 | 
					
						
							|  |  |  |         naive = np.random.rand(size) | 
					
						
							|  |  |  |         tree = SegmentTree(size) | 
					
						
							|  |  |  |         tree[np.arange(size)] = naive | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def sample_npbuf(): | 
					
						
							|  |  |  |             return np.random.choice(size, bsz, p=naive / naive.sum()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def sample_tree(): | 
					
						
							|  |  |  |             scalar = np.random.rand(bsz) * tree.reduce() | 
					
						
							|  |  |  |             return tree.get_prefix_sum_idx(scalar) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         print('npbuf', timeit(sample_npbuf, setup=sample_npbuf, number=1000)) | 
					
						
							|  |  |  |         print('tree', timeit(sample_tree, setup=sample_tree, number=1000)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-16 16:26:23 +08:00
										 |  |  | def test_pickle(): | 
					
						
							|  |  |  |     size = 100 | 
					
						
							|  |  |  |     vbuf = ReplayBuffer(size, stack_num=2) | 
					
						
							|  |  |  |     lbuf = ListReplayBuffer() | 
					
						
							|  |  |  |     pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4) | 
					
						
							|  |  |  |     device = 'cuda' if torch.cuda.is_available() else 'cpu' | 
					
						
							|  |  |  |     rew = torch.tensor([1.]).to(device) | 
					
						
							|  |  |  |     for i in range(4): | 
					
						
							|  |  |  |         vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0) | 
					
						
							|  |  |  |     for i in range(3): | 
					
						
							|  |  |  |         lbuf.add(obs=Batch(index=np.array([i])), act=1, rew=rew, done=0) | 
					
						
							|  |  |  |     for i in range(5): | 
					
						
							|  |  |  |         pbuf.add(obs=Batch(index=np.array([i])), | 
					
						
							|  |  |  |                  act=2, rew=rew, done=0, weight=np.random.rand()) | 
					
						
							|  |  |  |     # save & load | 
					
						
							|  |  |  |     _vbuf = pickle.loads(pickle.dumps(vbuf)) | 
					
						
							|  |  |  |     _lbuf = pickle.loads(pickle.dumps(lbuf)) | 
					
						
							|  |  |  |     _pbuf = pickle.loads(pickle.dumps(pbuf)) | 
					
						
							|  |  |  |     assert len(_vbuf) == len(vbuf) and np.allclose(_vbuf.act, vbuf.act) | 
					
						
							|  |  |  |     assert len(_lbuf) == len(lbuf) and np.allclose(_lbuf.act, lbuf.act) | 
					
						
							|  |  |  |     assert len(_pbuf) == len(pbuf) and np.allclose(_pbuf.act, pbuf.act) | 
					
						
							|  |  |  |     # make sure the meta var is identical | 
					
						
							|  |  |  |     assert _vbuf.stack_num == vbuf.stack_num | 
					
						
							|  |  |  |     assert np.allclose(_pbuf.weight[np.arange(len(_pbuf))], | 
					
						
							|  |  |  |                        pbuf.weight[np.arange(len(pbuf))]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-11 17:28:51 +08:00
										 |  |  | if __name__ == '__main__': | 
					
						
							| 
									
										
										
										
											2020-03-11 18:02:19 +08:00
										 |  |  |     test_replaybuffer() | 
					
						
							| 
									
										
										
										
											2020-06-17 16:37:51 +08:00
										 |  |  |     test_ignore_obs_next() | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |     test_stack() | 
					
						
							| 
									
										
										
										
											2020-08-16 16:26:23 +08:00
										 |  |  |     test_pickle() | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  |     test_segtree() | 
					
						
							|  |  |  |     test_priortized_replaybuffer() | 
					
						
							| 
									
										
										
										
											2020-04-26 15:11:20 +08:00
										 |  |  |     test_priortized_replaybuffer(233333, 200000) | 
					
						
							| 
									
										
										
										
											2020-07-20 22:12:57 +08:00
										 |  |  |     test_update() |