import gym import numpy as np from multiprocessing import Process, Pipe from typing import List, Tuple, Union, Optional, Callable, Any from tianshou.env import BaseVectorEnv from tianshou.env.utils import CloudpickleWrapper def worker(parent, p, env_fn_wrapper): parent.close() env = env_fn_wrapper.data() try: while True: cmd, data = p.recv() if cmd == 'step': p.send(env.step(data)) elif cmd == 'reset': p.send(env.reset()) elif cmd == 'close': p.send(env.close()) p.close() break elif cmd == 'render': p.send(env.render(**data) if hasattr(env, 'render') else None) elif cmd == 'seed': p.send(env.seed(data) if hasattr(env, 'seed') else None) elif cmd == 'getattr': p.send(getattr(env, data) if hasattr(env, data) else None) else: p.close() raise NotImplementedError except KeyboardInterrupt: p.close() class SubprocVectorEnv(BaseVectorEnv): """Vectorized environment wrapper based on subprocess. .. seealso:: Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed explanation. """ def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None: super().__init__(env_fns) self.closed = False self.parent_remote, self.child_remote = \ zip(*[Pipe() for _ in range(self.env_num)]) self.processes = [ Process(target=worker, args=( parent, child, CloudpickleWrapper(env_fn)), daemon=True) for (parent, child, env_fn) in zip( self.parent_remote, self.child_remote, env_fns) ] for p in self.processes: p.start() for c in self.child_remote: c.close() def __getattr__(self, key): for p in self.parent_remote: p.send(['getattr', key]) return [p.recv() for p in self.parent_remote] def step(self, action: np.ndarray, id: Optional[Union[int, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: if id is None: id = range(self.env_num) elif np.isscalar(id): id = [id] assert len(action) == len(id) for i, j in enumerate(id): self.parent_remote[j].send(['step', action[i]]) result = [self.parent_remote[i].recv() for i in id] obs, rew, done, info = map(np.stack, zip(*result)) return obs, rew, done, info def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray: if id is None: id = range(self.env_num) elif np.isscalar(id): id = [id] for i in id: self.parent_remote[i].send(['reset', None]) obs = np.stack([self.parent_remote[i].recv() for i in id]) return obs def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]: if np.isscalar(seed): seed = [seed + _ for _ in range(self.env_num)] elif seed is None: seed = [seed] * self.env_num for p, s in zip(self.parent_remote, seed): p.send(['seed', s]) return [p.recv() for p in self.parent_remote] def render(self, **kwargs) -> List[Any]: for p in self.parent_remote: p.send(['render', kwargs]) return [p.recv() for p in self.parent_remote] def close(self) -> List[Any]: if self.closed: return [] for p in self.parent_remote: p.send(['close', None]) result = [p.recv() for p in self.parent_remote] self.closed = True for p in self.processes: p.join() return result