import gym import numpy as np from multiprocessing import Process, Pipe from typing import List, Tuple, Union, Optional, Callable, Any from tianshou.env import BaseVectorEnv from tianshou.env.utils import CloudpickleWrapper def _worker(parent, p, env_fn_wrapper): parent.close() env = env_fn_wrapper.data() try: while True: cmd, data = p.recv() if cmd == 'step': p.send(env.step(data)) elif cmd == 'reset': p.send(env.reset()) elif cmd == 'close': p.send(env.close()) p.close() break elif cmd == 'render': p.send(env.render(**data) if hasattr(env, 'render') else None) elif cmd == 'seed': p.send(env.seed(data) if hasattr(env, 'seed') else None) elif cmd == 'getattr': p.send(getattr(env, data) if hasattr(env, data) else None) else: p.close() raise NotImplementedError except KeyboardInterrupt: p.close() class SubprocVectorEnv(BaseVectorEnv): """Vectorized environment wrapper based on subprocess. .. seealso:: Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed explanation. """ def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None: super().__init__(env_fns) self.closed = False self.parent_remote, self.child_remote = \ zip(*[Pipe() for _ in range(self.env_num)]) self.processes = [ Process(target=_worker, args=( parent, child, CloudpickleWrapper(env_fn)), daemon=True) for (parent, child, env_fn) in zip( self.parent_remote, self.child_remote, env_fns) ] for p in self.processes: p.start() for c in self.child_remote: c.close() def __getattr__(self, key): for p in self.parent_remote: p.send(['getattr', key]) return [p.recv() for p in self.parent_remote] def step(self, action: np.ndarray, id: Optional[Union[int, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: if id is None: id = range(self.env_num) elif np.isscalar(id): id = [id] assert len(action) == len(id) for i, j in enumerate(id): self.parent_remote[j].send(['step', action[i]]) result = [self.parent_remote[i].recv() for i in id] obs, rew, done, info = map(np.stack, zip(*result)) return obs, rew, done, info def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray: if id is None: id = range(self.env_num) elif np.isscalar(id): id = [id] for i in id: self.parent_remote[i].send(['reset', None]) obs = np.stack([self.parent_remote[i].recv() for i in id]) return obs def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]: if np.isscalar(seed): seed = [seed + _ for _ in range(self.env_num)] elif seed is None: seed = [seed] * self.env_num for p, s in zip(self.parent_remote, seed): p.send(['seed', s]) return [p.recv() for p in self.parent_remote] def render(self, **kwargs) -> List[Any]: for p in self.parent_remote: p.send(['render', kwargs]) return [p.recv() for p in self.parent_remote] def close(self) -> List[Any]: if self.closed: return [] for p in self.parent_remote: p.send(['close', None]) result = [p.recv() for p in self.parent_remote] self.closed = True for p in self.processes: p.join() return result