| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2020-03-11 17:28:51 +08:00
										 |  |  | from tianshou.data import ReplayBuffer | 
					
						
							| 
									
										
										
										
											2020-03-26 09:01:20 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-11 18:02:19 +08:00
										 |  |  | if __name__ == '__main__': | 
					
						
							| 
									
										
										
										
											2020-03-21 10:58:01 +08:00
										 |  |  |     from env import MyTestEnv | 
					
						
							| 
									
										
										
										
											2020-03-12 22:20:33 +08:00
										 |  |  | else:  # pytest | 
					
						
							| 
									
										
										
										
											2020-03-21 10:58:01 +08:00
										 |  |  |     from test.base.env import MyTestEnv | 
					
						
							| 
									
										
										
										
											2020-03-11 17:28:51 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-11 18:02:19 +08:00
										 |  |  | def test_replaybuffer(size=10, bufsize=20): | 
					
						
							|  |  |  |     env = MyTestEnv(size) | 
					
						
							| 
									
										
										
										
											2020-03-11 17:28:51 +08:00
										 |  |  |     buf = ReplayBuffer(bufsize) | 
					
						
							| 
									
										
										
										
											2020-03-16 11:11:29 +08:00
										 |  |  |     buf2 = ReplayBuffer(bufsize) | 
					
						
							| 
									
										
										
										
											2020-03-11 17:28:51 +08:00
										 |  |  |     obs = env.reset() | 
					
						
							| 
									
										
										
										
											2020-03-16 11:11:29 +08:00
										 |  |  |     action_list = [1] * 5 + [0] * 10 + [1] * 10 | 
					
						
							| 
									
										
										
										
											2020-03-11 17:28:51 +08:00
										 |  |  |     for i, a in enumerate(action_list): | 
					
						
							|  |  |  |         obs_next, rew, done, info = env.step(a) | 
					
						
							|  |  |  |         buf.add(obs, a, rew, done, obs_next, info) | 
					
						
							| 
									
										
										
										
											2020-03-16 11:11:29 +08:00
										 |  |  |         obs = obs_next | 
					
						
							| 
									
										
										
										
											2020-03-11 17:28:51 +08:00
										 |  |  |         assert len(buf) == min(bufsize, i + 1), print(len(buf), i) | 
					
						
							| 
									
										
										
										
											2020-03-14 21:48:31 +08:00
										 |  |  |     data, indice = buf.sample(bufsize * 2) | 
					
						
							| 
									
										
										
										
											2020-03-11 18:02:19 +08:00
										 |  |  |     assert (indice < len(buf)).all() | 
					
						
							|  |  |  |     assert (data.obs < size).all() | 
					
						
							|  |  |  |     assert (0 <= data.done).all() and (data.done <= 1).all() | 
					
						
							| 
									
										
										
										
											2020-03-16 11:11:29 +08:00
										 |  |  |     assert len(buf) > len(buf2) | 
					
						
							|  |  |  |     buf2.update(buf) | 
					
						
							|  |  |  |     assert len(buf) == len(buf2) | 
					
						
							|  |  |  |     assert buf2[0].obs == buf[5].obs | 
					
						
							|  |  |  |     assert buf2[-1].obs == buf[4].obs | 
					
						
							| 
									
										
										
										
											2020-03-11 17:28:51 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  | def test_stack(size=5, bufsize=9, stack_num=4): | 
					
						
							|  |  |  |     env = MyTestEnv(size) | 
					
						
							|  |  |  |     buf = ReplayBuffer(bufsize, stack_num) | 
					
						
							|  |  |  |     obs = env.reset(1) | 
					
						
							|  |  |  |     for i in range(15): | 
					
						
							|  |  |  |         obs_next, rew, done, info = env.step(1) | 
					
						
							|  |  |  |         buf.add(obs, 1, rew, done, None, info) | 
					
						
							|  |  |  |         obs = obs_next | 
					
						
							|  |  |  |         if done: | 
					
						
							|  |  |  |             obs = env.reset(1) | 
					
						
							|  |  |  |     indice = np.arange(len(buf)) | 
					
						
							| 
									
										
										
										
											2020-04-10 09:01:17 +08:00
										 |  |  |     assert abs(buf.get(indice, 'obs') - np.array([ | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |         [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], | 
					
						
							|  |  |  |         [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], | 
					
						
							|  |  |  |         [3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]])).sum() < 1e-6 | 
					
						
							|  |  |  |     print(buf) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-11 17:28:51 +08:00
										 |  |  | if __name__ == '__main__': | 
					
						
							| 
									
										
										
										
											2020-03-11 18:02:19 +08:00
										 |  |  |     test_replaybuffer() | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |     test_stack() |