| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | import ctypes | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | import time | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | from collections import OrderedDict | 
					
						
							|  |  |  | from multiprocessing import Array, Pipe, connection | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | from multiprocessing.context import Process | 
					
						
							|  |  |  | from typing import Any, Callable, List, Optional, Tuple, Union | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | import gym | 
					
						
							|  |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | from tianshou.env.utils import CloudpickleWrapper | 
					
						
							|  |  |  | from tianshou.env.worker import EnvWorker | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-13 19:31:50 +08:00
										 |  |  | _NP_TO_CT = { | 
					
						
							|  |  |  |     np.bool_: ctypes.c_bool, | 
					
						
							|  |  |  |     np.uint8: ctypes.c_uint8, | 
					
						
							|  |  |  |     np.uint16: ctypes.c_uint16, | 
					
						
							|  |  |  |     np.uint32: ctypes.c_uint32, | 
					
						
							|  |  |  |     np.uint64: ctypes.c_uint64, | 
					
						
							|  |  |  |     np.int8: ctypes.c_int8, | 
					
						
							|  |  |  |     np.int16: ctypes.c_int16, | 
					
						
							|  |  |  |     np.int32: ctypes.c_int32, | 
					
						
							|  |  |  |     np.int64: ctypes.c_int64, | 
					
						
							|  |  |  |     np.float32: ctypes.c_float, | 
					
						
							|  |  |  |     np.float64: ctypes.c_double, | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ShArray: | 
					
						
							|  |  |  |     """Wrapper of multiprocessing Array.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None: | 
					
						
							| 
									
										
										
										
											2021-03-30 16:06:03 +08:00
										 |  |  |         self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape)))  # type: ignore | 
					
						
							| 
									
										
										
										
											2020-09-13 19:31:50 +08:00
										 |  |  |         self.dtype = dtype | 
					
						
							|  |  |  |         self.shape = shape | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def save(self, ndarray: np.ndarray) -> None: | 
					
						
							|  |  |  |         assert isinstance(ndarray, np.ndarray) | 
					
						
							|  |  |  |         dst = self.arr.get_obj() | 
					
						
							|  |  |  |         dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape) | 
					
						
							|  |  |  |         np.copyto(dst_np, ndarray) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get(self) -> np.ndarray: | 
					
						
							|  |  |  |         obj = self.arr.get_obj() | 
					
						
							|  |  |  |         return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]: | 
					
						
							|  |  |  |     if isinstance(space, gym.spaces.Dict): | 
					
						
							|  |  |  |         assert isinstance(space.spaces, OrderedDict) | 
					
						
							|  |  |  |         return {k: _setup_buf(v) for k, v in space.spaces.items()} | 
					
						
							|  |  |  |     elif isinstance(space, gym.spaces.Tuple): | 
					
						
							|  |  |  |         assert isinstance(space.spaces, tuple) | 
					
						
							|  |  |  |         return tuple([_setup_buf(t) for t in space.spaces]) | 
					
						
							|  |  |  |     else: | 
					
						
							| 
									
										
										
										
											2022-03-16 14:38:51 +01:00
										 |  |  |         return ShArray(space.dtype, space.shape)  # type: ignore | 
					
						
							| 
									
										
										
										
											2020-09-13 19:31:50 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  | def _worker( | 
					
						
							|  |  |  |     parent: connection.Connection, | 
					
						
							|  |  |  |     p: connection.Connection, | 
					
						
							|  |  |  |     env_fn_wrapper: CloudpickleWrapper, | 
					
						
							| 
									
										
										
										
											2020-09-13 19:31:50 +08:00
										 |  |  |     obs_bufs: Optional[Union[dict, tuple, ShArray]] = None, | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  | ) -> None: | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     def _encode_obs( | 
					
						
							| 
									
										
										
										
											2021-03-30 16:06:03 +08:00
										 |  |  |         obs: Union[dict, tuple, np.ndarray], buffer: Union[dict, tuple, ShArray] | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     ) -> None: | 
					
						
							| 
									
										
										
										
											2020-09-13 19:31:50 +08:00
										 |  |  |         if isinstance(obs, np.ndarray) and isinstance(buffer, ShArray): | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |             buffer.save(obs) | 
					
						
							| 
									
										
										
										
											2020-09-13 19:31:50 +08:00
										 |  |  |         elif isinstance(obs, tuple) and isinstance(buffer, tuple): | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |             for o, b in zip(obs, buffer): | 
					
						
							|  |  |  |                 _encode_obs(o, b) | 
					
						
							| 
									
										
										
										
											2020-09-13 19:31:50 +08:00
										 |  |  |         elif isinstance(obs, dict) and isinstance(buffer, dict): | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |             for k in obs.keys(): | 
					
						
							|  |  |  |                 _encode_obs(obs[k], buffer[k]) | 
					
						
							|  |  |  |         return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     parent.close() | 
					
						
							|  |  |  |     env = env_fn_wrapper.data() | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         while True: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 cmd, data = p.recv() | 
					
						
							|  |  |  |             except EOFError:  # the pipe has been closed | 
					
						
							|  |  |  |                 p.close() | 
					
						
							|  |  |  |                 break | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |             if cmd == "step": | 
					
						
							| 
									
										
										
										
											2022-02-08 00:40:01 +08:00
										 |  |  |                 if data is None:  # reset | 
					
						
							|  |  |  |                     obs = env.reset() | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     obs, reward, done, info = env.step(data) | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |                 if obs_bufs is not None: | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |                     _encode_obs(obs, obs_bufs) | 
					
						
							|  |  |  |                     obs = None | 
					
						
							| 
									
										
										
										
											2022-02-08 00:40:01 +08:00
										 |  |  |                 if data is None: | 
					
						
							|  |  |  |                     p.send(obs) | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     p.send((obs, reward, done, info)) | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |             elif cmd == "close": | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |                 p.send(env.close()) | 
					
						
							|  |  |  |                 p.close() | 
					
						
							|  |  |  |                 break | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |             elif cmd == "render": | 
					
						
							|  |  |  |                 p.send(env.render(**data) if hasattr(env, "render") else None) | 
					
						
							|  |  |  |             elif cmd == "seed": | 
					
						
							|  |  |  |                 p.send(env.seed(data) if hasattr(env, "seed") else None) | 
					
						
							|  |  |  |             elif cmd == "getattr": | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |                 p.send(getattr(env, data) if hasattr(env, data) else None) | 
					
						
							| 
									
										
										
										
											2021-11-02 17:08:00 +01:00
										 |  |  |             elif cmd == "setattr": | 
					
						
							|  |  |  |                 setattr(env, data["key"], data["value"]) | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |             else: | 
					
						
							|  |  |  |                 p.close() | 
					
						
							|  |  |  |                 raise NotImplementedError | 
					
						
							|  |  |  |     except KeyboardInterrupt: | 
					
						
							|  |  |  |         p.close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class SubprocEnvWorker(EnvWorker): | 
					
						
							|  |  |  |     """Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv.""" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     def __init__( | 
					
						
							|  |  |  |         self, env_fn: Callable[[], gym.Env], share_memory: bool = False | 
					
						
							|  |  |  |     ) -> None: | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |         self.parent_remote, self.child_remote = Pipe() | 
					
						
							|  |  |  |         self.share_memory = share_memory | 
					
						
							| 
									
										
										
										
											2020-09-13 19:31:50 +08:00
										 |  |  |         self.buffer: Optional[Union[dict, tuple, ShArray]] = None | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |         if self.share_memory: | 
					
						
							|  |  |  |             dummy = env_fn() | 
					
						
							|  |  |  |             obs_space = dummy.observation_space | 
					
						
							|  |  |  |             dummy.close() | 
					
						
							|  |  |  |             del dummy | 
					
						
							|  |  |  |             self.buffer = _setup_buf(obs_space) | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |         args = ( | 
					
						
							|  |  |  |             self.parent_remote, | 
					
						
							|  |  |  |             self.child_remote, | 
					
						
							|  |  |  |             CloudpickleWrapper(env_fn), | 
					
						
							|  |  |  |             self.buffer, | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |         self.process = Process(target=_worker, args=args, daemon=True) | 
					
						
							|  |  |  |         self.process.start() | 
					
						
							|  |  |  |         self.child_remote.close() | 
					
						
							| 
									
										
										
										
											2022-02-08 00:40:01 +08:00
										 |  |  |         self.is_reset = False | 
					
						
							| 
									
										
										
										
											2021-03-02 12:28:28 +08:00
										 |  |  |         super().__init__(env_fn) | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-02 17:08:00 +01:00
										 |  |  |     def get_env_attr(self, key: str) -> Any: | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |         self.parent_remote.send(["getattr", key]) | 
					
						
							| 
									
										
										
										
											2021-03-02 12:28:28 +08:00
										 |  |  |         return self.parent_remote.recv() | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-02 17:08:00 +01:00
										 |  |  |     def set_env_attr(self, key: str, value: Any) -> None: | 
					
						
							|  |  |  |         self.parent_remote.send(["setattr", {"key": key, "value": value}]) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     def _decode_obs(self) -> Union[dict, tuple, np.ndarray]: | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |         def decode_obs( | 
					
						
							|  |  |  |             buffer: Optional[Union[dict, tuple, ShArray]] | 
					
						
							|  |  |  |         ) -> Union[dict, tuple, np.ndarray]: | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |             if isinstance(buffer, ShArray): | 
					
						
							|  |  |  |                 return buffer.get() | 
					
						
							|  |  |  |             elif isinstance(buffer, tuple): | 
					
						
							|  |  |  |                 return tuple([decode_obs(b) for b in buffer]) | 
					
						
							|  |  |  |             elif isinstance(buffer, dict): | 
					
						
							|  |  |  |                 return {k: decode_obs(v) for k, v in buffer.items()} | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 raise NotImplementedError | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return decode_obs(self.buffer) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @staticmethod | 
					
						
							| 
									
										
										
										
											2020-09-13 19:31:50 +08:00
										 |  |  |     def wait(  # type: ignore | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |         workers: List["SubprocEnvWorker"], | 
					
						
							|  |  |  |         wait_num: int, | 
					
						
							|  |  |  |         timeout: Optional[float] = None, | 
					
						
							|  |  |  |     ) -> List["SubprocEnvWorker"]: | 
					
						
							| 
									
										
										
										
											2020-09-13 19:31:50 +08:00
										 |  |  |         remain_conns = conns = [x.parent_remote for x in workers] | 
					
						
							|  |  |  |         ready_conns: List[connection.Connection] = [] | 
					
						
							|  |  |  |         remain_time, t1 = timeout, time.time() | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |         while len(remain_conns) > 0 and len(ready_conns) < wait_num: | 
					
						
							|  |  |  |             if timeout: | 
					
						
							|  |  |  |                 remain_time = timeout - (time.time() - t1) | 
					
						
							|  |  |  |                 if remain_time <= 0: | 
					
						
							|  |  |  |                     break | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |             # connection.wait hangs if the list is empty | 
					
						
							| 
									
										
										
										
											2021-03-01 15:44:03 +08:00
										 |  |  |             new_ready_conns = connection.wait(remain_conns, timeout=remain_time) | 
					
						
							| 
									
										
										
										
											2020-09-13 19:31:50 +08:00
										 |  |  |             ready_conns.extend(new_ready_conns)  # type: ignore | 
					
						
							| 
									
										
										
										
											2021-03-01 15:44:03 +08:00
										 |  |  |             remain_conns = [conn for conn in remain_conns if conn not in ready_conns] | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |         return [workers[conns.index(con)] for con in ready_conns] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-08 00:40:01 +08:00
										 |  |  |     def send(self, action: Optional[np.ndarray]) -> None: | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |         self.parent_remote.send(["step", action]) | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-08 00:40:01 +08:00
										 |  |  |     def recv( | 
					
						
							|  |  |  |         self | 
					
						
							|  |  |  |     ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]: | 
					
						
							|  |  |  |         result = self.parent_remote.recv() | 
					
						
							|  |  |  |         if isinstance(result, tuple): | 
					
						
							|  |  |  |             obs, rew, done, info = result | 
					
						
							|  |  |  |             if self.share_memory: | 
					
						
							|  |  |  |                 obs = self._decode_obs() | 
					
						
							|  |  |  |             return obs, rew, done, info | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             obs = result | 
					
						
							|  |  |  |             if self.share_memory: | 
					
						
							|  |  |  |                 obs = self._decode_obs() | 
					
						
							|  |  |  |             return obs | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: | 
					
						
							| 
									
										
										
										
											2021-03-02 12:28:28 +08:00
										 |  |  |         super().seed(seed) | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |         self.parent_remote.send(["seed", seed]) | 
					
						
							| 
									
										
										
										
											2021-03-02 12:28:28 +08:00
										 |  |  |         return self.parent_remote.recv() | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     def render(self, **kwargs: Any) -> Any: | 
					
						
							|  |  |  |         self.parent_remote.send(["render", kwargs]) | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |         return self.parent_remote.recv() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def close_env(self) -> None: | 
					
						
							|  |  |  |         try: | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |             self.parent_remote.send(["close", None]) | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |             # mp may be deleted so it may raise AttributeError | 
					
						
							|  |  |  |             self.parent_remote.recv() | 
					
						
							|  |  |  |             self.process.join() | 
					
						
							|  |  |  |         except (BrokenPipeError, EOFError, AttributeError): | 
					
						
							|  |  |  |             pass | 
					
						
							|  |  |  |         # ensure the subproc is terminated | 
					
						
							|  |  |  |         self.process.terminate() |