| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | from abc import ABC, abstractmethod | 
					
						
							| 
									
										
										
										
											2022-02-08 00:40:01 +08:00
										 |  |  | from typing import Any, Callable, List, Optional, Tuple, Union | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | import gym | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-03-21 16:29:27 -04:00
										 |  |  | from tianshou.utils import deprecation | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | class EnvWorker(ABC): | 
					
						
							|  |  |  |     """An abstract worker for an environment.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, env_fn: Callable[[], gym.Env]) -> None: | 
					
						
							|  |  |  |         self._env_fn = env_fn | 
					
						
							|  |  |  |         self.is_closed = False | 
					
						
							| 
									
										
										
										
											2022-02-25 11:05:02 -05:00
										 |  |  |         self.result: Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], | 
					
						
							|  |  |  |                            np.ndarray] | 
					
						
							| 
									
										
										
										
											2021-11-02 17:08:00 +01:00
										 |  |  |         self.action_space = self.get_env_attr("action_space")  # noqa: B009 | 
					
						
							| 
									
										
										
										
											2022-02-25 11:05:02 -05:00
										 |  |  |         self.is_reset = False | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @abstractmethod | 
					
						
							| 
									
										
										
										
											2021-11-02 17:08:00 +01:00
										 |  |  |     def get_env_attr(self, key: str) -> Any: | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @abstractmethod | 
					
						
							|  |  |  |     def set_env_attr(self, key: str, value: Any) -> None: | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-08 00:40:01 +08:00
										 |  |  |     def send(self, action: Optional[np.ndarray]) -> None: | 
					
						
							|  |  |  |         """Send action signal to low-level worker.
 | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-08 00:40:01 +08:00
										 |  |  |         When action is None, it indicates sending "reset" signal; otherwise | 
					
						
							|  |  |  |         it indicates "step" signal. The paired return value from "recv" | 
					
						
							|  |  |  |         function is determined by such kind of different signal. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2022-02-25 11:05:02 -05:00
										 |  |  |         if hasattr(self, "send_action"): | 
					
						
							| 
									
										
										
										
											2022-03-21 16:29:27 -04:00
										 |  |  |             deprecation( | 
					
						
							| 
									
										
										
										
											2022-02-25 11:05:02 -05:00
										 |  |  |                 "send_action will soon be deprecated. " | 
					
						
							|  |  |  |                 "Please use send and recv for your own EnvWorker." | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             if action is None: | 
					
						
							|  |  |  |                 self.is_reset = True | 
					
						
							|  |  |  |                 self.result = self.reset() | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 self.is_reset = False | 
					
						
							|  |  |  |                 self.send_action(action)  # type: ignore | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-08 00:40:01 +08:00
										 |  |  |     def recv( | 
					
						
							|  |  |  |         self | 
					
						
							|  |  |  |     ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]: | 
					
						
							|  |  |  |         """Receive result from low-level worker.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         If the last "send" function sends a NULL action, it only returns a | 
					
						
							|  |  |  |         single observation; otherwise it returns a tuple of (obs, rew, done, | 
					
						
							|  |  |  |         info). | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2022-02-25 11:05:02 -05:00
										 |  |  |         if hasattr(self, "get_result"): | 
					
						
							| 
									
										
										
										
											2022-03-21 16:29:27 -04:00
										 |  |  |             deprecation( | 
					
						
							| 
									
										
										
										
											2022-02-25 11:05:02 -05:00
										 |  |  |                 "get_result will soon be deprecated. " | 
					
						
							|  |  |  |                 "Please use send and recv for your own EnvWorker." | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             if not self.is_reset: | 
					
						
							|  |  |  |                 self.result = self.get_result()  # type: ignore | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |         return self.result | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-08 00:40:01 +08:00
										 |  |  |     def reset(self) -> np.ndarray: | 
					
						
							|  |  |  |         self.send(None) | 
					
						
							|  |  |  |         return self.recv()  # type: ignore | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     def step( | 
					
						
							|  |  |  |         self, action: np.ndarray | 
					
						
							|  |  |  |     ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: | 
					
						
							| 
									
										
										
										
											2020-09-11 07:55:37 +08:00
										 |  |  |         """Perform one timestep of the environment's dynamic.
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-08 00:40:01 +08:00
										 |  |  |         "send" and "recv" are coupled in sync simulation, so users only call | 
					
						
							|  |  |  |         "step" function. But they can be called separately in async | 
					
						
							|  |  |  |         simulation, i.e. someone calls "send" first, and calls "recv" later. | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2022-02-08 00:40:01 +08:00
										 |  |  |         self.send(action) | 
					
						
							|  |  |  |         return self.recv()  # type: ignore | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @staticmethod | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     def wait( | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         workers: List["EnvWorker"], | 
					
						
							|  |  |  |         wait_num: int, | 
					
						
							|  |  |  |         timeout: Optional[float] = None | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     ) -> List["EnvWorker"]: | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |         """Given a list of workers, return those ready ones.""" | 
					
						
							|  |  |  |         raise NotImplementedError | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: | 
					
						
							| 
									
										
										
										
											2021-03-02 12:28:28 +08:00
										 |  |  |         return self.action_space.seed(seed)  # issue 299 | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @abstractmethod | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     def render(self, **kwargs: Any) -> Any: | 
					
						
							|  |  |  |         """Render the environment.""" | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @abstractmethod | 
					
						
							|  |  |  |     def close_env(self) -> None: | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def close(self) -> None: | 
					
						
							|  |  |  |         if self.is_closed: | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  |         self.is_closed = True | 
					
						
							|  |  |  |         self.close_env() |