| 
									
										
										
										
											2020-12-17 01:58:43 +01:00
										 |  |  | import os | 
					
						
							| 
									
										
										
										
											2020-08-16 16:26:23 +08:00
										 |  |  | import pickle | 
					
						
							| 
									
										
										
										
											2020-12-17 01:58:43 +01:00
										 |  |  | import tempfile | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  | from timeit import timeit | 
					
						
							| 
									
										
										
										
											2020-07-20 22:12:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | import h5py | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | import pytest | 
					
						
							|  |  |  | import torch | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | from tianshou.data import ( | 
					
						
							|  |  |  |     Batch, | 
					
						
							|  |  |  |     CachedReplayBuffer, | 
					
						
							| 
									
										
										
										
											2022-10-31 08:54:54 +09:00
										 |  |  |     HERReplayBuffer, | 
					
						
							|  |  |  |     HERVectorReplayBuffer, | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     PrioritizedReplayBuffer, | 
					
						
							|  |  |  |     PrioritizedVectorReplayBuffer, | 
					
						
							|  |  |  |     ReplayBuffer, | 
					
						
							|  |  |  |     SegmentTree, | 
					
						
							|  |  |  |     VectorReplayBuffer, | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | from tianshou.data.utils.converter import to_hdf5 | 
					
						
							| 
									
										
										
										
											2020-03-26 09:01:20 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  | if __name__ == "__main__": | 
					
						
							| 
									
										
										
										
											2022-10-31 08:54:54 +09:00
										 |  |  |     from env import MyGoalEnv, MyTestEnv | 
					
						
							| 
									
										
										
										
											2020-03-12 22:20:33 +08:00
										 |  |  | else:  # pytest | 
					
						
							| 
									
										
										
										
											2022-10-31 08:54:54 +09:00
										 |  |  |     from test.base.env import MyGoalEnv, MyTestEnv | 
					
						
							| 
									
										
										
										
											2020-03-11 17:28:51 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-11 18:02:19 +08:00
										 |  |  | def test_replaybuffer(size=10, bufsize=20): | 
					
						
							|  |  |  |     env = MyTestEnv(size) | 
					
						
							| 
									
										
										
										
											2020-03-11 17:28:51 +08:00
										 |  |  |     buf = ReplayBuffer(bufsize) | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     buf.update(buf) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert str(buf) == buf.__class__.__name__ + "()" | 
					
						
							| 
									
										
										
										
											2023-08-09 19:27:18 +02:00
										 |  |  |     obs, _ = env.reset() | 
					
						
							| 
									
										
										
										
											2020-03-16 11:11:29 +08:00
										 |  |  |     action_list = [1] * 5 + [0] * 10 + [1] * 10 | 
					
						
							| 
									
										
										
										
											2022-01-30 00:53:56 +08:00
										 |  |  |     for i, act in enumerate(action_list): | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |         obs_next, rew, terminated, truncated, info = env.step(act) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         buf.add( | 
					
						
							| 
									
										
										
										
											2022-01-30 00:53:56 +08:00
										 |  |  |             Batch( | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |                 obs=obs, | 
					
						
							|  |  |  |                 act=[act], | 
					
						
							|  |  |  |                 rew=rew, | 
					
						
							|  |  |  |                 terminated=terminated, | 
					
						
							|  |  |  |                 truncated=truncated, | 
					
						
							|  |  |  |                 obs_next=obs_next, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |                 info=info, | 
					
						
							|  |  |  |             ), | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2020-03-16 11:11:29 +08:00
										 |  |  |         obs = obs_next | 
					
						
							| 
									
										
										
										
											2020-06-08 21:53:00 +08:00
										 |  |  |         assert len(buf) == min(bufsize, i + 1) | 
					
						
							| 
									
										
										
										
											2021-03-30 16:06:03 +08:00
										 |  |  |     assert buf.act.dtype == int | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     assert buf.act.shape == (bufsize, 1) | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     data, indices = buf.sample(bufsize * 2) | 
					
						
							|  |  |  |     assert (indices < len(buf)).all() | 
					
						
							| 
									
										
										
										
											2020-03-11 18:02:19 +08:00
										 |  |  |     assert (data.obs < size).all() | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert (data.done >= 0).all() | 
					
						
							|  |  |  |     assert (data.done <= 1).all() | 
					
						
							|  |  |  |     assert (data.terminated >= 0).all() | 
					
						
							|  |  |  |     assert (data.terminated <= 1).all() | 
					
						
							|  |  |  |     assert (data.truncated >= 0).all() | 
					
						
							|  |  |  |     assert (data.truncated <= 1).all() | 
					
						
							| 
									
										
										
										
											2020-06-26 12:37:50 +02:00
										 |  |  |     b = ReplayBuffer(size=10) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     # neg bsz should return empty index | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     assert b.sample_indices(-1).tolist() == [] | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     ptr, ep_rew, ep_len, ep_idx = b.add( | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         Batch( | 
					
						
							|  |  |  |             obs=1, | 
					
						
							|  |  |  |             act=1, | 
					
						
							|  |  |  |             rew=1, | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |             terminated=1, | 
					
						
							|  |  |  |             truncated=0, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |             obs_next="str", | 
					
						
							|  |  |  |             info={"a": 3, "b": {"c": 5.0}}, | 
					
						
							|  |  |  |         ), | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-06-26 12:37:50 +02:00
										 |  |  |     assert b.obs[0] == 1 | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     assert b.done[0] | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |     assert b.terminated[0] | 
					
						
							|  |  |  |     assert not b.truncated[0] | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert b.obs_next[0] == "str" | 
					
						
							| 
									
										
										
										
											2020-06-26 12:37:50 +02:00
										 |  |  |     assert np.all(b.obs[1:] == 0) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     assert np.all(b.obs_next[1:] == np.array(None)) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert b.info.a[0] == 3 | 
					
						
							|  |  |  |     assert b.info.a.dtype == int | 
					
						
							| 
									
										
										
										
											2020-06-26 12:37:50 +02:00
										 |  |  |     assert np.all(b.info.a[1:] == 0) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert b.info.b.c[0] == 5.0 | 
					
						
							|  |  |  |     assert b.info.b.c.dtype == float | 
					
						
							| 
									
										
										
										
											2020-06-27 03:06:40 +02:00
										 |  |  |     assert np.all(b.info.b.c[1:] == 0.0) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert ptr.shape == (1,) | 
					
						
							|  |  |  |     assert ptr[0] == 0 | 
					
						
							|  |  |  |     assert ep_rew.shape == (1,) | 
					
						
							|  |  |  |     assert ep_rew[0] == 1 | 
					
						
							|  |  |  |     assert ep_len.shape == (1,) | 
					
						
							|  |  |  |     assert ep_len[0] == 1 | 
					
						
							|  |  |  |     assert ep_idx.shape == (1,) | 
					
						
							|  |  |  |     assert ep_idx[0] == 0 | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     # test extra keys pop up, the buffer should handle it dynamically | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     batch = Batch( | 
					
						
							|  |  |  |         obs=2, | 
					
						
							|  |  |  |         act=2, | 
					
						
							|  |  |  |         rew=2, | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |         terminated=0, | 
					
						
							|  |  |  |         truncated=0, | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         obs_next="str2", | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         info={"a": 4, "d": {"e": -np.inf}}, | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     b.add(batch) | 
					
						
							|  |  |  |     info_keys = ["a", "b", "d"] | 
					
						
							|  |  |  |     assert set(b.info.keys()) == set(info_keys) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert b.info.a[1] == 4 | 
					
						
							|  |  |  |     assert b.info.b.c[1] == 0 | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     assert b.info.d.e[1] == -np.inf | 
					
						
							|  |  |  |     # test batch-style adding method, where len(batch) == 1 | 
					
						
							| 
									
										
										
										
											2021-02-22 19:19:22 +08:00
										 |  |  |     batch.done = [1] | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |     batch.terminated = [0] | 
					
						
							|  |  |  |     batch.truncated = [1] | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     batch.info.e = np.zeros([1, 4]) | 
					
						
							|  |  |  |     batch = Batch.stack([batch]) | 
					
						
							|  |  |  |     ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0]) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert ptr.shape == (1,) | 
					
						
							|  |  |  |     assert ptr[0] == 2 | 
					
						
							|  |  |  |     assert ep_rew.shape == (1,) | 
					
						
							|  |  |  |     assert ep_rew[0] == 4 | 
					
						
							|  |  |  |     assert ep_len.shape == (1,) | 
					
						
							|  |  |  |     assert ep_len[0] == 2 | 
					
						
							|  |  |  |     assert ep_idx.shape == (1,) | 
					
						
							|  |  |  |     assert ep_idx[0] == 1 | 
					
						
							|  |  |  |     assert set(b.info.keys()) == {*info_keys, "e"} | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     assert b.info.e.shape == (b.maxsize, 1, 4) | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     with pytest.raises(IndexError): | 
					
						
							|  |  |  |         b[22] | 
					
						
							| 
									
										
										
										
											2021-02-22 19:19:22 +08:00
										 |  |  |     # test prev / next | 
					
						
							|  |  |  |     assert np.all(b.prev(np.array([0, 1, 2])) == [0, 1, 1]) | 
					
						
							|  |  |  |     assert np.all(b.next(np.array([0, 1, 2])) == [0, 2, 2]) | 
					
						
							|  |  |  |     batch.done = [0] | 
					
						
							|  |  |  |     b.add(batch, buffer_ids=[0]) | 
					
						
							|  |  |  |     assert np.all(b.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3]) | 
					
						
							|  |  |  |     assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3]) | 
					
						
							| 
									
										
										
										
											2020-03-11 17:28:51 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-17 16:37:51 +08:00
										 |  |  | def test_ignore_obs_next(size=10): | 
					
						
							|  |  |  |     # Issue 82 | 
					
						
							| 
									
										
										
										
											2020-08-16 16:26:23 +08:00
										 |  |  |     buf = ReplayBuffer(size, ignore_obs_next=True) | 
					
						
							| 
									
										
										
										
											2020-06-17 16:37:51 +08:00
										 |  |  |     for i in range(size): | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         buf.add( | 
					
						
							|  |  |  |             Batch( | 
					
						
							|  |  |  |                 obs={ | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |                     "mask1": np.array([i, 1, 1, 0, 0]), | 
					
						
							|  |  |  |                     "mask2": np.array([i + 4, 0, 1, 0, 0]), | 
					
						
							|  |  |  |                     "mask": i, | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |                 }, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |                 act={"act_id": i, "position_id": i + 3}, | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |                 rew=i, | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |                 terminated=i % 3 == 0, | 
					
						
							|  |  |  |                 truncated=False, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |                 info={"if": i}, | 
					
						
							|  |  |  |             ), | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     indices = np.arange(len(buf)) | 
					
						
							| 
									
										
										
										
											2020-06-17 16:37:51 +08:00
										 |  |  |     orig = np.arange(len(buf)) | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     data = buf[indices] | 
					
						
							|  |  |  |     data2 = buf[indices] | 
					
						
							| 
									
										
										
										
											2020-06-17 16:37:51 +08:00
										 |  |  |     assert isinstance(data, Batch) | 
					
						
							|  |  |  |     assert isinstance(data2, Batch) | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     assert np.allclose(indices, orig) | 
					
						
							| 
									
										
										
										
											2020-08-16 16:26:23 +08:00
										 |  |  |     assert np.allclose(data.obs_next.mask, data2.obs_next.mask) | 
					
						
							|  |  |  |     assert np.allclose(data.obs_next.mask, [0, 2, 3, 3, 5, 6, 6, 8, 9, 9]) | 
					
						
							|  |  |  |     buf.stack_num = 4 | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     data = buf[indices] | 
					
						
							|  |  |  |     data2 = buf[indices] | 
					
						
							| 
									
										
										
										
											2020-08-16 16:26:23 +08:00
										 |  |  |     assert np.allclose(data.obs_next.mask, data2.obs_next.mask) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     assert np.allclose( | 
					
						
							|  |  |  |         data.obs_next.mask, | 
					
						
							|  |  |  |         np.array( | 
					
						
							|  |  |  |             [ | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |                 [0, 0, 0, 0], | 
					
						
							|  |  |  |                 [1, 1, 1, 2], | 
					
						
							|  |  |  |                 [1, 1, 2, 3], | 
					
						
							|  |  |  |                 [1, 1, 2, 3], | 
					
						
							|  |  |  |                 [4, 4, 4, 5], | 
					
						
							|  |  |  |                 [4, 4, 5, 6], | 
					
						
							|  |  |  |                 [4, 4, 5, 6], | 
					
						
							|  |  |  |                 [7, 7, 7, 8], | 
					
						
							|  |  |  |                 [7, 7, 8, 9], | 
					
						
							|  |  |  |                 [7, 7, 8, 9], | 
					
						
							|  |  |  |             ], | 
					
						
							|  |  |  |         ), | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert np.allclose(data.info["if"], data2.info["if"]) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     assert np.allclose( | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         data.info["if"], | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         np.array( | 
					
						
							|  |  |  |             [ | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |                 [0, 0, 0, 0], | 
					
						
							|  |  |  |                 [1, 1, 1, 1], | 
					
						
							|  |  |  |                 [1, 1, 1, 2], | 
					
						
							|  |  |  |                 [1, 1, 2, 3], | 
					
						
							|  |  |  |                 [4, 4, 4, 4], | 
					
						
							|  |  |  |                 [4, 4, 4, 5], | 
					
						
							|  |  |  |                 [4, 4, 5, 6], | 
					
						
							|  |  |  |                 [7, 7, 7, 7], | 
					
						
							|  |  |  |                 [7, 7, 7, 8], | 
					
						
							|  |  |  |                 [7, 7, 8, 9], | 
					
						
							|  |  |  |             ], | 
					
						
							|  |  |  |         ), | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-08-16 16:26:23 +08:00
										 |  |  |     assert data.obs_next | 
					
						
							| 
									
										
										
										
											2020-06-17 16:37:51 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  | def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |     env = MyTestEnv(size) | 
					
						
							| 
									
										
										
										
											2020-06-29 12:18:52 +08:00
										 |  |  |     buf = ReplayBuffer(bufsize, stack_num=stack_num) | 
					
						
							|  |  |  |     buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) | 
					
						
							| 
									
										
										
										
											2020-08-30 05:48:09 +08:00
										 |  |  |     buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True) | 
					
						
							| 
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 |  |  |     obs, info = env.reset(options={"state": 1}) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     for _ in range(16): | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |         obs_next, rew, terminated, truncated, info = env.step(1) | 
					
						
							|  |  |  |         done = terminated or truncated | 
					
						
							|  |  |  |         buf.add( | 
					
						
							|  |  |  |             Batch( | 
					
						
							|  |  |  |                 obs=obs, | 
					
						
							|  |  |  |                 act=1, | 
					
						
							|  |  |  |                 rew=rew, | 
					
						
							|  |  |  |                 terminated=terminated, | 
					
						
							|  |  |  |                 truncated=truncated, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |                 info=info, | 
					
						
							|  |  |  |             ), | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |         ) | 
					
						
							|  |  |  |         buf2.add( | 
					
						
							|  |  |  |             Batch( | 
					
						
							|  |  |  |                 obs=obs, | 
					
						
							|  |  |  |                 act=1, | 
					
						
							|  |  |  |                 rew=rew, | 
					
						
							|  |  |  |                 terminated=terminated, | 
					
						
							|  |  |  |                 truncated=truncated, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |                 info=info, | 
					
						
							|  |  |  |             ), | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         buf3.add( | 
					
						
							|  |  |  |             Batch( | 
					
						
							|  |  |  |                 obs=[obs, obs, obs], | 
					
						
							|  |  |  |                 act=1, | 
					
						
							|  |  |  |                 rew=rew, | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |                 terminated=terminated, | 
					
						
							|  |  |  |                 truncated=truncated, | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |                 obs_next=[obs, obs], | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |                 info=info, | 
					
						
							|  |  |  |             ), | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |         obs = obs_next | 
					
						
							|  |  |  |         if done: | 
					
						
							| 
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 |  |  |             obs, info = env.reset(options={"state": 1}) | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     indices = np.arange(len(buf)) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     assert np.allclose( | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         buf.get(indices, "obs")[..., 0], | 
					
						
							|  |  |  |         [ | 
					
						
							|  |  |  |             [1, 1, 1, 2], | 
					
						
							|  |  |  |             [1, 1, 2, 3], | 
					
						
							|  |  |  |             [1, 2, 3, 4], | 
					
						
							|  |  |  |             [1, 1, 1, 1], | 
					
						
							|  |  |  |             [1, 1, 1, 2], | 
					
						
							|  |  |  |             [1, 1, 2, 3], | 
					
						
							|  |  |  |             [1, 2, 3, 4], | 
					
						
							|  |  |  |             [4, 4, 4, 4], | 
					
						
							|  |  |  |             [1, 1, 1, 1], | 
					
						
							|  |  |  |         ], | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert np.allclose(buf.get(indices, "obs"), buf3.get(indices, "obs")) | 
					
						
							|  |  |  |     assert np.allclose(buf.get(indices, "obs"), buf3.get(indices, "obs_next")) | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     _, indices = buf2.sample(0) | 
					
						
							|  |  |  |     assert indices.tolist() == [2, 6] | 
					
						
							|  |  |  |     _, indices = buf2.sample(1) | 
					
						
							|  |  |  |     assert indices[0] in [2, 6] | 
					
						
							|  |  |  |     batch, indices = buf2.sample(-1)  # neg bsz -> no data | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert indices.tolist() == [] | 
					
						
							|  |  |  |     assert len(batch) == 0 | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     with pytest.raises(IndexError): | 
					
						
							|  |  |  |         buf[bufsize * 2] | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-26 15:11:20 +08:00
										 |  |  | def test_priortized_replaybuffer(size=32, bufsize=15): | 
					
						
							|  |  |  |     env = MyTestEnv(size) | 
					
						
							|  |  |  |     buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5) | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5) | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |     obs, info = env.reset() | 
					
						
							| 
									
										
										
										
											2020-04-26 15:11:20 +08:00
										 |  |  |     action_list = [1] * 5 + [0] * 10 + [1] * 10 | 
					
						
							| 
									
										
										
										
											2022-01-30 00:53:56 +08:00
										 |  |  |     for i, act in enumerate(action_list): | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |         obs_next, rew, terminated, truncated, info = env.step(act) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         batch = Batch( | 
					
						
							|  |  |  |             obs=obs, | 
					
						
							| 
									
										
										
										
											2022-01-30 00:53:56 +08:00
										 |  |  |             act=act, | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             rew=rew, | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |             terminated=terminated, | 
					
						
							|  |  |  |             truncated=truncated, | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             obs_next=obs_next, | 
					
						
							|  |  |  |             info=info, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |             policy=np.random.randn() - 0.5, | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |         batch_stack = Batch.stack([batch, batch, batch]) | 
					
						
							|  |  |  |         buf.add(Batch.stack([batch]), buffer_ids=[0]) | 
					
						
							|  |  |  |         buf2.add(batch_stack, buffer_ids=[0, 1, 2]) | 
					
						
							| 
									
										
										
										
											2020-04-26 15:11:20 +08:00
										 |  |  |         obs = obs_next | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |         data, indices = buf.sample(len(buf) // 2) | 
					
						
							| 
									
										
										
										
											2020-04-26 15:11:20 +08:00
										 |  |  |         if len(buf) // 2 == 0: | 
					
						
							|  |  |  |             assert len(data) == len(buf) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             assert len(data) == len(buf) // 2 | 
					
						
							| 
									
										
										
										
											2020-06-08 21:53:00 +08:00
										 |  |  |         assert len(buf) == min(bufsize, i + 1) | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |         assert len(buf2) == min(bufsize, 3 * (i + 1)) | 
					
						
							|  |  |  |     # check single buffer's data | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert buf.info.key.shape == (buf.maxsize,) | 
					
						
							| 
									
										
										
										
											2021-03-30 16:06:03 +08:00
										 |  |  |     assert buf.rew.dtype == float | 
					
						
							|  |  |  |     assert buf.done.dtype == bool | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |     assert buf.terminated.dtype == bool | 
					
						
							|  |  |  |     assert buf.truncated.dtype == bool | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     data, indices = buf.sample(len(buf) // 2) | 
					
						
							|  |  |  |     buf.update_weight(indices, -data.weight / 2) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert np.allclose(buf.weight[indices], np.abs(-data.weight / 2) ** buf._alpha) | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     # check multi buffer's data | 
					
						
							|  |  |  |     assert np.allclose(buf2[np.arange(buf2.maxsize)].weight, 1) | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     batch, indices = buf2.sample(10) | 
					
						
							|  |  |  |     buf2.update_weight(indices, batch.weight * 0) | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     weight = buf2[np.arange(buf2.maxsize)].weight | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     mask = np.isin(np.arange(buf2.maxsize), indices) | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     assert np.all(weight[mask] == weight[mask][0]) | 
					
						
							|  |  |  |     assert np.all(weight[~mask] == weight[~mask][0]) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert weight[~mask][0] < weight[mask][0] | 
					
						
							|  |  |  |     assert weight[mask][0] <= 1 | 
					
						
							| 
									
										
										
										
											2020-04-26 15:11:20 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-31 08:54:54 +09:00
										 |  |  | def test_herreplaybuffer(size=10, bufsize=100, sample_sz=4): | 
					
						
							|  |  |  |     env_size = size | 
					
						
							|  |  |  |     env = MyGoalEnv(env_size, array_state=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def compute_reward_fn(ag, g): | 
					
						
							|  |  |  |         return env.compute_reward_fn(ag, g, {}) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     buf = HERReplayBuffer(bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8) | 
					
						
							| 
									
										
										
										
											2022-10-31 08:54:54 +09:00
										 |  |  |     buf2 = HERVectorReplayBuffer( | 
					
						
							|  |  |  |         bufsize, | 
					
						
							|  |  |  |         buffer_num=3, | 
					
						
							|  |  |  |         compute_reward_fn=compute_reward_fn, | 
					
						
							|  |  |  |         horizon=30, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         future_k=8, | 
					
						
							| 
									
										
										
										
											2022-10-31 08:54:54 +09:00
										 |  |  |     ) | 
					
						
							|  |  |  |     # Apply her on every episodes sampled (Hacky but necessary for deterministic test) | 
					
						
							|  |  |  |     buf.future_p = 1 | 
					
						
							|  |  |  |     for buf2_buf in buf2.buffers: | 
					
						
							|  |  |  |         buf2_buf.future_p = 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     obs, _ = env.reset() | 
					
						
							|  |  |  |     action_list = [1] * 5 + [0] * 10 + [1] * 10 | 
					
						
							|  |  |  |     for i, act in enumerate(action_list): | 
					
						
							|  |  |  |         obs_next, rew, terminated, truncated, info = env.step(act) | 
					
						
							|  |  |  |         batch = Batch( | 
					
						
							|  |  |  |             obs=obs, | 
					
						
							|  |  |  |             act=[act], | 
					
						
							|  |  |  |             rew=rew, | 
					
						
							|  |  |  |             terminated=terminated, | 
					
						
							|  |  |  |             truncated=truncated, | 
					
						
							|  |  |  |             obs_next=obs_next, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |             info=info, | 
					
						
							| 
									
										
										
										
											2022-10-31 08:54:54 +09:00
										 |  |  |         ) | 
					
						
							|  |  |  |         buf.add(batch) | 
					
						
							|  |  |  |         buf2.add(Batch.stack([batch, batch, batch]), buffer_ids=[0, 1, 2]) | 
					
						
							|  |  |  |         obs = obs_next | 
					
						
							|  |  |  |         assert len(buf) == min(bufsize, i + 1) | 
					
						
							|  |  |  |         assert len(buf2) == min(bufsize, 3 * (i + 1)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     batch, indices = buf.sample(sample_sz) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Check that goals are the same for the episode (only 1 ep in buffer) | 
					
						
							|  |  |  |     tmp_indices = indices.copy() | 
					
						
							|  |  |  |     for _ in range(2 * env_size): | 
					
						
							|  |  |  |         obs = buf[tmp_indices].obs | 
					
						
							|  |  |  |         obs_next = buf[tmp_indices].obs_next | 
					
						
							|  |  |  |         rew = buf[tmp_indices].rew | 
					
						
							|  |  |  |         g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] | 
					
						
							|  |  |  |         ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0] | 
					
						
							|  |  |  |         g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] | 
					
						
							|  |  |  |         assert np.all(g == g[0]) | 
					
						
							|  |  |  |         assert np.all(g_next == g_next[0]) | 
					
						
							|  |  |  |         assert np.all(rew == (ag_next == g).astype(np.float32)) | 
					
						
							|  |  |  |         tmp_indices = buf.next(tmp_indices) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Check that goals are correctly restored | 
					
						
							|  |  |  |     buf._restore_cache() | 
					
						
							|  |  |  |     tmp_indices = indices.copy() | 
					
						
							|  |  |  |     for _ in range(2 * env_size): | 
					
						
							|  |  |  |         obs = buf[tmp_indices].obs | 
					
						
							|  |  |  |         obs_next = buf[tmp_indices].obs_next | 
					
						
							|  |  |  |         g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] | 
					
						
							|  |  |  |         g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] | 
					
						
							|  |  |  |         assert np.all(g == env_size) | 
					
						
							|  |  |  |         assert np.all(g_next == g_next[0]) | 
					
						
							|  |  |  |         assert np.all(g == g[0]) | 
					
						
							|  |  |  |         tmp_indices = buf.next(tmp_indices) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Test vector buffer | 
					
						
							|  |  |  |     batch, indices = buf2.sample(sample_sz) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Check that goals are the same for the episode (only 1 ep in buffer) | 
					
						
							|  |  |  |     tmp_indices = indices.copy() | 
					
						
							|  |  |  |     for _ in range(2 * env_size): | 
					
						
							|  |  |  |         obs = buf2[tmp_indices].obs | 
					
						
							|  |  |  |         obs_next = buf2[tmp_indices].obs_next | 
					
						
							|  |  |  |         rew = buf2[tmp_indices].rew | 
					
						
							|  |  |  |         g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] | 
					
						
							|  |  |  |         ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0] | 
					
						
							|  |  |  |         g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] | 
					
						
							|  |  |  |         assert np.all(g == g_next) | 
					
						
							|  |  |  |         assert np.all(rew == (ag_next == g).astype(np.float32)) | 
					
						
							|  |  |  |         tmp_indices = buf2.next(tmp_indices) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Check that goals are correctly restored | 
					
						
							|  |  |  |     buf2._restore_cache() | 
					
						
							|  |  |  |     tmp_indices = indices.copy() | 
					
						
							|  |  |  |     for _ in range(2 * env_size): | 
					
						
							|  |  |  |         obs = buf2[tmp_indices].obs | 
					
						
							|  |  |  |         obs_next = buf2[tmp_indices].obs_next | 
					
						
							|  |  |  |         g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] | 
					
						
							|  |  |  |         g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] | 
					
						
							|  |  |  |         assert np.all(g == env_size) | 
					
						
							|  |  |  |         assert np.all(g_next == g_next[0]) | 
					
						
							|  |  |  |         assert np.all(g == g[0]) | 
					
						
							|  |  |  |         tmp_indices = buf2.next(tmp_indices) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Test handling cycled indices | 
					
						
							|  |  |  |     env_size = size | 
					
						
							|  |  |  |     bufsize = 15 | 
					
						
							|  |  |  |     env = MyGoalEnv(env_size, array_state=False) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def compute_reward_fn(ag, g): | 
					
						
							|  |  |  |         return env.compute_reward_fn(ag, g, {}) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     buf = HERReplayBuffer(bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8) | 
					
						
							| 
									
										
										
										
											2022-10-31 08:54:54 +09:00
										 |  |  |     buf._index = 5  # shifted start index | 
					
						
							|  |  |  |     buf.future_p = 1 | 
					
						
							|  |  |  |     action_list = [1] * 10 | 
					
						
							|  |  |  |     for ep_len in [5, 10]: | 
					
						
							|  |  |  |         obs, _ = env.reset() | 
					
						
							|  |  |  |         for i in range(ep_len): | 
					
						
							|  |  |  |             act = 1 | 
					
						
							|  |  |  |             obs_next, rew, terminated, truncated, info = env.step(act) | 
					
						
							|  |  |  |             batch = Batch( | 
					
						
							|  |  |  |                 obs=obs, | 
					
						
							|  |  |  |                 act=[act], | 
					
						
							|  |  |  |                 rew=rew, | 
					
						
							|  |  |  |                 terminated=(i == ep_len - 1), | 
					
						
							|  |  |  |                 truncated=(i == ep_len - 1), | 
					
						
							|  |  |  |                 obs_next=obs_next, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |                 info=info, | 
					
						
							| 
									
										
										
										
											2022-10-31 08:54:54 +09:00
										 |  |  |             ) | 
					
						
							|  |  |  |             buf.add(batch) | 
					
						
							|  |  |  |             obs = obs_next | 
					
						
							|  |  |  |     batch, indices = buf.sample(0) | 
					
						
							|  |  |  |     assert np.all(buf[:5].obs.desired_goal == buf[0].obs.desired_goal) | 
					
						
							|  |  |  |     assert np.all(buf[5:10].obs.desired_goal == buf[5].obs.desired_goal) | 
					
						
							|  |  |  |     assert np.all(buf[10:].obs.desired_goal == buf[0].obs.desired_goal)  # (same ep) | 
					
						
							|  |  |  |     assert np.all(buf[0].obs.desired_goal != buf[5].obs.desired_goal)  # (diff ep) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-04 08:57:04 +08:00
										 |  |  |     # Another test case for cycled indices | 
					
						
							|  |  |  |     env_size = 99 | 
					
						
							|  |  |  |     bufsize = 15 | 
					
						
							|  |  |  |     env = MyGoalEnv(env_size, array_state=False) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     buf = HERReplayBuffer(bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8) | 
					
						
							| 
									
										
										
										
											2023-03-04 08:57:04 +08:00
										 |  |  |     buf.future_p = 1 | 
					
						
							|  |  |  |     for x, ep_len in enumerate([10, 20]): | 
					
						
							|  |  |  |         obs, _ = env.reset() | 
					
						
							|  |  |  |         for i in range(ep_len): | 
					
						
							|  |  |  |             act = 1 | 
					
						
							|  |  |  |             obs_next, rew, terminated, truncated, info = env.step(act) | 
					
						
							|  |  |  |             batch = Batch( | 
					
						
							|  |  |  |                 obs=obs, | 
					
						
							|  |  |  |                 act=[act], | 
					
						
							|  |  |  |                 rew=rew, | 
					
						
							|  |  |  |                 terminated=(i == ep_len - 1), | 
					
						
							|  |  |  |                 truncated=(i == ep_len - 1), | 
					
						
							|  |  |  |                 obs_next=obs_next, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |                 info=info, | 
					
						
							| 
									
										
										
										
											2023-03-04 08:57:04 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  |             if x == 1 and obs["observation"] < 10: | 
					
						
							|  |  |  |                 obs = obs_next | 
					
						
							|  |  |  |                 continue | 
					
						
							|  |  |  |             buf.add(batch) | 
					
						
							|  |  |  |             obs = obs_next | 
					
						
							|  |  |  |     buf._restore_cache() | 
					
						
							|  |  |  |     sample_indices = np.array([10])  # Suppose the sampled indices is [10] | 
					
						
							|  |  |  |     buf.rewrite_transitions(sample_indices) | 
					
						
							|  |  |  |     assert int(buf.obs.desired_goal[10][0]) in [11, 12, 13, 14, 15, 16, 17, 18, 19, 20] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-31 08:54:54 +09:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-20 22:12:57 +08:00
										 |  |  | def test_update(): | 
					
						
							|  |  |  |     buf1 = ReplayBuffer(4, stack_num=2) | 
					
						
							|  |  |  |     buf2 = ReplayBuffer(4, stack_num=2) | 
					
						
							|  |  |  |     for i in range(5): | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         buf1.add( | 
					
						
							|  |  |  |             Batch( | 
					
						
							|  |  |  |                 obs=np.array([i]), | 
					
						
							|  |  |  |                 act=float(i), | 
					
						
							|  |  |  |                 rew=i * i, | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |                 terminated=i % 2 == 0, | 
					
						
							|  |  |  |                 truncated=False, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |                 info={"incident": "found"}, | 
					
						
							|  |  |  |             ), | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2020-07-20 22:12:57 +08:00
										 |  |  |     assert len(buf1) > len(buf2) | 
					
						
							|  |  |  |     buf2.update(buf1) | 
					
						
							|  |  |  |     assert len(buf1) == len(buf2) | 
					
						
							|  |  |  |     assert (buf2[0].obs == buf1[1].obs).all() | 
					
						
							|  |  |  |     assert (buf2[-1].obs == buf1[0].obs).all() | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     b = CachedReplayBuffer(ReplayBuffer(10), 4, 5) | 
					
						
							|  |  |  |     with pytest.raises(NotImplementedError): | 
					
						
							|  |  |  |         b.update(b) | 
					
						
							| 
									
										
										
										
											2020-07-20 22:12:57 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  | def test_segtree(): | 
					
						
							| 
									
										
										
										
											2020-09-02 13:03:32 +08:00
										 |  |  |     realop = np.sum | 
					
						
							|  |  |  |     # small test | 
					
						
							|  |  |  |     actual_len = 8 | 
					
						
							|  |  |  |     tree = SegmentTree(actual_len)  # 1-15. 8-15 are leaf nodes | 
					
						
							|  |  |  |     assert len(tree) == actual_len | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert np.all([tree[i] == 0.0 for i in range(actual_len)]) | 
					
						
							| 
									
										
										
										
											2020-09-02 13:03:32 +08:00
										 |  |  |     with pytest.raises(IndexError): | 
					
						
							|  |  |  |         tree[actual_len] | 
					
						
							|  |  |  |     naive = np.zeros([actual_len]) | 
					
						
							|  |  |  |     for _ in range(1000): | 
					
						
							|  |  |  |         # random choose a place to perform single update | 
					
						
							|  |  |  |         index = np.random.randint(actual_len) | 
					
						
							|  |  |  |         value = np.random.rand() | 
					
						
							|  |  |  |         naive[index] = value | 
					
						
							|  |  |  |         tree[index] = value | 
					
						
							|  |  |  |         for i in range(actual_len): | 
					
						
							|  |  |  |             for j in range(i + 1, actual_len): | 
					
						
							|  |  |  |                 ref = realop(naive[i:j]) | 
					
						
							|  |  |  |                 out = tree.reduce(i, j) | 
					
						
							|  |  |  |                 assert np.allclose(ref, out), (ref, out) | 
					
						
							|  |  |  |     assert np.allclose(tree.reduce(start=1), realop(naive[1:])) | 
					
						
							|  |  |  |     assert np.allclose(tree.reduce(end=-1), realop(naive[:-1])) | 
					
						
							|  |  |  |     # batch setitem | 
					
						
							|  |  |  |     for _ in range(1000): | 
					
						
							|  |  |  |         index = np.random.choice(actual_len, size=4) | 
					
						
							|  |  |  |         value = np.random.rand(4) | 
					
						
							|  |  |  |         naive[index] = value | 
					
						
							|  |  |  |         tree[index] = value | 
					
						
							|  |  |  |         assert np.allclose(realop(naive), tree.reduce()) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         for _ in range(10): | 
					
						
							| 
									
										
										
										
											2020-09-02 13:03:32 +08:00
										 |  |  |             left = np.random.randint(actual_len) | 
					
						
							|  |  |  |             right = np.random.randint(left + 1, actual_len + 1) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             assert np.allclose(realop(naive[left:right]), tree.reduce(left, right)) | 
					
						
							| 
									
										
										
										
											2020-09-02 13:03:32 +08:00
										 |  |  |     # large test | 
					
						
							|  |  |  |     actual_len = 16384 | 
					
						
							|  |  |  |     tree = SegmentTree(actual_len) | 
					
						
							|  |  |  |     naive = np.zeros([actual_len]) | 
					
						
							|  |  |  |     for _ in range(1000): | 
					
						
							|  |  |  |         index = np.random.choice(actual_len, size=64) | 
					
						
							|  |  |  |         value = np.random.rand(64) | 
					
						
							|  |  |  |         naive[index] = value | 
					
						
							|  |  |  |         tree[index] = value | 
					
						
							|  |  |  |         assert np.allclose(realop(naive), tree.reduce()) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         for _ in range(10): | 
					
						
							| 
									
										
										
										
											2020-09-02 13:03:32 +08:00
										 |  |  |             left = np.random.randint(actual_len) | 
					
						
							|  |  |  |             right = np.random.randint(left + 1, actual_len + 1) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             assert np.allclose(realop(naive[left:right]), tree.reduce(left, right)) | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # test prefix-sum-idx | 
					
						
							|  |  |  |     actual_len = 8 | 
					
						
							|  |  |  |     tree = SegmentTree(actual_len) | 
					
						
							|  |  |  |     naive = np.random.rand(actual_len) | 
					
						
							|  |  |  |     tree[np.arange(actual_len)] = naive | 
					
						
							|  |  |  |     for _ in range(1000): | 
					
						
							|  |  |  |         scalar = np.random.rand() * naive.sum() | 
					
						
							|  |  |  |         index = tree.get_prefix_sum_idx(scalar) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         assert naive[:index].sum() <= scalar <= naive[: index + 1].sum() | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  |     # corner case here | 
					
						
							| 
									
										
										
										
											2021-03-30 16:06:03 +08:00
										 |  |  |     naive = np.ones(actual_len, int) | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  |     tree[np.arange(actual_len)] = naive | 
					
						
							|  |  |  |     for scalar in range(actual_len): | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         index = tree.get_prefix_sum_idx(scalar * 1.0) | 
					
						
							|  |  |  |         assert naive[:index].sum() <= scalar <= naive[: index + 1].sum() | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  |     tree = SegmentTree(10) | 
					
						
							|  |  |  |     tree[np.arange(3)] = np.array([0.1, 0, 0.1]) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     assert np.allclose( | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         tree.get_prefix_sum_idx(np.array([0, 0.1, 0.1 + 1e-6, 0.2 - 1e-6])), | 
					
						
							|  |  |  |         [0, 0, 2, 2], | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  |     with pytest.raises(AssertionError): | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         tree.get_prefix_sum_idx(0.2) | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  |     # test large prefix-sum-idx | 
					
						
							|  |  |  |     actual_len = 16384 | 
					
						
							|  |  |  |     tree = SegmentTree(actual_len) | 
					
						
							|  |  |  |     naive = np.random.rand(actual_len) | 
					
						
							|  |  |  |     tree[np.arange(actual_len)] = naive | 
					
						
							|  |  |  |     for _ in range(1000): | 
					
						
							|  |  |  |         scalar = np.random.rand() * naive.sum() | 
					
						
							|  |  |  |         index = tree.get_prefix_sum_idx(scalar) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         assert naive[:index].sum() <= scalar <= naive[: index + 1].sum() | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # profile | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     if __name__ == "__main__": | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  |         size = 100000 | 
					
						
							|  |  |  |         bsz = 64 | 
					
						
							|  |  |  |         naive = np.random.rand(size) | 
					
						
							|  |  |  |         tree = SegmentTree(size) | 
					
						
							|  |  |  |         tree[np.arange(size)] = naive | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def sample_npbuf(): | 
					
						
							|  |  |  |             return np.random.choice(size, bsz, p=naive / naive.sum()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def sample_tree(): | 
					
						
							|  |  |  |             scalar = np.random.rand(bsz) * tree.reduce() | 
					
						
							|  |  |  |             return tree.get_prefix_sum_idx(scalar) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         print("npbuf", timeit(sample_npbuf, setup=sample_npbuf, number=1000)) | 
					
						
							|  |  |  |         print("tree", timeit(sample_tree, setup=sample_tree, number=1000)) | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-16 16:26:23 +08:00
										 |  |  | def test_pickle(): | 
					
						
							|  |  |  |     size = 100 | 
					
						
							|  |  |  |     vbuf = ReplayBuffer(size, stack_num=2) | 
					
						
							|  |  |  |     pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     rew = np.array([1, 1]) | 
					
						
							| 
									
										
										
										
											2020-08-16 16:26:23 +08:00
										 |  |  |     for i in range(4): | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |         vbuf.add( | 
					
						
							|  |  |  |             Batch( | 
					
						
							|  |  |  |                 obs=Batch(index=np.array([i])), | 
					
						
							|  |  |  |                 act=0, | 
					
						
							|  |  |  |                 rew=rew, | 
					
						
							|  |  |  |                 terminated=0, | 
					
						
							|  |  |  |                 truncated=0, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |             ), | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2020-08-16 16:26:23 +08:00
										 |  |  |     for i in range(5): | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         pbuf.add( | 
					
						
							|  |  |  |             Batch( | 
					
						
							|  |  |  |                 obs=Batch(index=np.array([i])), | 
					
						
							|  |  |  |                 act=2, | 
					
						
							|  |  |  |                 rew=rew, | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |                 terminated=0, | 
					
						
							|  |  |  |                 truncated=0, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |                 info=np.random.rand(), | 
					
						
							|  |  |  |             ), | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2020-08-16 16:26:23 +08:00
										 |  |  |     # save & load | 
					
						
							|  |  |  |     _vbuf = pickle.loads(pickle.dumps(vbuf)) | 
					
						
							|  |  |  |     _pbuf = pickle.loads(pickle.dumps(pbuf)) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert len(_vbuf) == len(vbuf) | 
					
						
							|  |  |  |     assert np.allclose(_vbuf.act, vbuf.act) | 
					
						
							|  |  |  |     assert len(_pbuf) == len(pbuf) | 
					
						
							|  |  |  |     assert np.allclose(_pbuf.act, pbuf.act) | 
					
						
							| 
									
										
										
										
											2020-08-16 16:26:23 +08:00
										 |  |  |     # make sure the meta var is identical | 
					
						
							|  |  |  |     assert _vbuf.stack_num == vbuf.stack_num | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert np.allclose(_pbuf.weight[np.arange(len(_pbuf))], pbuf.weight[np.arange(len(pbuf))]) | 
					
						
							| 
									
										
										
										
											2020-08-16 16:26:23 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-12-17 01:58:43 +01:00
										 |  |  | def test_hdf5(): | 
					
						
							|  |  |  |     size = 100 | 
					
						
							|  |  |  |     buffers = { | 
					
						
							|  |  |  |         "array": ReplayBuffer(size, stack_num=2), | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |         "prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4), | 
					
						
							| 
									
										
										
										
											2020-12-17 01:58:43 +01:00
										 |  |  |     } | 
					
						
							|  |  |  |     buffer_types = {k: b.__class__ for k, b in buffers.items()} | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     device = "cuda" if torch.cuda.is_available() else "cpu" | 
					
						
							|  |  |  |     info_t = torch.tensor([1.0]).to(device) | 
					
						
							| 
									
										
										
										
											2020-12-17 01:58:43 +01:00
										 |  |  |     for i in range(4): | 
					
						
							|  |  |  |         kwargs = { | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |             "obs": Batch(index=np.array([i])), | 
					
						
							|  |  |  |             "act": i, | 
					
						
							|  |  |  |             "rew": np.array([1, 2]), | 
					
						
							|  |  |  |             "terminated": i % 3 == 2, | 
					
						
							|  |  |  |             "truncated": False, | 
					
						
							|  |  |  |             "done": i % 3 == 2, | 
					
						
							|  |  |  |             "info": {"number": {"n": i, "t": info_t}, "extra": None}, | 
					
						
							| 
									
										
										
										
											2020-12-17 01:58:43 +01:00
										 |  |  |         } | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |         buffers["array"].add(Batch(kwargs)) | 
					
						
							|  |  |  |         buffers["prioritized"].add(Batch(kwargs)) | 
					
						
							| 
									
										
										
										
											2020-12-17 01:58:43 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # save | 
					
						
							|  |  |  |     paths = {} | 
					
						
							|  |  |  |     for k, buf in buffers.items(): | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         f, path = tempfile.mkstemp(suffix=".hdf5") | 
					
						
							| 
									
										
										
										
											2020-12-17 01:58:43 +01:00
										 |  |  |         os.close(f) | 
					
						
							|  |  |  |         buf.save_hdf5(path) | 
					
						
							|  |  |  |         paths[k] = path | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # load replay buffer | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     _buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths} | 
					
						
							| 
									
										
										
										
											2020-12-17 01:58:43 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # compare | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     for k in buffers: | 
					
						
							| 
									
										
										
										
											2020-12-17 01:58:43 +01:00
										 |  |  |         assert len(_buffers[k]) == len(buffers[k]) | 
					
						
							|  |  |  |         assert np.allclose(_buffers[k].act, buffers[k].act) | 
					
						
							|  |  |  |         assert _buffers[k].stack_num == buffers[k].stack_num | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |         assert _buffers[k].maxsize == buffers[k].maxsize | 
					
						
							| 
									
										
										
										
											2020-12-17 01:58:43 +01:00
										 |  |  |         assert np.all(_buffers[k]._indices == buffers[k]._indices) | 
					
						
							|  |  |  |     for k in ["array", "prioritized"]: | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |         assert _buffers[k]._index == buffers[k]._index | 
					
						
							| 
									
										
										
										
											2020-12-17 01:58:43 +01:00
										 |  |  |         assert isinstance(buffers[k].get(0, "info"), Batch) | 
					
						
							|  |  |  |         assert isinstance(_buffers[k].get(0, "info"), Batch) | 
					
						
							|  |  |  |     for k in ["array"]: | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         assert np.all(buffers[k][:].info.number.n == _buffers[k][:].info.number.n) | 
					
						
							|  |  |  |         assert np.all(buffers[k][:].info.extra == _buffers[k][:].info.extra) | 
					
						
							| 
									
										
										
										
											2020-12-17 01:58:43 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # raise exception when value cannot be pickled | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     data = {"not_supported": lambda x: x * x} | 
					
						
							| 
									
										
										
										
											2020-12-17 01:58:43 +01:00
										 |  |  |     grp = h5py.Group | 
					
						
							|  |  |  |     with pytest.raises(NotImplementedError): | 
					
						
							|  |  |  |         to_hdf5(data, grp) | 
					
						
							|  |  |  |     # ndarray with data type not supported by HDF5 that cannot be pickled | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     data = {"not_supported": np.array(lambda x: x * x)} | 
					
						
							| 
									
										
										
										
											2020-12-17 01:58:43 +01:00
										 |  |  |     grp = h5py.Group | 
					
						
							|  |  |  |     with pytest.raises(RuntimeError): | 
					
						
							|  |  |  |         to_hdf5(data, grp) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  | def test_replaybuffermanager(): | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     buf = VectorReplayBuffer(20, 4) | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |     batch = Batch( | 
					
						
							|  |  |  |         obs=[1, 2, 3], | 
					
						
							|  |  |  |         act=[1, 2, 3], | 
					
						
							|  |  |  |         rew=[1, 2, 3], | 
					
						
							|  |  |  |         terminated=[0, 0, 1], | 
					
						
							|  |  |  |         truncated=[0, 0, 0], | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     ptr, ep_rew, ep_len, ep_idx = buf.add(batch, buffer_ids=[0, 1, 2]) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert np.all(ep_len == [0, 0, 1]) | 
					
						
							|  |  |  |     assert np.all(ep_rew == [0, 0, 3]) | 
					
						
							|  |  |  |     assert np.all(ptr == [0, 5, 10]) | 
					
						
							|  |  |  |     assert np.all(ep_idx == [0, 5, 10]) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     with pytest.raises(NotImplementedError): | 
					
						
							|  |  |  |         # ReplayBufferManager cannot be updated | 
					
						
							|  |  |  |         buf.update(buf) | 
					
						
							|  |  |  |     # sample index / prev / next / unfinished_index | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     indices = buf.sample_indices(11000) | 
					
						
							|  |  |  |     assert np.bincount(indices)[[0, 5, 10]].min() >= 3000  # uniform sample | 
					
						
							|  |  |  |     batch, indices = buf.sample(0) | 
					
						
							|  |  |  |     assert np.allclose(indices, [0, 5, 10]) | 
					
						
							|  |  |  |     indices_prev = buf.prev(indices) | 
					
						
							|  |  |  |     assert np.allclose(indices_prev, indices), indices_prev | 
					
						
							|  |  |  |     indices_next = buf.next(indices) | 
					
						
							|  |  |  |     assert np.allclose(indices_next, indices), indices_next | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     assert np.allclose(buf.unfinished_index(), [0, 5]) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     buf.add(Batch(obs=[4], act=[4], rew=[4], terminated=[1], truncated=[0]), buffer_ids=[3]) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     assert np.allclose(buf.unfinished_index(), [0, 5]) | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     batch, indices = buf.sample(10) | 
					
						
							|  |  |  |     batch, indices = buf.sample(0) | 
					
						
							|  |  |  |     assert np.allclose(indices, [0, 5, 10, 15]) | 
					
						
							|  |  |  |     indices_prev = buf.prev(indices) | 
					
						
							|  |  |  |     assert np.allclose(indices_prev, indices), indices_prev | 
					
						
							|  |  |  |     indices_next = buf.next(indices) | 
					
						
							|  |  |  |     assert np.allclose(indices_next, indices), indices_next | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     data = np.array([0, 0, 0, 0]) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     buf.add( | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |         Batch(obs=data, act=data, rew=data, terminated=data, truncated=data), | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         buffer_ids=[0, 1, 2, 3], | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |     ) | 
					
						
							|  |  |  |     buf.add( | 
					
						
							|  |  |  |         Batch(obs=data, act=data, rew=data, terminated=1 - data, truncated=data), | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         buffer_ids=[0, 1, 2, 3], | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     assert len(buf) == 12 | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     buf.add( | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |         Batch(obs=data, act=data, rew=data, terminated=data, truncated=data), | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         buffer_ids=[0, 1, 2, 3], | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |     ) | 
					
						
							|  |  |  |     buf.add( | 
					
						
							|  |  |  |         Batch(obs=data, act=data, rew=data, terminated=[0, 1, 0, 1], truncated=data), | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         buffer_ids=[0, 1, 2, 3], | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     assert len(buf) == 20 | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     indices = buf.sample_indices(120000) | 
					
						
							|  |  |  |     assert np.bincount(indices).min() >= 5000 | 
					
						
							|  |  |  |     batch, indices = buf.sample(10) | 
					
						
							|  |  |  |     indices = buf.sample_indices(0) | 
					
						
							|  |  |  |     assert np.allclose(indices, np.arange(len(buf))) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     # check the actual data stored in buf._meta | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     assert np.allclose( | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         buf.done, | 
					
						
							|  |  |  |         [ | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             1, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             1, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             1, | 
					
						
							|  |  |  |             1, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             1, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             1, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             1, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             1, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         ], | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							|  |  |  |     assert np.allclose( | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         buf.prev(indices), | 
					
						
							|  |  |  |         [ | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             1, | 
					
						
							|  |  |  |             3, | 
					
						
							|  |  |  |             3, | 
					
						
							|  |  |  |             5, | 
					
						
							|  |  |  |             5, | 
					
						
							|  |  |  |             6, | 
					
						
							|  |  |  |             8, | 
					
						
							|  |  |  |             8, | 
					
						
							|  |  |  |             10, | 
					
						
							|  |  |  |             11, | 
					
						
							|  |  |  |             11, | 
					
						
							|  |  |  |             13, | 
					
						
							|  |  |  |             13, | 
					
						
							|  |  |  |             15, | 
					
						
							|  |  |  |             16, | 
					
						
							|  |  |  |             16, | 
					
						
							|  |  |  |             18, | 
					
						
							|  |  |  |             18, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         ], | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							|  |  |  |     assert np.allclose( | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         buf.next(indices), | 
					
						
							|  |  |  |         [ | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             1, | 
					
						
							|  |  |  |             2, | 
					
						
							|  |  |  |             2, | 
					
						
							|  |  |  |             4, | 
					
						
							|  |  |  |             4, | 
					
						
							|  |  |  |             6, | 
					
						
							|  |  |  |             7, | 
					
						
							|  |  |  |             7, | 
					
						
							|  |  |  |             9, | 
					
						
							|  |  |  |             9, | 
					
						
							|  |  |  |             10, | 
					
						
							|  |  |  |             12, | 
					
						
							|  |  |  |             12, | 
					
						
							|  |  |  |             14, | 
					
						
							|  |  |  |             14, | 
					
						
							|  |  |  |             15, | 
					
						
							|  |  |  |             17, | 
					
						
							|  |  |  |             17, | 
					
						
							|  |  |  |             19, | 
					
						
							|  |  |  |             19, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         ], | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     assert np.allclose(buf.unfinished_index(), [4, 14]) | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     ptr, ep_rew, ep_len, ep_idx = buf.add( | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         Batch(obs=[1], act=[1], rew=[1], terminated=[1], truncated=[0]), | 
					
						
							|  |  |  |         buffer_ids=[2], | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert np.all(ep_len == [3]) | 
					
						
							|  |  |  |     assert np.all(ep_rew == [1]) | 
					
						
							|  |  |  |     assert np.all(ptr == [10]) | 
					
						
							|  |  |  |     assert np.all(ep_idx == [13]) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     assert np.allclose(buf.unfinished_index(), [4]) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     indices = sorted(buf.sample_indices(0)) | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     assert np.allclose(indices, np.arange(len(buf))) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     assert np.allclose( | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         buf.prev(indices), | 
					
						
							|  |  |  |         [ | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             1, | 
					
						
							|  |  |  |             3, | 
					
						
							|  |  |  |             3, | 
					
						
							|  |  |  |             5, | 
					
						
							|  |  |  |             5, | 
					
						
							|  |  |  |             6, | 
					
						
							|  |  |  |             8, | 
					
						
							|  |  |  |             8, | 
					
						
							|  |  |  |             14, | 
					
						
							|  |  |  |             11, | 
					
						
							|  |  |  |             11, | 
					
						
							|  |  |  |             13, | 
					
						
							|  |  |  |             13, | 
					
						
							|  |  |  |             15, | 
					
						
							|  |  |  |             16, | 
					
						
							|  |  |  |             16, | 
					
						
							|  |  |  |             18, | 
					
						
							|  |  |  |             18, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         ], | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							|  |  |  |     assert np.allclose( | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         buf.next(indices), | 
					
						
							|  |  |  |         [ | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             1, | 
					
						
							|  |  |  |             2, | 
					
						
							|  |  |  |             2, | 
					
						
							|  |  |  |             4, | 
					
						
							|  |  |  |             4, | 
					
						
							|  |  |  |             6, | 
					
						
							|  |  |  |             7, | 
					
						
							|  |  |  |             7, | 
					
						
							|  |  |  |             9, | 
					
						
							|  |  |  |             9, | 
					
						
							|  |  |  |             10, | 
					
						
							|  |  |  |             12, | 
					
						
							|  |  |  |             12, | 
					
						
							|  |  |  |             14, | 
					
						
							|  |  |  |             10, | 
					
						
							|  |  |  |             15, | 
					
						
							|  |  |  |             17, | 
					
						
							|  |  |  |             17, | 
					
						
							|  |  |  |             19, | 
					
						
							|  |  |  |             19, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         ], | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     # corner case: list, int and -1 | 
					
						
							|  |  |  |     assert buf.prev(-1) == buf.prev([buf.maxsize - 1])[0] | 
					
						
							|  |  |  |     assert buf.next(-1) == buf.next([buf.maxsize - 1])[0] | 
					
						
							|  |  |  |     batch = buf._meta | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     batch.info = np.ones(buf.maxsize) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     buf.set_batch(batch) | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     assert np.allclose(buf.buffers[-1].info, [1] * 5) | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     assert buf.sample_indices(-1).tolist() == [] | 
					
						
							| 
									
										
										
										
											2021-03-30 16:06:03 +08:00
										 |  |  |     assert np.array([ReplayBuffer(0, ignore_obs_next=True)]).dtype == object | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_cachedbuffer(): | 
					
						
							|  |  |  |     buf = CachedReplayBuffer(ReplayBuffer(10), 4, 5) | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     assert buf.sample_indices(0).tolist() == [] | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     # check the normal function/usage/storage in CachedReplayBuffer | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     ptr, ep_rew, ep_len, ep_idx = buf.add( | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         Batch(obs=[1], act=[1], rew=[1], terminated=[0], truncated=[0]), | 
					
						
							|  |  |  |         buffer_ids=[1], | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     obs = np.zeros(buf.maxsize) | 
					
						
							|  |  |  |     obs[15] = 1 | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     indices = buf.sample_indices(0) | 
					
						
							|  |  |  |     assert np.allclose(indices, [15]) | 
					
						
							|  |  |  |     assert np.allclose(buf.prev(indices), [15]) | 
					
						
							|  |  |  |     assert np.allclose(buf.next(indices), [15]) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     assert np.allclose(buf.obs, obs) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert np.all(ep_len == [0]) | 
					
						
							|  |  |  |     assert np.all(ep_rew == [0.0]) | 
					
						
							|  |  |  |     assert np.all(ptr == [15]) | 
					
						
							|  |  |  |     assert np.all(ep_idx == [15]) | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     ptr, ep_rew, ep_len, ep_idx = buf.add( | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         Batch(obs=[2], act=[2], rew=[2], terminated=[1], truncated=[0]), | 
					
						
							|  |  |  |         buffer_ids=[3], | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     obs[[0, 25]] = 2 | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     indices = buf.sample_indices(0) | 
					
						
							|  |  |  |     assert np.allclose(indices, [0, 15]) | 
					
						
							|  |  |  |     assert np.allclose(buf.prev(indices), [0, 15]) | 
					
						
							|  |  |  |     assert np.allclose(buf.next(indices), [0, 15]) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     assert np.allclose(buf.obs, obs) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert np.all(ep_len == [1]) | 
					
						
							|  |  |  |     assert np.all(ep_rew == [2.0]) | 
					
						
							|  |  |  |     assert np.all(ptr == [0]) | 
					
						
							|  |  |  |     assert np.all(ep_idx == [0]) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     assert np.allclose(buf.unfinished_index(), [15]) | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     assert np.allclose(buf.sample_indices(0), [0, 15]) | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     ptr, ep_rew, ep_len, ep_idx = buf.add( | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |         Batch(obs=[3, 4], act=[3, 4], rew=[3, 4], terminated=[0, 1], truncated=[0, 0]), | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         buffer_ids=[3, 1],  # TODO | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert np.all(ep_len == [0, 2]) | 
					
						
							|  |  |  |     assert np.all(ep_rew == [0, 5.0]) | 
					
						
							|  |  |  |     assert np.all(ptr == [25, 2]) | 
					
						
							|  |  |  |     assert np.all(ep_idx == [25, 1]) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     obs[[0, 1, 2, 15, 16, 25]] = [2, 1, 4, 1, 4, 3] | 
					
						
							|  |  |  |     assert np.allclose(buf.obs, obs) | 
					
						
							|  |  |  |     assert np.allclose(buf.unfinished_index(), [25]) | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     indices = buf.sample_indices(0) | 
					
						
							|  |  |  |     assert np.allclose(indices, [0, 1, 2, 25]) | 
					
						
							|  |  |  |     assert np.allclose(buf.done[indices], [1, 0, 1, 0]) | 
					
						
							|  |  |  |     assert np.allclose(buf.prev(indices), [0, 1, 1, 25]) | 
					
						
							|  |  |  |     assert np.allclose(buf.next(indices), [0, 2, 2, 25]) | 
					
						
							|  |  |  |     indices = buf.sample_indices(10000) | 
					
						
							|  |  |  |     assert np.bincount(indices)[[0, 1, 2, 25]].min() > 2000  # uniform sample | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     # cached buffer with main_buffer size == 0 (no update) | 
					
						
							|  |  |  |     # used in test_collector | 
					
						
							|  |  |  |     buf = CachedReplayBuffer(ReplayBuffer(0, sample_avail=True), 4, 5) | 
					
						
							|  |  |  |     data = np.zeros(4) | 
					
						
							|  |  |  |     rew = np.ones([4, 4]) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     buf.add(Batch(obs=data, act=data, rew=rew, terminated=[0, 0, 1, 1], truncated=[0, 0, 0, 0])) | 
					
						
							|  |  |  |     buf.add(Batch(obs=data, act=data, rew=rew, terminated=[0, 0, 0, 0], truncated=[0, 0, 0, 0])) | 
					
						
							|  |  |  |     buf.add(Batch(obs=data, act=data, rew=rew, terminated=[1, 1, 1, 1], truncated=[0, 0, 0, 0])) | 
					
						
							|  |  |  |     buf.add(Batch(obs=data, act=data, rew=rew, terminated=[0, 0, 0, 0], truncated=[0, 0, 0, 0])) | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     ptr, ep_rew, ep_len, ep_idx = buf.add( | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         Batch(obs=data, act=data, rew=rew, terminated=[0, 1, 0, 1], truncated=[0, 0, 0, 0]), | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert np.all(ptr == [1, -1, 11, -1]) | 
					
						
							|  |  |  |     assert np.all(ep_idx == [0, -1, 10, -1]) | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     assert np.all(ep_len == [0, 2, 0, 2]) | 
					
						
							|  |  |  |     assert np.all(ep_rew == [data, data + 2, data, data + 2]) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     assert np.allclose( | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         buf.done, | 
					
						
							|  |  |  |         [ | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             1, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             1, | 
					
						
							|  |  |  |             1, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             1, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         ], | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     indices = buf.sample_indices(0) | 
					
						
							|  |  |  |     assert np.allclose(indices, [0, 1, 10, 11]) | 
					
						
							|  |  |  |     assert np.allclose(buf.prev(indices), [0, 0, 10, 10]) | 
					
						
							|  |  |  |     assert np.allclose(buf.next(indices), [1, 1, 11, 11]) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_multibuf_stack(): | 
					
						
							|  |  |  |     size = 5 | 
					
						
							|  |  |  |     bufsize = 9 | 
					
						
							|  |  |  |     stack_num = 4 | 
					
						
							|  |  |  |     cached_num = 3 | 
					
						
							|  |  |  |     env = MyTestEnv(size) | 
					
						
							|  |  |  |     # test if CachedReplayBuffer can handle stack_num + ignore_obs_next | 
					
						
							|  |  |  |     buf4 = CachedReplayBuffer( | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True), | 
					
						
							|  |  |  |         cached_num, | 
					
						
							|  |  |  |         size, | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     # test if CachedReplayBuffer can handle corner case: | 
					
						
							|  |  |  |     # buffer + stack_num + ignore_obs_next + sample_avail | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     buf5 = CachedReplayBuffer( | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True, sample_avail=True), | 
					
						
							|  |  |  |         cached_num, | 
					
						
							|  |  |  |         size, | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 |  |  |     obs, info = env.reset(options={"state": 1}) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     for i in range(18): | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |         obs_next, rew, terminated, truncated, info = env.step(1) | 
					
						
							|  |  |  |         done = terminated or truncated | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |         obs_list = np.array([obs + size * i for i in range(cached_num)]) | 
					
						
							|  |  |  |         act_list = [1] * cached_num | 
					
						
							|  |  |  |         rew_list = [rew] * cached_num | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |         terminated_list = [terminated] * cached_num | 
					
						
							|  |  |  |         truncated_list = [truncated] * cached_num | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |         obs_next_list = -obs_list | 
					
						
							|  |  |  |         info_list = [info] * cached_num | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         batch = Batch( | 
					
						
							|  |  |  |             obs=obs_list, | 
					
						
							|  |  |  |             act=act_list, | 
					
						
							|  |  |  |             rew=rew_list, | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |             terminated=terminated_list, | 
					
						
							|  |  |  |             truncated=truncated_list, | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             obs_next=obs_next_list, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |             info=info_list, | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |         buf5.add(batch) | 
					
						
							|  |  |  |         buf4.add(batch) | 
					
						
							|  |  |  |         assert np.all(buf4.obs == buf5.obs) | 
					
						
							|  |  |  |         assert np.all(buf4.done == buf5.done) | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |         assert np.all(buf4.terminated == buf5.terminated) | 
					
						
							|  |  |  |         assert np.all(buf4.truncated == buf5.truncated) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |         obs = obs_next | 
					
						
							|  |  |  |         if done: | 
					
						
							| 
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 |  |  |             obs, info = env.reset(options={"state": 1}) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     # check the `add` order is correct | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     assert np.allclose( | 
					
						
							|  |  |  |         buf4.obs.reshape(-1), | 
					
						
							|  |  |  |         [ | 
					
						
							|  |  |  |             12, | 
					
						
							|  |  |  |             13, | 
					
						
							|  |  |  |             14, | 
					
						
							|  |  |  |             4, | 
					
						
							|  |  |  |             6, | 
					
						
							|  |  |  |             7, | 
					
						
							|  |  |  |             8, | 
					
						
							|  |  |  |             9, | 
					
						
							|  |  |  |             11,  # main_buffer | 
					
						
							|  |  |  |             1, | 
					
						
							|  |  |  |             2, | 
					
						
							|  |  |  |             3, | 
					
						
							|  |  |  |             4, | 
					
						
							|  |  |  |             0,  # cached_buffer[0] | 
					
						
							|  |  |  |             6, | 
					
						
							|  |  |  |             7, | 
					
						
							|  |  |  |             8, | 
					
						
							|  |  |  |             9, | 
					
						
							|  |  |  |             0,  # cached_buffer[1] | 
					
						
							|  |  |  |             11, | 
					
						
							|  |  |  |             12, | 
					
						
							|  |  |  |             13, | 
					
						
							|  |  |  |             14, | 
					
						
							|  |  |  |             0,  # cached_buffer[2] | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         ], | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ), buf4.obs | 
					
						
							|  |  |  |     assert np.allclose( | 
					
						
							|  |  |  |         buf4.done, | 
					
						
							|  |  |  |         [ | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             1, | 
					
						
							|  |  |  |             1, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             1, | 
					
						
							|  |  |  |             0,  # main_buffer | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             1, | 
					
						
							|  |  |  |             0,  # cached_buffer[0] | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             1, | 
					
						
							|  |  |  |             0,  # cached_buffer[1] | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             1, | 
					
						
							|  |  |  |             0,  # cached_buffer[2] | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         ], | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ), buf4.done | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     assert np.allclose(buf4.unfinished_index(), [10, 15, 20]) | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     indices = sorted(buf4.sample_indices(0)) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert np.allclose(indices, [*list(range(bufsize)), 9, 10, 14, 15, 19, 20]) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     assert np.allclose( | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         buf4[indices].obs[..., 0], | 
					
						
							|  |  |  |         [ | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             [11, 11, 11, 12], | 
					
						
							|  |  |  |             [11, 11, 12, 13], | 
					
						
							|  |  |  |             [11, 12, 13, 14], | 
					
						
							|  |  |  |             [4, 4, 4, 4], | 
					
						
							|  |  |  |             [6, 6, 6, 6], | 
					
						
							|  |  |  |             [6, 6, 6, 7], | 
					
						
							|  |  |  |             [6, 6, 7, 8], | 
					
						
							|  |  |  |             [6, 7, 8, 9], | 
					
						
							|  |  |  |             [11, 11, 11, 11], | 
					
						
							|  |  |  |             [1, 1, 1, 1], | 
					
						
							|  |  |  |             [1, 1, 1, 2], | 
					
						
							|  |  |  |             [6, 6, 6, 6], | 
					
						
							|  |  |  |             [6, 6, 6, 7], | 
					
						
							|  |  |  |             [11, 11, 11, 11], | 
					
						
							|  |  |  |             [11, 11, 11, 12], | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         ], | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							|  |  |  |     assert np.allclose( | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         buf4[indices].obs_next[..., 0], | 
					
						
							|  |  |  |         [ | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             [11, 11, 12, 13], | 
					
						
							|  |  |  |             [11, 12, 13, 14], | 
					
						
							|  |  |  |             [11, 12, 13, 14], | 
					
						
							|  |  |  |             [4, 4, 4, 4], | 
					
						
							|  |  |  |             [6, 6, 6, 7], | 
					
						
							|  |  |  |             [6, 6, 7, 8], | 
					
						
							|  |  |  |             [6, 7, 8, 9], | 
					
						
							|  |  |  |             [6, 7, 8, 9], | 
					
						
							|  |  |  |             [11, 11, 11, 12], | 
					
						
							|  |  |  |             [1, 1, 1, 2], | 
					
						
							|  |  |  |             [1, 1, 1, 2], | 
					
						
							|  |  |  |             [6, 6, 6, 7], | 
					
						
							|  |  |  |             [6, 6, 6, 7], | 
					
						
							|  |  |  |             [11, 11, 11, 12], | 
					
						
							|  |  |  |             [11, 11, 11, 12], | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         ], | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     indices = buf5.sample_indices(0) | 
					
						
							|  |  |  |     assert np.allclose(sorted(indices), [2, 7]) | 
					
						
							|  |  |  |     assert np.all(np.isin(buf5.sample_indices(100), indices)) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     # manually change the stack num | 
					
						
							|  |  |  |     buf5.stack_num = 2 | 
					
						
							|  |  |  |     for buf in buf5.buffers: | 
					
						
							|  |  |  |         buf.stack_num = 2 | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     indices = buf5.sample_indices(0) | 
					
						
							|  |  |  |     assert np.allclose(sorted(indices), [0, 1, 2, 5, 6, 7, 10, 15, 20]) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     batch, _ = buf5.sample(0) | 
					
						
							|  |  |  |     # test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next | 
					
						
							|  |  |  |     buf6 = CachedReplayBuffer( | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True, ignore_obs_next=True), | 
					
						
							|  |  |  |         cached_num, | 
					
						
							|  |  |  |         size, | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     obs = np.random.rand(size, 4, 84, 84) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     buf6.add( | 
					
						
							|  |  |  |         Batch( | 
					
						
							|  |  |  |             obs=[obs[2], obs[0]], | 
					
						
							|  |  |  |             act=[1, 1], | 
					
						
							|  |  |  |             rew=[0, 0], | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |             terminated=[0, 1], | 
					
						
							|  |  |  |             truncated=[0, 0], | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |             obs_next=[obs[3], obs[1]], | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         ), | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         buffer_ids=[1, 2], | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     assert buf6.obs.shape == (buf6.maxsize, 84, 84) | 
					
						
							|  |  |  |     assert np.allclose(buf6.obs[0], obs[0, -1]) | 
					
						
							|  |  |  |     assert np.allclose(buf6.obs[14], obs[2, -1]) | 
					
						
							|  |  |  |     assert np.allclose(buf6.obs[19], obs[0, -1]) | 
					
						
							|  |  |  |     assert buf6[0].obs.shape == (4, 84, 84) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_multibuf_hdf5(): | 
					
						
							|  |  |  |     size = 100 | 
					
						
							|  |  |  |     buffers = { | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |         "vector": VectorReplayBuffer(size * 4, 4), | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         "cached": CachedReplayBuffer(ReplayBuffer(size), 4, size), | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     } | 
					
						
							|  |  |  |     buffer_types = {k: b.__class__ for k, b in buffers.items()} | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     device = "cuda" if torch.cuda.is_available() else "cpu" | 
					
						
							|  |  |  |     info_t = torch.tensor([1.0]).to(device) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     for i in range(4): | 
					
						
							|  |  |  |         kwargs = { | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |             "obs": Batch(index=np.array([i])), | 
					
						
							|  |  |  |             "act": i, | 
					
						
							|  |  |  |             "rew": np.array([1, 2]), | 
					
						
							|  |  |  |             "terminated": i % 3 == 2, | 
					
						
							|  |  |  |             "truncated": False, | 
					
						
							|  |  |  |             "done": i % 3 == 2, | 
					
						
							|  |  |  |             "info": {"number": {"n": i, "t": info_t}, "extra": None}, | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |         } | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         buffers["vector"].add(Batch.stack([kwargs, kwargs, kwargs]), buffer_ids=[0, 1, 2]) | 
					
						
							|  |  |  |         buffers["cached"].add(Batch.stack([kwargs, kwargs, kwargs]), buffer_ids=[0, 1, 2]) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # save | 
					
						
							|  |  |  |     paths = {} | 
					
						
							|  |  |  |     for k, buf in buffers.items(): | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         f, path = tempfile.mkstemp(suffix=".hdf5") | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |         os.close(f) | 
					
						
							|  |  |  |         buf.save_hdf5(path) | 
					
						
							|  |  |  |         paths[k] = path | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # load replay buffer | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     _buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths} | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # compare | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     for k in buffers: | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |         assert len(_buffers[k]) == len(buffers[k]) | 
					
						
							|  |  |  |         assert np.allclose(_buffers[k].act, buffers[k].act) | 
					
						
							|  |  |  |         assert _buffers[k].stack_num == buffers[k].stack_num | 
					
						
							|  |  |  |         assert _buffers[k].maxsize == buffers[k].maxsize | 
					
						
							|  |  |  |         assert np.all(_buffers[k]._indices == buffers[k]._indices) | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     # check shallow copy in VectorReplayBuffer | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     for k in ["vector", "cached"]: | 
					
						
							|  |  |  |         buffers[k].info.number.n[0] = -100 | 
					
						
							|  |  |  |         assert buffers[k].buffers[0].info.number.n[0] == -100 | 
					
						
							|  |  |  |     # check if still behave normally | 
					
						
							|  |  |  |     for k in ["vector", "cached"]: | 
					
						
							|  |  |  |         kwargs = { | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |             "obs": Batch(index=np.array([5])), | 
					
						
							|  |  |  |             "act": 5, | 
					
						
							|  |  |  |             "rew": np.array([2, 1]), | 
					
						
							|  |  |  |             "terminated": False, | 
					
						
							|  |  |  |             "truncated": False, | 
					
						
							|  |  |  |             "done": False, | 
					
						
							|  |  |  |             "info": {"number": {"n": i}, "Timelimit.truncate": True}, | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |         } | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |         buffers[k].add(Batch.stack([kwargs, kwargs, kwargs, kwargs])) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |         act = np.zeros(buffers[k].maxsize) | 
					
						
							|  |  |  |         if k == "vector": | 
					
						
							|  |  |  |             act[np.arange(5)] = np.array([0, 1, 2, 3, 5]) | 
					
						
							|  |  |  |             act[np.arange(5) + size] = np.array([0, 1, 2, 3, 5]) | 
					
						
							|  |  |  |             act[np.arange(5) + size * 2] = np.array([0, 1, 2, 3, 5]) | 
					
						
							|  |  |  |             act[size * 3] = 5 | 
					
						
							|  |  |  |         elif k == "cached": | 
					
						
							|  |  |  |             act[np.arange(9)] = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2]) | 
					
						
							|  |  |  |             act[np.arange(3) + size] = np.array([3, 5, 2]) | 
					
						
							|  |  |  |             act[np.arange(3) + size * 2] = np.array([3, 5, 2]) | 
					
						
							|  |  |  |             act[np.arange(3) + size * 3] = np.array([3, 5, 2]) | 
					
						
							|  |  |  |             act[size * 4] = 5 | 
					
						
							|  |  |  |         assert np.allclose(buffers[k].act, act) | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |         info_keys = ["number", "extra", "Timelimit.truncate"] | 
					
						
							|  |  |  |         assert set(buffers[k].info.keys()) == set(info_keys) | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     for path in paths.values(): | 
					
						
							|  |  |  |         os.remove(path) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-05-03 13:37:52 -07:00
										 |  |  | def test_from_data(): | 
					
						
							|  |  |  |     obs_data = np.ndarray((10, 3, 3), dtype="uint8") | 
					
						
							|  |  |  |     for i in range(10): | 
					
						
							|  |  |  |         obs_data[i] = i * np.ones((3, 3), dtype="uint8") | 
					
						
							|  |  |  |     obs_next_data = np.zeros_like(obs_data) | 
					
						
							|  |  |  |     obs_next_data[:-1] = obs_data[1:] | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     f, path = tempfile.mkstemp(suffix=".hdf5") | 
					
						
							| 
									
										
										
										
											2022-05-03 13:37:52 -07:00
										 |  |  |     os.close(f) | 
					
						
							|  |  |  |     with h5py.File(path, "w") as f: | 
					
						
							|  |  |  |         obs = f.create_dataset("obs", data=obs_data) | 
					
						
							|  |  |  |         act = f.create_dataset("act", data=np.arange(10, dtype="int32")) | 
					
						
							|  |  |  |         rew = f.create_dataset("rew", data=np.arange(10, dtype="float32")) | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |         terminated = f.create_dataset("terminated", data=np.zeros(10, dtype="bool")) | 
					
						
							|  |  |  |         truncated = f.create_dataset("truncated", data=np.zeros(10, dtype="bool")) | 
					
						
							| 
									
										
										
										
											2022-05-03 13:37:52 -07:00
										 |  |  |         done = f.create_dataset("done", data=np.zeros(10, dtype="bool")) | 
					
						
							|  |  |  |         obs_next = f.create_dataset("obs_next", data=obs_next_data) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         buf = ReplayBuffer.from_data(obs, act, rew, terminated, truncated, done, obs_next) | 
					
						
							| 
									
										
										
										
											2022-05-03 13:37:52 -07:00
										 |  |  |     assert len(buf) == 10 | 
					
						
							|  |  |  |     batch = buf[3] | 
					
						
							|  |  |  |     assert np.array_equal(batch.obs, 3 * np.ones((3, 3), dtype="uint8")) | 
					
						
							|  |  |  |     assert batch.act == 3 | 
					
						
							|  |  |  |     assert batch.rew == 3.0 | 
					
						
							|  |  |  |     assert not batch.done | 
					
						
							|  |  |  |     assert np.array_equal(batch.obs_next, 4 * np.ones((3, 3), dtype="uint8")) | 
					
						
							|  |  |  |     os.remove(path) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-11 01:06:10 +02:00
										 |  |  | def test_custom_key(): | 
					
						
							|  |  |  |     batch = Batch( | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         obs_next=np.array( | 
					
						
							|  |  |  |             [ | 
					
						
							| 
									
										
										
										
											2023-08-11 01:06:10 +02:00
										 |  |  |                 [ | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |                     1.174, | 
					
						
							|  |  |  |                     -0.1151, | 
					
						
							|  |  |  |                     -0.609, | 
					
						
							|  |  |  |                     -0.5205, | 
					
						
							|  |  |  |                     -0.9316, | 
					
						
							|  |  |  |                     3.236, | 
					
						
							|  |  |  |                     -2.418, | 
					
						
							|  |  |  |                     0.386, | 
					
						
							|  |  |  |                     0.2227, | 
					
						
							|  |  |  |                     -0.5117, | 
					
						
							|  |  |  |                     2.293, | 
					
						
							|  |  |  |                 ], | 
					
						
							|  |  |  |             ], | 
					
						
							|  |  |  |         ), | 
					
						
							|  |  |  |         rew=np.array([4.28125]), | 
					
						
							|  |  |  |         act=np.array([[-0.3088, -0.4636, 0.4956]]), | 
					
						
							|  |  |  |         truncated=np.array([False]), | 
					
						
							|  |  |  |         obs=np.array( | 
					
						
							|  |  |  |             [ | 
					
						
							| 
									
										
										
										
											2023-08-11 01:06:10 +02:00
										 |  |  |                 [ | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |                     1.193, | 
					
						
							|  |  |  |                     -0.1203, | 
					
						
							|  |  |  |                     -0.6123, | 
					
						
							|  |  |  |                     -0.519, | 
					
						
							|  |  |  |                     -0.9434, | 
					
						
							|  |  |  |                     3.32, | 
					
						
							|  |  |  |                     -2.266, | 
					
						
							|  |  |  |                     0.9116, | 
					
						
							|  |  |  |                     0.623, | 
					
						
							|  |  |  |                     0.1259, | 
					
						
							|  |  |  |                     0.363, | 
					
						
							|  |  |  |                 ], | 
					
						
							|  |  |  |             ], | 
					
						
							|  |  |  |         ), | 
					
						
							|  |  |  |         terminated=np.array([False]), | 
					
						
							|  |  |  |         done=np.array([False]), | 
					
						
							|  |  |  |         returns=np.array([74.70343082]), | 
					
						
							|  |  |  |         info=Batch(), | 
					
						
							|  |  |  |         policy=Batch(), | 
					
						
							| 
									
										
										
										
											2023-08-11 01:06:10 +02:00
										 |  |  |     ) | 
					
						
							|  |  |  |     buffer_size = len(batch.rew) | 
					
						
							|  |  |  |     buffer = ReplayBuffer(buffer_size) | 
					
						
							|  |  |  |     buffer.add(batch) | 
					
						
							|  |  |  |     sampled_batch, _ = buffer.sample(1) | 
					
						
							|  |  |  |     # Check if they have the same keys | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert set(batch.keys()) == set( | 
					
						
							|  |  |  |         sampled_batch.keys(), | 
					
						
							|  |  |  |     ), "Batches have different keys: {} and {}".format(set(batch.keys()), set(sampled_batch.keys())) | 
					
						
							| 
									
										
										
										
											2023-08-11 01:06:10 +02:00
										 |  |  |     # Compare the values for each key | 
					
						
							|  |  |  |     for key in batch.keys(): | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         if isinstance(batch.__dict__[key], np.ndarray) and isinstance( | 
					
						
							|  |  |  |             sampled_batch.__dict__[key], | 
					
						
							|  |  |  |             np.ndarray, | 
					
						
							|  |  |  |         ): | 
					
						
							|  |  |  |             assert np.allclose( | 
					
						
							|  |  |  |                 batch.__dict__[key], | 
					
						
							|  |  |  |                 sampled_batch.__dict__[key], | 
					
						
							|  |  |  |             ), f"Value mismatch for key: {key}" | 
					
						
							|  |  |  |         if isinstance(batch.__dict__[key], Batch) and isinstance( | 
					
						
							|  |  |  |             sampled_batch.__dict__[key], | 
					
						
							|  |  |  |             Batch, | 
					
						
							|  |  |  |         ): | 
					
						
							| 
									
										
										
										
											2023-08-11 01:06:10 +02:00
										 |  |  |             assert batch.__dict__[key].is_empty() | 
					
						
							|  |  |  |             assert sampled_batch.__dict__[key].is_empty() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  | if __name__ == "__main__": | 
					
						
							| 
									
										
										
										
											2020-03-11 18:02:19 +08:00
										 |  |  |     test_replaybuffer() | 
					
						
							| 
									
										
										
										
											2020-06-17 16:37:51 +08:00
										 |  |  |     test_ignore_obs_next() | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |     test_stack() | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  |     test_segtree() | 
					
						
							|  |  |  |     test_priortized_replaybuffer() | 
					
						
							| 
									
										
										
										
											2020-07-20 22:12:57 +08:00
										 |  |  |     test_update() | 
					
						
							| 
									
										
										
										
											2021-01-29 12:23:18 +08:00
										 |  |  |     test_pickle() | 
					
						
							|  |  |  |     test_hdf5() | 
					
						
							|  |  |  |     test_replaybuffermanager() | 
					
						
							|  |  |  |     test_cachedbuffer() | 
					
						
							|  |  |  |     test_multibuf_stack() | 
					
						
							|  |  |  |     test_multibuf_hdf5() | 
					
						
							| 
									
										
										
										
											2022-05-03 13:37:52 -07:00
										 |  |  |     test_from_data() | 
					
						
							| 
									
										
										
										
											2022-10-31 08:54:54 +09:00
										 |  |  |     test_herreplaybuffer() | 
					
						
							| 
									
										
										
										
											2023-08-11 01:06:10 +02:00
										 |  |  |     test_custom_key() |