ShmemVectorEnv Implementation (#174)

* add shmem vecenv, some add&fix in test_env

* generalize test_env IO

* pep8 fix

* comment update

* style change

* pep8 fix

* style fix

* minor fix

* fix a bug

* test fix

* change env

* testenv bug fix& shmem support recurse dict

* bugfix

* pep8 fix

* _NP_TO_CT enhance

* doc update

* docstring update

* pep8 fix

* style change

* style fix

* remove assert

* minor

Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
ChenDRAG 2020-08-04 13:39:05 +08:00 committed by GitHub
parent 996e2f7c9b
commit f2bcc55a25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 284 additions and 49 deletions

View File

@ -31,7 +31,7 @@ See :ref:`customized_trainer`.
Parallel Sampling
-----------------
Use :class:`~tianshou.env.VectorEnv` or :class:`~tianshou.env.SubprocVectorEnv`.
Use :class:`~tianshou.env.VectorEnv`, :class:`~tianshou.env.SubprocVectorEnv` or :class:`~tianshou.env.ShmemVectorEnv`.
::
env_fns = [

View File

@ -30,7 +30,7 @@ It is available if you want the original ``gym.Env``:
train_envs = gym.make('CartPole-v0')
test_envs = gym.make('CartPole-v0')
Tianshou supports parallel sampling for all algorithms. It provides three types of vectorized environment wrapper: :class:`~tianshou.env.VectorEnv`, :class:`~tianshou.env.SubprocVectorEnv`, and :class:`~tianshou.env.RayVectorEnv`. It can be used as follows:
Tianshou supports parallel sampling for all algorithms. It provides four types of vectorized environment wrapper: :class:`~tianshou.env.VectorEnv`, :class:`~tianshou.env.SubprocVectorEnv`, :class:`~tianshou.env.ShmemVectorEnv`, and :class:`~tianshou.env.RayVectorEnv`. It can be used as follows:
::
train_envs = ts.env.VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(8)])

View File

@ -2,35 +2,55 @@ import gym
import time
import random
import numpy as np
from gym.spaces import Discrete, MultiDiscrete, Box
from gym.spaces import Discrete, MultiDiscrete, Box, Dict, Tuple
class MyTestEnv(gym.Env):
"""This is a "going right" task. The task is to go right ``size`` steps.
"""
def __init__(self, size, sleep=0, dict_state=False, ma_rew=0,
multidiscrete_action=False, random_sleep=False):
def __init__(self, size, sleep=0, dict_state=False, recurse_state=False,
ma_rew=0, multidiscrete_action=False, random_sleep=False):
assert not (
dict_state and recurse_state), \
"dict_state and recurse_state cannot both be true"
self.size = size
self.sleep = sleep
self.random_sleep = random_sleep
self.dict_state = dict_state
self.recurse_state = recurse_state
self.ma_rew = ma_rew
self._md_action = multidiscrete_action
self.observation_space = Box(shape=(1, ), low=0, high=size - 1)
if dict_state:
self.observation_space = Dict(
{"index": Box(shape=(1, ), low=0, high=size - 1),
"rand": Box(shape=(1,), low=0, high=1, dtype=np.float64)})
elif recurse_state:
self.observation_space = Dict(
{"index": Box(shape=(1, ), low=0, high=size - 1),
"dict": Dict({
"tuple": Tuple((Discrete(2), Box(shape=(2,),
low=0, high=1, dtype=np.float64))),
"rand": Box(shape=(1, 2), low=0, high=1,
dtype=np.float64)})
})
else:
self.observation_space = Box(shape=(1, ), low=0, high=size - 1)
if multidiscrete_action:
self.action_space = MultiDiscrete([2, 2])
else:
self.action_space = Discrete(2)
self.reset()
self.done = False
self.index = 0
self.seed()
def seed(self, seed=0):
np.random.seed(seed)
self.rng = np.random.RandomState(seed)
def reset(self, state=0):
self.done = False
self.index = state
return self._get_dict_state()
return self._get_state()
def _get_reward(self):
"""Generate a non-scalar reward if ma_rew is True."""
@ -39,10 +59,18 @@ class MyTestEnv(gym.Env):
return [x] * self.ma_rew
return x
def _get_dict_state(self):
"""Generate a dict_state if dict_state is True."""
return {'index': self.index, 'rand': np.random.rand()} \
if self.dict_state else self.index
def _get_state(self):
"""Generate state(observation) of MyTestEnv"""
if self.dict_state:
return {'index': np.array([self.index], dtype=np.float32),
'rand': self.rng.rand(1)}
elif self.recurse_state:
return {'index': np.array([self.index], dtype=np.float32),
'dict': {"tuple": (np.array([1],
dtype=np.int64), self.rng.rand(2)),
"rand": self.rng.rand(1, 2)}}
else:
return np.array([self.index], dtype=np.float32)
def step(self, action):
if self._md_action:
@ -55,13 +83,13 @@ class MyTestEnv(gym.Env):
time.sleep(sleep_time)
if self.index == self.size:
self.done = True
return self._get_dict_state(), self._get_reward(), self.done, {}
return self._get_state(), self._get_reward(), self.done, {}
if action == 0:
self.index = max(self.index - 1, 0)
return self._get_dict_state(), self._get_reward(), self.done, \
return self._get_state(), self._get_reward(), self.done, \
{'key': 1, 'env': self} if self.dict_state else {}
elif action == 1:
self.index += 1
self.done = self.index == self.size
return self._get_dict_state(), self._get_reward(), \
return self._get_state(), self._get_reward(), \
self.done, {'key': 1, 'env': self}

View File

@ -67,10 +67,10 @@ def test_stack(size=5, bufsize=9, stack_num=4):
if done:
obs = env.reset(1)
indice = np.arange(len(buf))
assert np.allclose(buf.get(indice, 'obs'), np.array([
[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
[3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]]))
assert np.allclose(buf.get(indice, 'obs'), np.expand_dims(
[[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
[3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]], axis=-1))
print(buf)
_, indice = buf2.sample(0)
assert indice == [2]

View File

@ -72,34 +72,40 @@ def test_collector():
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False),
logger.preprocess_fn)
c0.collect(n_step=3)
assert np.allclose(c0.buffer.obs[:4], [0, 1, 0, 1])
assert np.allclose(c0.buffer[:4].obs_next, [1, 2, 1, 2])
assert np.allclose(c0.buffer.obs[:4], np.expand_dims(
[0, 1, 0, 1], axis=-1))
assert np.allclose(c0.buffer[:4].obs_next, np.expand_dims(
[1, 2, 1, 2], axis=-1))
c0.collect(n_episode=3)
assert np.allclose(c0.buffer.obs[:10], [0, 1, 0, 1, 0, 1, 0, 1, 0, 1])
assert np.allclose(c0.buffer[:10].obs_next, [1, 2, 1, 2, 1, 2, 1, 2, 1, 2])
assert np.allclose(c0.buffer.obs[:10], np.expand_dims(
[0, 1, 0, 1, 0, 1, 0, 1, 0, 1], axis=-1))
assert np.allclose(c0.buffer[:10].obs_next, np.expand_dims(
[1, 2, 1, 2, 1, 2, 1, 2, 1, 2], axis=-1))
c0.collect(n_step=3, random=True)
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False),
logger.preprocess_fn)
c1.collect(n_step=6)
assert np.allclose(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3])
assert np.allclose(c1.buffer[:11].obs_next,
[1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4])
assert np.allclose(c1.buffer.obs[:11], np.expand_dims(
[0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3], axis=-1))
assert np.allclose(c1.buffer[:11].obs_next, np.expand_dims([
1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4], axis=-1))
c1.collect(n_episode=2)
assert np.allclose(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2])
assert np.allclose(c1.buffer.obs[11:21], np.expand_dims(
[0, 1, 2, 3, 4, 0, 1, 0, 1, 2], axis=-1))
assert np.allclose(c1.buffer[11:21].obs_next,
[1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
np.expand_dims([1, 2, 3, 4, 5, 1, 2, 1, 2, 3], axis=-1))
c1.collect(n_episode=3, random=True)
c2 = Collector(policy, dum, ReplayBuffer(size=100, ignore_obs_next=False),
logger.preprocess_fn)
c2.collect(n_episode=[1, 2, 2, 2])
assert np.allclose(c2.buffer.obs_next[:26], [
assert np.allclose(c2.buffer.obs_next[:26], np.expand_dims([
1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5,
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5], axis=-1))
c2.reset_env()
c2.collect(n_episode=[2, 2, 2, 2])
assert np.allclose(c2.buffer.obs_next[26:54], [
assert np.allclose(c2.buffer.obs_next[26:54], np.expand_dims([
1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 5,
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5], axis=-1))
c2.collect(n_episode=[1, 1, 1, 1], random=True)
@ -145,6 +151,8 @@ def test_collector_with_async():
assert j - i == env_lens[env_id[i]]
obs_ground_truth += list(range(j - i))
i = j
obs_ground_truth = np.expand_dims(
np.array(obs_ground_truth), axis=-1)
assert np.allclose(obs, obs_ground_truth)
@ -169,10 +177,10 @@ def test_collector_with_dict_state():
batch = c1.sample(10)
print(batch)
c0.buffer.update(c1.buffer)
assert np.allclose(c0.buffer[:len(c0.buffer)].obs.index, [
assert np.allclose(c0.buffer[:len(c0.buffer)].obs.index, np.expand_dims([
0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1.,
0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0.,
1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.])
1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.], axis=-1))
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4),
Logger.single_preprocess_fn)
c2.collect(n_episode=[0, 0, 0, 10])
@ -204,10 +212,10 @@ def test_collector_with_ma():
batch = c1.sample(10)
print(batch)
c0.buffer.update(c1.buffer)
obs = [
obs = np.array(np.expand_dims([
0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1.,
0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0.,
1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.]
1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.], axis=-1))
assert np.allclose(c0.buffer[:len(c0.buffer)].obs, obs)
rew = [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1,
0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0,

View File

@ -3,7 +3,7 @@ import numpy as np
from gym.spaces.discrete import Discrete
from tianshou.data import Batch
from tianshou.env import VectorEnv, SubprocVectorEnv, \
RayVectorEnv, AsyncVectorEnv
RayVectorEnv, AsyncVectorEnv, ShmemVectorEnv
if __name__ == '__main__':
from env import MyTestEnv
@ -11,6 +11,24 @@ else: # pytest
from test.base.env import MyTestEnv
def recurse_comp(a, b):
try:
if isinstance(a, np.ndarray):
if a.dtype == np.object:
return np.array(
[recurse_comp(m, n) for m, n in zip(a, b)]).all()
else:
return np.allclose(a, b)
elif isinstance(a, (list, tuple)):
return np.array(
[recurse_comp(m, n) for m, n in zip(a, b)]).all()
elif isinstance(a, dict):
return np.array(
[recurse_comp(a[k], b[k]) for k in a.keys()]).all()
except(Exception):
return False
def test_async_env(num=8, sleep=0.1):
# simplify the test case, just keep stepping
size = 10000
@ -56,17 +74,18 @@ def test_async_env(num=8, sleep=0.1):
def test_vecenv(size=10, num=8, sleep=0.001):
verbose = __name__ == '__main__'
env_fns = [
lambda i=i: MyTestEnv(size=i, sleep=sleep)
lambda i=i: MyTestEnv(size=i, sleep=sleep, recurse_state=True)
for i in range(size, size + num)
]
venv = [
VectorEnv(env_fns),
SubprocVectorEnv(env_fns),
ShmemVectorEnv(env_fns),
]
if verbose:
venv.append(RayVectorEnv(env_fns))
for v in venv:
v.seed()
v.seed(0)
action_list = [1] * 5 + [0] * 10 + [1] * 20
if not verbose:
o = [v.reset() for v in venv]
@ -77,11 +96,13 @@ def test_vecenv(size=10, num=8, sleep=0.001):
if sum(C):
A = v.reset(np.where(C)[0])
o.append([A, B, C, D])
for i in zip(*o):
for j in range(1, len(i) - 1):
assert (i[0] == i[j]).all()
for index, infos in enumerate(zip(*o)):
if index == 3: # do not check info here
continue
for info in infos:
assert recurse_comp(infos[0], info)
else:
t = [0, 0, 0]
t = [0] * len(venv)
for i, e in enumerate(venv):
t[i] = time.time()
e.reset()
@ -90,9 +111,8 @@ def test_vecenv(size=10, num=8, sleep=0.001):
if sum(done) > 0:
e.reset(np.where(done)[0])
t[i] = time.time() - t[i]
print(f'VectorEnv: {t[0]:.6f}s')
print(f'SubprocVectorEnv: {t[1]:.6f}s')
print(f'RayVectorEnv: {t[2]:.6f}s')
for i, v in enumerate(venv):
print(f'{type(v)}: {t[i]:.6f}s')
for v in venv:
assert v.size == list(range(size, size + num))
assert v.env_num == num

View File

@ -3,6 +3,7 @@ from tianshou.env.vecenv.dummy import VectorEnv
from tianshou.env.vecenv.subproc import SubprocVectorEnv
from tianshou.env.vecenv.asyncenv import AsyncVectorEnv
from tianshou.env.vecenv.rayenv import RayVectorEnv
from tianshou.env.vecenv.shmemenv import ShmemVectorEnv
from tianshou.env.maenv import MultiAgentEnv
__all__ = [
@ -11,5 +12,6 @@ __all__ = [
'SubprocVectorEnv',
'AsyncVectorEnv',
'RayVectorEnv',
'ShmemVectorEnv',
'MultiAgentEnv',
]

177
tianshou/env/vecenv/shmemenv.py vendored Normal file
View File

@ -0,0 +1,177 @@
import gym
import ctypes
import numpy as np
from collections import OrderedDict
from multiprocessing import Pipe, Process, Array
from typing import Callable, List, Optional, Tuple, Union
from tianshou.env import BaseVectorEnv, SubprocVectorEnv
from tianshou.env.utils import CloudpickleWrapper
_NP_TO_CT = {np.bool: ctypes.c_bool,
np.bool_: ctypes.c_bool,
np.uint8: ctypes.c_uint8,
np.uint16: ctypes.c_uint16,
np.uint32: ctypes.c_uint32,
np.uint64: ctypes.c_uint64,
np.int8: ctypes.c_int8,
np.int16: ctypes.c_int16,
np.int32: ctypes.c_int32,
np.int64: ctypes.c_int64,
np.float32: ctypes.c_float,
np.float64: ctypes.c_double}
def _shmem_worker(parent, p, env_fn_wrapper, obs_bufs):
"""Control a single environment instance using IPC and shared memory."""
def _encode_obs(obs, buffer):
if isinstance(obs, np.ndarray):
buffer.save(obs)
elif isinstance(obs, tuple):
for o, b in zip(obs, buffer):
_encode_obs(o, b)
elif isinstance(obs, dict):
for k in obs.keys():
_encode_obs(obs[k], buffer[k])
return None
parent.close()
env = env_fn_wrapper.data()
try:
while True:
cmd, data = p.recv()
if cmd == 'step':
obs, reward, done, info = env.step(data)
p.send((_encode_obs(obs, obs_bufs), reward, done, info))
elif cmd == 'reset':
p.send(_encode_obs(env.reset(), obs_bufs))
elif cmd == 'close':
p.send(env.close())
p.close()
break
elif cmd == 'render':
p.send(env.render(**data) if hasattr(env, 'render') else None)
elif cmd == 'seed':
p.send(env.seed(data) if hasattr(env, 'seed') else None)
elif cmd == 'getattr':
p.send(getattr(env, data) if hasattr(env, data) else None)
else:
p.close()
raise NotImplementedError
except KeyboardInterrupt:
p.close()
class ShArray:
"""Wrapper of multiprocessing Array"""
def __init__(self, dtype, shape):
self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape)))
self.dtype = dtype
self.shape = shape
def save(self, ndarray):
assert isinstance(ndarray, np.ndarray)
dst = self.arr.get_obj()
dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape)
np.copyto(dst_np, ndarray)
def get(self):
return np.frombuffer(self.arr.get_obj(),
dtype=self.dtype).reshape(self.shape)
class ShmemVectorEnv(SubprocVectorEnv):
"""Optimized version of SubprocVectorEnv that uses shared variables to
communicate observations. SubprocVectorEnv has exactly the same API as
SubprocVectorEnv.
.. seealso::
Please refer to :class:`~tianshou.env.SubprocVectorEnv` for more
detailed explanation.
ShmemVectorEnv Class was inspired by openai baseline's implementation.
Please refer to 'https://github.com/openai/baselines/blob/master/baselines/
common/vec_env/shmem_vec_env.py' for more info if you are interested.
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
BaseVectorEnv.__init__(self, env_fns)
# Mind that SubprocVectorEnv is not initialised.
self.closed = False
dummy = env_fns[0]()
obs_space = dummy.observation_space
dummy.close()
del dummy
self.obs_bufs = [ShmemVectorEnv._setup_buf(obs_space)
for _ in range(self.env_num)]
self.parent_remote, self.child_remote = \
zip(*[Pipe() for _ in range(self.env_num)])
self.processes = [
Process(target=_shmem_worker, args=(
parent, child, CloudpickleWrapper(env_fn),
obs_buf), daemon=True)
for (parent, child, env_fn, obs_buf) in zip(
self.parent_remote, self.child_remote, env_fns, self.obs_bufs)
]
for p in self.processes:
p.start()
for c in self.child_remote:
c.close()
def step(self,
action: np.ndarray,
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
assert len(action) == len(id)
for i, j in enumerate(id):
self.parent_remote[j].send(['step', action[i]])
result = []
for i in id:
obs, rew, done, info = self.parent_remote[i].recv()
obs = self._decode_obs(obs, i)
result.append((obs, rew, done, info))
obs, rew, done, info = map(np.stack, zip(*result))
return obs, rew, done, info
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
for i in id:
self.parent_remote[i].send(['reset', None])
obs = np.stack(
[self._decode_obs(self.parent_remote[i].recv(), i) for i in id])
return obs
@staticmethod
def _setup_buf(space):
if isinstance(space, gym.spaces.Dict):
assert isinstance(space.spaces, OrderedDict)
buffer = {k: ShmemVectorEnv._setup_buf(v)
for k, v in space.spaces.items()}
elif isinstance(space, gym.spaces.Tuple):
assert isinstance(space.spaces, tuple)
buffer = tuple([ShmemVectorEnv._setup_buf(t)
for t in space.spaces])
else:
buffer = ShArray(space.dtype, space.shape)
return buffer
def _decode_obs(self, isNone, index):
def decode_obs(buffer):
if isinstance(buffer, ShArray):
return buffer.get()
elif isinstance(buffer, tuple):
return tuple([decode_obs(b) for b in buffer])
elif isinstance(buffer, dict):
return {k: decode_obs(v) for k, v in buffer.items()}
else:
raise NotImplementedError
return decode_obs(self.obs_bufs[index])

View File

@ -7,7 +7,7 @@ from tianshou.env import BaseVectorEnv
from tianshou.env.utils import CloudpickleWrapper
def worker(parent, p, env_fn_wrapper):
def _worker(parent, p, env_fn_wrapper):
parent.close()
env = env_fn_wrapper.data()
try:
@ -49,7 +49,7 @@ class SubprocVectorEnv(BaseVectorEnv):
self.parent_remote, self.child_remote = \
zip(*[Pipe() for _ in range(self.env_num)])
self.processes = [
Process(target=worker, args=(
Process(target=_worker, args=(
parent, child, CloudpickleWrapper(env_fn)), daemon=True)
for (parent, child, env_fn) in zip(
self.parent_remote, self.child_remote, env_fns)