Tianshou/tianshou/env/vecenv/shmemenv.py

178 lines
6.3 KiB
Python
Raw Normal View History

import gym
import ctypes
import numpy as np
from collections import OrderedDict
from multiprocessing import Pipe, Process, Array
from typing import Callable, List, Optional, Tuple, Union
from tianshou.env import BaseVectorEnv, SubprocVectorEnv
from tianshou.env.utils import CloudpickleWrapper
_NP_TO_CT = {np.bool: ctypes.c_bool,
np.bool_: ctypes.c_bool,
np.uint8: ctypes.c_uint8,
np.uint16: ctypes.c_uint16,
np.uint32: ctypes.c_uint32,
np.uint64: ctypes.c_uint64,
np.int8: ctypes.c_int8,
np.int16: ctypes.c_int16,
np.int32: ctypes.c_int32,
np.int64: ctypes.c_int64,
np.float32: ctypes.c_float,
np.float64: ctypes.c_double}
def _shmem_worker(parent, p, env_fn_wrapper, obs_bufs):
"""Control a single environment instance using IPC and shared memory."""
def _encode_obs(obs, buffer):
if isinstance(obs, np.ndarray):
buffer.save(obs)
elif isinstance(obs, tuple):
for o, b in zip(obs, buffer):
_encode_obs(o, b)
elif isinstance(obs, dict):
for k in obs.keys():
_encode_obs(obs[k], buffer[k])
return None
parent.close()
env = env_fn_wrapper.data()
try:
while True:
cmd, data = p.recv()
if cmd == 'step':
obs, reward, done, info = env.step(data)
p.send((_encode_obs(obs, obs_bufs), reward, done, info))
elif cmd == 'reset':
p.send(_encode_obs(env.reset(), obs_bufs))
elif cmd == 'close':
p.send(env.close())
p.close()
break
elif cmd == 'render':
p.send(env.render(**data) if hasattr(env, 'render') else None)
elif cmd == 'seed':
p.send(env.seed(data) if hasattr(env, 'seed') else None)
elif cmd == 'getattr':
p.send(getattr(env, data) if hasattr(env, data) else None)
else:
p.close()
raise NotImplementedError
except KeyboardInterrupt:
p.close()
class ShArray:
"""Wrapper of multiprocessing Array"""
def __init__(self, dtype, shape):
self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape)))
self.dtype = dtype
self.shape = shape
def save(self, ndarray):
assert isinstance(ndarray, np.ndarray)
dst = self.arr.get_obj()
dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape)
np.copyto(dst_np, ndarray)
def get(self):
return np.frombuffer(self.arr.get_obj(),
dtype=self.dtype).reshape(self.shape)
class ShmemVectorEnv(SubprocVectorEnv):
"""Optimized version of SubprocVectorEnv that uses shared variables to
communicate observations. SubprocVectorEnv has exactly the same API as
SubprocVectorEnv.
.. seealso::
Please refer to :class:`~tianshou.env.SubprocVectorEnv` for more
detailed explanation.
ShmemVectorEnv Class was inspired by openai baseline's implementation.
Please refer to 'https://github.com/openai/baselines/blob/master/baselines/
common/vec_env/shmem_vec_env.py' for more info if you are interested.
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
BaseVectorEnv.__init__(self, env_fns)
# Mind that SubprocVectorEnv is not initialised.
self.closed = False
dummy = env_fns[0]()
obs_space = dummy.observation_space
dummy.close()
del dummy
self.obs_bufs = [ShmemVectorEnv._setup_buf(obs_space)
for _ in range(self.env_num)]
self.parent_remote, self.child_remote = \
zip(*[Pipe() for _ in range(self.env_num)])
self.processes = [
Process(target=_shmem_worker, args=(
parent, child, CloudpickleWrapper(env_fn),
obs_buf), daemon=True)
for (parent, child, env_fn, obs_buf) in zip(
self.parent_remote, self.child_remote, env_fns, self.obs_bufs)
]
for p in self.processes:
p.start()
for c in self.child_remote:
c.close()
def step(self,
action: np.ndarray,
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
assert len(action) == len(id)
for i, j in enumerate(id):
self.parent_remote[j].send(['step', action[i]])
result = []
for i in id:
obs, rew, done, info = self.parent_remote[i].recv()
obs = self._decode_obs(obs, i)
result.append((obs, rew, done, info))
obs, rew, done, info = map(np.stack, zip(*result))
return obs, rew, done, info
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
for i in id:
self.parent_remote[i].send(['reset', None])
obs = np.stack(
[self._decode_obs(self.parent_remote[i].recv(), i) for i in id])
return obs
@staticmethod
def _setup_buf(space):
if isinstance(space, gym.spaces.Dict):
assert isinstance(space.spaces, OrderedDict)
buffer = {k: ShmemVectorEnv._setup_buf(v)
for k, v in space.spaces.items()}
elif isinstance(space, gym.spaces.Tuple):
assert isinstance(space.spaces, tuple)
buffer = tuple([ShmemVectorEnv._setup_buf(t)
for t in space.spaces])
else:
buffer = ShArray(space.dtype, space.shape)
return buffer
def _decode_obs(self, isNone, index):
def decode_obs(buffer):
if isinstance(buffer, ShArray):
return buffer.get()
elif isinstance(buffer, tuple):
return tuple([decode_obs(b) for b in buffer])
elif isinstance(buffer, dict):
return {k: decode_obs(v) for k, v in buffer.items()}
else:
raise NotImplementedError
return decode_obs(self.obs_bufs[index])