| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  | import sys | 
					
						
							|  |  |  | import gym | 
					
						
							|  |  |  | import time | 
					
						
							|  |  |  | import tqdm | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  | from tianshou.data import Batch, ReplayBuffer, VectorReplayBuffer | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_replaybuffer(task="Pendulum-v0"): | 
					
						
							|  |  |  |     total_count = 5 | 
					
						
							|  |  |  |     for _ in tqdm.trange(total_count, desc="ReplayBuffer"): | 
					
						
							|  |  |  |         env = gym.make(task) | 
					
						
							|  |  |  |         buf = ReplayBuffer(10000) | 
					
						
							|  |  |  |         obs = env.reset() | 
					
						
							|  |  |  |         for i in range(100000): | 
					
						
							|  |  |  |             act = env.action_space.sample() | 
					
						
							|  |  |  |             obs_next, rew, done, info = env.step(act) | 
					
						
							|  |  |  |             batch = Batch( | 
					
						
							|  |  |  |                 obs=np.array([obs]), | 
					
						
							|  |  |  |                 act=np.array([act]), | 
					
						
							|  |  |  |                 rew=np.array([rew]), | 
					
						
							|  |  |  |                 done=np.array([done]), | 
					
						
							|  |  |  |                 obs_next=np.array([obs_next]), | 
					
						
							|  |  |  |                 info=np.array([info]), | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             buf.add(batch, buffer_ids=[0]) | 
					
						
							|  |  |  |             obs = obs_next | 
					
						
							|  |  |  |             if done: | 
					
						
							|  |  |  |                 obs = env.reset() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_vectorbuffer(task="Pendulum-v0"): | 
					
						
							|  |  |  |     total_count = 5 | 
					
						
							|  |  |  |     for _ in tqdm.trange(total_count, desc="VectorReplayBuffer"): | 
					
						
							|  |  |  |         env = gym.make(task) | 
					
						
							|  |  |  |         buf = VectorReplayBuffer(total_size=10000, buffer_num=1) | 
					
						
							|  |  |  |         obs = env.reset() | 
					
						
							|  |  |  |         for i in range(100000): | 
					
						
							|  |  |  |             act = env.action_space.sample() | 
					
						
							|  |  |  |             obs_next, rew, done, info = env.step(act) | 
					
						
							|  |  |  |             batch = Batch( | 
					
						
							|  |  |  |                 obs=np.array([obs]), | 
					
						
							|  |  |  |                 act=np.array([act]), | 
					
						
							|  |  |  |                 rew=np.array([rew]), | 
					
						
							|  |  |  |                 done=np.array([done]), | 
					
						
							|  |  |  |                 obs_next=np.array([obs_next]), | 
					
						
							|  |  |  |                 info=np.array([info]), | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             buf.add(batch) | 
					
						
							|  |  |  |             obs = obs_next | 
					
						
							|  |  |  |             if done: | 
					
						
							|  |  |  |                 obs = env.reset() | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-02 18:24:40 +08:00
										 |  |  | if __name__ == '__main__': | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     t0 = time.time() | 
					
						
							|  |  |  |     test_replaybuffer(sys.argv[-1]) | 
					
						
							|  |  |  |     print("test replaybuffer: ", time.time() - t0) | 
					
						
							|  |  |  |     t0 = time.time() | 
					
						
							|  |  |  |     test_vectorbuffer(sys.argv[-1]) | 
					
						
							|  |  |  |     print("test vectorbuffer: ", time.time() - t0) |