ChenDRAG f2bcc55a25
ShmemVectorEnv Implementation (#174)
* add shmem vecenv, some add&fix in test_env

* generalize test_env IO

* pep8 fix

* comment update

* style change

* pep8 fix

* style fix

* minor fix

* fix a bug

* test fix

* change env

* testenv bug fix& shmem support recurse dict

* bugfix

* pep8 fix

* _NP_TO_CT enhance

* doc update

* docstring update

* pep8 fix

* style change

* style fix

* remove assert

* minor

Co-authored-by: Trinkle23897 <463003665@qq.com>
2020-08-04 13:39:05 +08:00

116 lines
3.8 KiB
Python

import gym
import numpy as np
from multiprocessing import Process, Pipe
from typing import List, Tuple, Union, Optional, Callable, Any
from tianshou.env import BaseVectorEnv
from tianshou.env.utils import CloudpickleWrapper
def _worker(parent, p, env_fn_wrapper):
parent.close()
env = env_fn_wrapper.data()
try:
while True:
cmd, data = p.recv()
if cmd == 'step':
p.send(env.step(data))
elif cmd == 'reset':
p.send(env.reset())
elif cmd == 'close':
p.send(env.close())
p.close()
break
elif cmd == 'render':
p.send(env.render(**data) if hasattr(env, 'render') else None)
elif cmd == 'seed':
p.send(env.seed(data) if hasattr(env, 'seed') else None)
elif cmd == 'getattr':
p.send(getattr(env, data) if hasattr(env, data) else None)
else:
p.close()
raise NotImplementedError
except KeyboardInterrupt:
p.close()
class SubprocVectorEnv(BaseVectorEnv):
"""Vectorized environment wrapper based on subprocess.
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
explanation.
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
super().__init__(env_fns)
self.closed = False
self.parent_remote, self.child_remote = \
zip(*[Pipe() for _ in range(self.env_num)])
self.processes = [
Process(target=_worker, args=(
parent, child, CloudpickleWrapper(env_fn)), daemon=True)
for (parent, child, env_fn) in zip(
self.parent_remote, self.child_remote, env_fns)
]
for p in self.processes:
p.start()
for c in self.child_remote:
c.close()
def __getattr__(self, key):
for p in self.parent_remote:
p.send(['getattr', key])
return [p.recv() for p in self.parent_remote]
def step(self,
action: np.ndarray,
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
assert len(action) == len(id)
for i, j in enumerate(id):
self.parent_remote[j].send(['step', action[i]])
result = [self.parent_remote[i].recv() for i in id]
obs, rew, done, info = map(np.stack, zip(*result))
return obs, rew, done, info
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
for i in id:
self.parent_remote[i].send(['reset', None])
obs = np.stack([self.parent_remote[i].recv() for i in id])
return obs
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
if np.isscalar(seed):
seed = [seed + _ for _ in range(self.env_num)]
elif seed is None:
seed = [seed] * self.env_num
for p, s in zip(self.parent_remote, seed):
p.send(['seed', s])
return [p.recv() for p in self.parent_remote]
def render(self, **kwargs) -> List[Any]:
for p in self.parent_remote:
p.send(['render', kwargs])
return [p.recv() for p in self.parent_remote]
def close(self) -> List[Any]:
if self.closed:
return []
for p in self.parent_remote:
p.send(['close', None])
result = [p.recv() for p in self.parent_remote]
self.closed = True
for p in self.processes:
p.join()
return result