| 
									
										
										
										
											2023-09-05 23:34:23 +02:00
										 |  |  | from collections.abc import Callable | 
					
						
							|  |  |  | from typing import Any | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 |  |  | import gymnasium as gym | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | import numpy as np | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from tianshou.env.worker import EnvWorker | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class DummyEnvWorker(EnvWorker): | 
					
						
							|  |  |  |     """Dummy worker used in sequential vector environments.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, env_fn: Callable[[], gym.Env]) -> None: | 
					
						
							|  |  |  |         self.env = env_fn() | 
					
						
							| 
									
										
										
										
											2021-03-02 12:28:28 +08:00
										 |  |  |         super().__init__(env_fn) | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-02 17:08:00 +01:00
										 |  |  |     def get_env_attr(self, key: str) -> Any: | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |         return getattr(self.env, key) | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-02 17:08:00 +01:00
										 |  |  |     def set_env_attr(self, key: str, value: Any) -> None: | 
					
						
							| 
									
										
										
										
											2022-07-14 22:52:56 -07:00
										 |  |  |         setattr(self.env.unwrapped, key, value) | 
					
						
							| 
									
										
										
										
											2021-11-02 17:08:00 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     def reset(self, **kwargs: Any) -> tuple[np.ndarray, dict]: | 
					
						
							| 
									
										
										
										
											2022-06-27 18:52:21 -04:00
										 |  |  |         if "seed" in kwargs: | 
					
						
							|  |  |  |             super().seed(kwargs["seed"]) | 
					
						
							|  |  |  |         return self.env.reset(**kwargs) | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @staticmethod | 
					
						
							| 
									
										
										
										
											2020-09-13 19:31:50 +08:00
										 |  |  |     def wait(  # type: ignore | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         workers: list["DummyEnvWorker"], | 
					
						
							|  |  |  |         wait_num: int, | 
					
						
							| 
									
										
										
										
											2023-09-05 23:34:23 +02:00
										 |  |  |         timeout: float | None = None, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     ) -> list["DummyEnvWorker"]: | 
					
						
							| 
									
										
										
										
											2020-09-11 07:55:37 +08:00
										 |  |  |         # Sequential EnvWorker objects are always ready | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |         return workers | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-05 23:34:23 +02:00
										 |  |  |     def send(self, action: np.ndarray | None, **kwargs: Any) -> None: | 
					
						
							| 
									
										
										
										
											2022-02-08 00:40:01 +08:00
										 |  |  |         if action is None: | 
					
						
							| 
									
										
										
										
											2022-06-27 18:52:21 -04:00
										 |  |  |             self.result = self.env.reset(**kwargs) | 
					
						
							| 
									
										
										
										
											2022-02-08 00:40:01 +08:00
										 |  |  |         else: | 
					
						
							| 
									
										
										
										
											2022-03-16 14:38:51 +01:00
										 |  |  |             self.result = self.env.step(action)  # type: ignore | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-05 23:34:23 +02:00
										 |  |  |     def seed(self, seed: int | None = None) -> list[int] | None: | 
					
						
							| 
									
										
										
										
											2021-03-02 12:28:28 +08:00
										 |  |  |         super().seed(seed) | 
					
						
							| 
									
										
										
										
											2022-06-27 18:52:21 -04:00
										 |  |  |         try: | 
					
						
							| 
									
										
										
										
											2022-09-26 18:31:23 +02:00
										 |  |  |             return self.env.seed(seed)  # type: ignore | 
					
						
							|  |  |  |         except (AttributeError, NotImplementedError): | 
					
						
							| 
									
										
										
										
											2022-06-27 18:52:21 -04:00
										 |  |  |             self.env.reset(seed=seed) | 
					
						
							|  |  |  |             return [seed]  # type: ignore | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     def render(self, **kwargs: Any) -> Any: | 
					
						
							| 
									
										
										
										
											2021-03-02 12:28:28 +08:00
										 |  |  |         return self.env.render(**kwargs) | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def close_env(self) -> None: | 
					
						
							|  |  |  |         self.env.close() |