| 
									
										
										
										
											2022-02-08 00:40:01 +08:00
										 |  |  | from typing import Any, Callable, List, Optional, Tuple, Union | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | import gym | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from tianshou.env.worker import EnvWorker | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | try: | 
					
						
							|  |  |  |     import ray | 
					
						
							|  |  |  | except ImportError: | 
					
						
							|  |  |  |     pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-02 17:08:00 +01:00
										 |  |  | class _SetAttrWrapper(gym.Wrapper): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def set_env_attr(self, key: str, value: Any) -> None: | 
					
						
							|  |  |  |         setattr(self.env, key, value) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_env_attr(self, key: str) -> Any: | 
					
						
							|  |  |  |         return getattr(self.env, key) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | class RayEnvWorker(EnvWorker): | 
					
						
							|  |  |  |     """Ray worker used in RayVectorEnv.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, env_fn: Callable[[], gym.Env]) -> None: | 
					
						
							| 
									
										
										
										
											2021-11-02 17:08:00 +01:00
										 |  |  |         self.env = ray.remote(_SetAttrWrapper).options(num_cpus=0).remote(env_fn()) | 
					
						
							| 
									
										
										
										
											2021-03-02 12:28:28 +08:00
										 |  |  |         super().__init__(env_fn) | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-02 17:08:00 +01:00
										 |  |  |     def get_env_attr(self, key: str) -> Any: | 
					
						
							|  |  |  |         return ray.get(self.env.get_env_attr.remote(key)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def set_env_attr(self, key: str, value: Any) -> None: | 
					
						
							|  |  |  |         ray.get(self.env.set_env_attr.remote(key, value)) | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def reset(self) -> Any: | 
					
						
							|  |  |  |         return ray.get(self.env.reset.remote()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @staticmethod | 
					
						
							| 
									
										
										
										
											2020-09-13 19:31:50 +08:00
										 |  |  |     def wait(  # type: ignore | 
					
						
							| 
									
										
										
										
											2021-03-30 16:06:03 +08:00
										 |  |  |         workers: List["RayEnvWorker"], wait_num: int, timeout: Optional[float] = None | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     ) -> List["RayEnvWorker"]: | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |         results = [x.result for x in workers] | 
					
						
							| 
									
										
										
										
											2021-03-02 12:28:28 +08:00
										 |  |  |         ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout) | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |         return [workers[results.index(result)] for result in ready_results] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-08 00:40:01 +08:00
										 |  |  |     def send(self, action: Optional[np.ndarray]) -> None: | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |         # self.action is actually a handle | 
					
						
							| 
									
										
										
										
											2022-02-08 00:40:01 +08:00
										 |  |  |         if action is None: | 
					
						
							|  |  |  |             self.result = self.env.reset.remote() | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             self.result = self.env.step.remote(action) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def recv( | 
					
						
							|  |  |  |         self | 
					
						
							|  |  |  |     ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]: | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |         return ray.get(self.result) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-02 12:28:28 +08:00
										 |  |  |     def seed(self, seed: Optional[int] = None) -> List[int]: | 
					
						
							|  |  |  |         super().seed(seed) | 
					
						
							|  |  |  |         return ray.get(self.env.seed.remote(seed)) | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     def render(self, **kwargs: Any) -> Any: | 
					
						
							| 
									
										
										
										
											2021-03-02 12:28:28 +08:00
										 |  |  |         return ray.get(self.env.render.remote(**kwargs)) | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def close_env(self) -> None: | 
					
						
							|  |  |  |         ray.get(self.env.close.remote()) |