116 lines
3.8 KiB
Python
Raw Normal View History

import gym
import numpy as np
from multiprocessing import Process, Pipe
from typing import List, Tuple, Union, Optional, Callable, Any
from tianshou.env import BaseVectorEnv
from tianshou.env.utils import CloudpickleWrapper
def worker(parent, p, env_fn_wrapper):
parent.close()
env = env_fn_wrapper.data()
try:
while True:
cmd, data = p.recv()
if cmd == 'step':
p.send(env.step(data))
elif cmd == 'reset':
p.send(env.reset())
elif cmd == 'close':
p.send(env.close())
p.close()
break
elif cmd == 'render':
p.send(env.render(**data) if hasattr(env, 'render') else None)
elif cmd == 'seed':
p.send(env.seed(data) if hasattr(env, 'seed') else None)
elif cmd == 'getattr':
p.send(getattr(env, data) if hasattr(env, data) else None)
else:
p.close()
raise NotImplementedError
except KeyboardInterrupt:
p.close()
class SubprocVectorEnv(BaseVectorEnv):
"""Vectorized environment wrapper based on subprocess.
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
explanation.
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
super().__init__(env_fns)
self.closed = False
self.parent_remote, self.child_remote = \
zip(*[Pipe() for _ in range(self.env_num)])
self.processes = [
Process(target=worker, args=(
parent, child, CloudpickleWrapper(env_fn)), daemon=True)
for (parent, child, env_fn) in zip(
self.parent_remote, self.child_remote, env_fns)
]
for p in self.processes:
p.start()
for c in self.child_remote:
c.close()
def __getattr__(self, key):
for p in self.parent_remote:
p.send(['getattr', key])
return [p.recv() for p in self.parent_remote]
def step(self,
action: np.ndarray,
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
assert len(action) == len(id)
for i, j in enumerate(id):
self.parent_remote[j].send(['step', action[i]])
result = [self.parent_remote[i].recv() for i in id]
obs, rew, done, info = map(np.stack, zip(*result))
return obs, rew, done, info
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
for i in id:
self.parent_remote[i].send(['reset', None])
obs = np.stack([self.parent_remote[i].recv() for i in id])
return obs
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
if np.isscalar(seed):
seed = [seed + _ for _ in range(self.env_num)]
elif seed is None:
seed = [seed] * self.env_num
for p, s in zip(self.parent_remote, seed):
p.send(['seed', s])
return [p.recv() for p in self.parent_remote]
def render(self, **kwargs) -> List[Any]:
for p in self.parent_remote:
p.send(['render', kwargs])
return [p.recv() for p in self.parent_remote]
def close(self) -> List[Any]:
if self.closed:
return []
for p in self.parent_remote:
p.send(['close', None])
result = [p.recv() for p in self.parent_remote]
self.closed = True
for p in self.processes:
p.join()
return result