Tianshou/tianshou/env/vecenv/shmemenv.py
ChenDRAG f2bcc55a25
ShmemVectorEnv Implementation (#174)
* add shmem vecenv, some add&fix in test_env

* generalize test_env IO

* pep8 fix

* comment update

* style change

* pep8 fix

* style fix

* minor fix

* fix a bug

* test fix

* change env

* testenv bug fix& shmem support recurse dict

* bugfix

* pep8 fix

* _NP_TO_CT enhance

* doc update

* docstring update

* pep8 fix

* style change

* style fix

* remove assert

* minor

Co-authored-by: Trinkle23897 <463003665@qq.com>
2020-08-04 13:39:05 +08:00

178 lines
6.3 KiB
Python

import gym
import ctypes
import numpy as np
from collections import OrderedDict
from multiprocessing import Pipe, Process, Array
from typing import Callable, List, Optional, Tuple, Union
from tianshou.env import BaseVectorEnv, SubprocVectorEnv
from tianshou.env.utils import CloudpickleWrapper
_NP_TO_CT = {np.bool: ctypes.c_bool,
np.bool_: ctypes.c_bool,
np.uint8: ctypes.c_uint8,
np.uint16: ctypes.c_uint16,
np.uint32: ctypes.c_uint32,
np.uint64: ctypes.c_uint64,
np.int8: ctypes.c_int8,
np.int16: ctypes.c_int16,
np.int32: ctypes.c_int32,
np.int64: ctypes.c_int64,
np.float32: ctypes.c_float,
np.float64: ctypes.c_double}
def _shmem_worker(parent, p, env_fn_wrapper, obs_bufs):
"""Control a single environment instance using IPC and shared memory."""
def _encode_obs(obs, buffer):
if isinstance(obs, np.ndarray):
buffer.save(obs)
elif isinstance(obs, tuple):
for o, b in zip(obs, buffer):
_encode_obs(o, b)
elif isinstance(obs, dict):
for k in obs.keys():
_encode_obs(obs[k], buffer[k])
return None
parent.close()
env = env_fn_wrapper.data()
try:
while True:
cmd, data = p.recv()
if cmd == 'step':
obs, reward, done, info = env.step(data)
p.send((_encode_obs(obs, obs_bufs), reward, done, info))
elif cmd == 'reset':
p.send(_encode_obs(env.reset(), obs_bufs))
elif cmd == 'close':
p.send(env.close())
p.close()
break
elif cmd == 'render':
p.send(env.render(**data) if hasattr(env, 'render') else None)
elif cmd == 'seed':
p.send(env.seed(data) if hasattr(env, 'seed') else None)
elif cmd == 'getattr':
p.send(getattr(env, data) if hasattr(env, data) else None)
else:
p.close()
raise NotImplementedError
except KeyboardInterrupt:
p.close()
class ShArray:
"""Wrapper of multiprocessing Array"""
def __init__(self, dtype, shape):
self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape)))
self.dtype = dtype
self.shape = shape
def save(self, ndarray):
assert isinstance(ndarray, np.ndarray)
dst = self.arr.get_obj()
dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape)
np.copyto(dst_np, ndarray)
def get(self):
return np.frombuffer(self.arr.get_obj(),
dtype=self.dtype).reshape(self.shape)
class ShmemVectorEnv(SubprocVectorEnv):
"""Optimized version of SubprocVectorEnv that uses shared variables to
communicate observations. SubprocVectorEnv has exactly the same API as
SubprocVectorEnv.
.. seealso::
Please refer to :class:`~tianshou.env.SubprocVectorEnv` for more
detailed explanation.
ShmemVectorEnv Class was inspired by openai baseline's implementation.
Please refer to 'https://github.com/openai/baselines/blob/master/baselines/
common/vec_env/shmem_vec_env.py' for more info if you are interested.
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
BaseVectorEnv.__init__(self, env_fns)
# Mind that SubprocVectorEnv is not initialised.
self.closed = False
dummy = env_fns[0]()
obs_space = dummy.observation_space
dummy.close()
del dummy
self.obs_bufs = [ShmemVectorEnv._setup_buf(obs_space)
for _ in range(self.env_num)]
self.parent_remote, self.child_remote = \
zip(*[Pipe() for _ in range(self.env_num)])
self.processes = [
Process(target=_shmem_worker, args=(
parent, child, CloudpickleWrapper(env_fn),
obs_buf), daemon=True)
for (parent, child, env_fn, obs_buf) in zip(
self.parent_remote, self.child_remote, env_fns, self.obs_bufs)
]
for p in self.processes:
p.start()
for c in self.child_remote:
c.close()
def step(self,
action: np.ndarray,
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
assert len(action) == len(id)
for i, j in enumerate(id):
self.parent_remote[j].send(['step', action[i]])
result = []
for i in id:
obs, rew, done, info = self.parent_remote[i].recv()
obs = self._decode_obs(obs, i)
result.append((obs, rew, done, info))
obs, rew, done, info = map(np.stack, zip(*result))
return obs, rew, done, info
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
for i in id:
self.parent_remote[i].send(['reset', None])
obs = np.stack(
[self._decode_obs(self.parent_remote[i].recv(), i) for i in id])
return obs
@staticmethod
def _setup_buf(space):
if isinstance(space, gym.spaces.Dict):
assert isinstance(space.spaces, OrderedDict)
buffer = {k: ShmemVectorEnv._setup_buf(v)
for k, v in space.spaces.items()}
elif isinstance(space, gym.spaces.Tuple):
assert isinstance(space.spaces, tuple)
buffer = tuple([ShmemVectorEnv._setup_buf(t)
for t in space.spaces])
else:
buffer = ShArray(space.dtype, space.shape)
return buffer
def _decode_obs(self, isNone, index):
def decode_obs(buffer):
if isinstance(buffer, ShArray):
return buffer.get()
elif isinstance(buffer, tuple):
return tuple([decode_obs(b) for b in buffer])
elif isinstance(buffer, dict):
return {k: decode_obs(v) for k, v in buffer.items()}
else:
raise NotImplementedError
return decode_obs(self.obs_bufs[index])