| 
									
										
										
										
											2020-03-11 10:56:38 +08:00
										 |  |  | import time | 
					
						
							|  |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2020-04-04 21:02:06 +08:00
										 |  |  | from tianshou.env import VectorEnv, SubprocVectorEnv, RayVectorEnv | 
					
						
							| 
									
										
										
										
											2020-03-11 10:56:38 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-21 10:58:01 +08:00
										 |  |  | if __name__ == '__main__': | 
					
						
							|  |  |  |     from env import MyTestEnv | 
					
						
							|  |  |  | else:  # pytest | 
					
						
							|  |  |  |     from test.base.env import MyTestEnv | 
					
						
							| 
									
										
										
										
											2020-03-11 10:56:38 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-11 16:14:53 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-16 11:11:29 +08:00
										 |  |  | def test_vecenv(size=10, num=8, sleep=0.001): | 
					
						
							|  |  |  |     verbose = __name__ == '__main__' | 
					
						
							| 
									
										
										
										
											2020-03-25 14:08:28 +08:00
										 |  |  |     env_fns = [ | 
					
						
							|  |  |  |         lambda: MyTestEnv(size=size, sleep=sleep), | 
					
						
							|  |  |  |         lambda: MyTestEnv(size=size + 1, sleep=sleep), | 
					
						
							|  |  |  |         lambda: MyTestEnv(size=size + 2, sleep=sleep), | 
					
						
							|  |  |  |         lambda: MyTestEnv(size=size + 3, sleep=sleep), | 
					
						
							|  |  |  |         lambda: MyTestEnv(size=size + 4, sleep=sleep), | 
					
						
							|  |  |  |         lambda: MyTestEnv(size=size + 5, sleep=sleep), | 
					
						
							|  |  |  |         lambda: MyTestEnv(size=size + 6, sleep=sleep), | 
					
						
							|  |  |  |         lambda: MyTestEnv(size=size + 7, sleep=sleep), | 
					
						
							|  |  |  |     ] | 
					
						
							| 
									
										
										
										
											2020-03-11 16:14:53 +08:00
										 |  |  |     venv = [ | 
					
						
							| 
									
										
										
										
											2020-03-25 14:08:28 +08:00
										 |  |  |         VectorEnv(env_fns), | 
					
						
							|  |  |  |         SubprocVectorEnv(env_fns), | 
					
						
							| 
									
										
										
										
											2020-03-11 16:14:53 +08:00
										 |  |  |     ] | 
					
						
							| 
									
										
										
										
											2020-03-16 11:11:29 +08:00
										 |  |  |     if verbose: | 
					
						
							| 
									
										
										
										
											2020-03-25 14:08:28 +08:00
										 |  |  |         venv.append(RayVectorEnv(env_fns)) | 
					
						
							| 
									
										
										
										
											2020-03-11 16:14:53 +08:00
										 |  |  |     for v in venv: | 
					
						
							|  |  |  |         v.seed() | 
					
						
							| 
									
										
										
										
											2020-03-25 14:08:28 +08:00
										 |  |  |     action_list = [1] * 5 + [0] * 10 + [1] * 20 | 
					
						
							| 
									
										
										
										
											2020-03-11 16:14:53 +08:00
										 |  |  |     if not verbose: | 
					
						
							|  |  |  |         o = [v.reset() for v in venv] | 
					
						
							|  |  |  |         for i, a in enumerate(action_list): | 
					
						
							| 
									
										
										
										
											2020-03-25 14:08:28 +08:00
										 |  |  |             o = [] | 
					
						
							|  |  |  |             for v in venv: | 
					
						
							|  |  |  |                 A, B, C, D = v.step([a] * num) | 
					
						
							|  |  |  |                 if sum(C): | 
					
						
							|  |  |  |                     A = v.reset(np.where(C)[0]) | 
					
						
							|  |  |  |                 o.append([A, B, C, D]) | 
					
						
							| 
									
										
										
										
											2020-03-11 16:14:53 +08:00
										 |  |  |             for i in zip(*o): | 
					
						
							|  |  |  |                 for j in range(1, len(i)): | 
					
						
							|  |  |  |                     assert (i[0] == i[j]).all() | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         t = [0, 0, 0] | 
					
						
							|  |  |  |         for i, e in enumerate(venv): | 
					
						
							|  |  |  |             t[i] = time.time() | 
					
						
							|  |  |  |             e.reset() | 
					
						
							|  |  |  |             for a in action_list: | 
					
						
							| 
									
										
										
										
											2020-03-25 14:08:28 +08:00
										 |  |  |                 done = e.step([a] * num)[2] | 
					
						
							|  |  |  |                 if sum(done) > 0: | 
					
						
							|  |  |  |                     e.reset(np.where(done)[0]) | 
					
						
							| 
									
										
										
										
											2020-03-11 16:14:53 +08:00
										 |  |  |             t[i] = time.time() - t[i] | 
					
						
							| 
									
										
										
										
											2020-03-13 17:49:22 +08:00
										 |  |  |         print(f'VectorEnv: {t[0]:.6f}s') | 
					
						
							|  |  |  |         print(f'SubprocVectorEnv: {t[1]:.6f}s') | 
					
						
							|  |  |  |         print(f'RayVectorEnv: {t[2]:.6f}s') | 
					
						
							| 
									
										
										
										
											2020-03-11 17:28:51 +08:00
										 |  |  |     for v in venv: | 
					
						
							|  |  |  |         v.close() | 
					
						
							| 
									
										
										
										
											2020-03-11 16:14:53 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-11 10:56:38 +08:00
										 |  |  | if __name__ == '__main__': | 
					
						
							| 
									
										
										
										
											2020-03-16 11:11:29 +08:00
										 |  |  |     test_vecenv() |