Asynchronous sampling vector environment (#134)

Fix #103

Co-authored-by: youkaichao <youkaichao@126.com>
Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
Alexis DUBURCQ 2020-07-26 12:01:21 +02:00 committed by GitHub
parent 30368c29a6
commit e024afab8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 543 additions and 266 deletions

View File

@ -1,5 +1,6 @@
import gym
import time
import random
import numpy as np
from gym.spaces import Discrete, MultiDiscrete, Box
@ -9,9 +10,10 @@ class MyTestEnv(gym.Env):
"""
def __init__(self, size, sleep=0, dict_state=False, ma_rew=0,
multidiscrete_action=False):
multidiscrete_action=False, random_sleep=False):
self.size = size
self.sleep = sleep
self.random_sleep = random_sleep
self.dict_state = dict_state
self.ma_rew = ma_rew
self._md_action = multidiscrete_action
@ -48,7 +50,9 @@ class MyTestEnv(gym.Env):
if self.done:
raise ValueError('step after done !!!')
if self.sleep > 0:
time.sleep(self.sleep)
sleep_time = random.random() if self.random_sleep else 1
sleep_time *= self.sleep
time.sleep(sleep_time)
if self.index == self.size:
self.done = True
return self._get_dict_state(), self._get_reward(), self.done, {}

View File

@ -2,7 +2,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import BasePolicy
from tianshou.env import VectorEnv, SubprocVectorEnv
from tianshou.env import VectorEnv, SubprocVectorEnv, AsyncVectorEnv
from tianshou.data import Collector, Batch, ReplayBuffer
if __name__ == '__main__':
@ -103,6 +103,51 @@ def test_collector():
c2.collect(n_episode=[1, 1, 1, 1], random=True)
def test_collector_with_async():
env_lens = [2, 3, 4, 5]
writer = SummaryWriter('log/async_collector')
logger = Logger(writer)
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.1, random_sleep=True)
for i in env_lens]
venv = AsyncVectorEnv(env_fns)
policy = MyPolicy()
c1 = Collector(policy, venv,
ReplayBuffer(size=1000, ignore_obs_next=False),
logger.preprocess_fn)
c1.collect(n_episode=10)
# check if the data in the buffer is chronological
# i.e. data in the buffer are full episodes, and each episode is
# returned by the same environment
env_id = c1.buffer.info['env_id']
size = len(c1.buffer)
obs = c1.buffer.obs[:size]
done = c1.buffer.done[:size]
print(env_id[:size])
print(obs)
obs_ground_truth = []
i = 0
while i < size:
# i is the start of an episode
if done[i]:
# this episode has one transition
assert env_lens[env_id[i]] == 1
i += 1
continue
j = i
while True:
j += 1
# in one episode, the environment id is the same
assert env_id[j] == env_id[i]
if done[j]:
break
j = j + 1 # j is the start of the next episode
assert j - i == env_lens[env_id[i]]
obs_ground_truth += list(range(j - i))
i = j
assert np.allclose(obs, obs_ground_truth)
def test_collector_with_dict_state():
env = MyTestEnv(size=5, sleep=0, dict_state=True)
policy = MyPolicy(dict_state=True)
@ -181,3 +226,4 @@ if __name__ == '__main__':
test_collector()
test_collector_with_dict_state()
test_collector_with_ma()
test_collector_with_async()

View File

@ -1,7 +1,9 @@
import time
import numpy as np
from gym.spaces.discrete import Discrete
from tianshou.env import VectorEnv, SubprocVectorEnv, RayVectorEnv
from tianshou.data import Batch
from tianshou.env import VectorEnv, SubprocVectorEnv, \
RayVectorEnv, AsyncVectorEnv
if __name__ == '__main__':
from env import MyTestEnv
@ -9,6 +11,43 @@ else: # pytest
from test.base.env import MyTestEnv
def test_async_env(num=8, sleep=0.1):
# simplify the test case, just keep stepping
size = 10000
env_fns = [
lambda i=i: MyTestEnv(size=i, sleep=sleep, random_sleep=True)
for i in range(size, size + num)
]
v = AsyncVectorEnv(env_fns, wait_num=num // 2)
v.seed()
v.reset()
# for a random variable u ~ U[0, 1], let v = max{u1, u2, ..., un}
# P(v <= x) = x^n (0 <= x <= 1), pdf of v is nx^{n-1}
# expectation of v is n / (n + 1)
# for a synchronous environment, the following actions should take
# about 7 * sleep * num / (num + 1) seconds
# for AsyncVectorEnv, the analysis is complicated, but the time cost
# should be smaller
action_list = [1] * num + [0] * (num * 2) + [1] * (num * 4)
current_index_start = 0
action = action_list[:num]
env_ids = list(range(num))
o = []
spent_time = time.time()
while current_index_start < len(action_list):
A, B, C, D = v.step(action=action, id=env_ids)
b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D})
env_ids = b.info.env_id
o.append(b)
current_index_start += len(action)
action = action_list[current_index_start: current_index_start + len(A)]
spent_time = time.time() - spent_time
data = Batch.cat(o)
# assure 1/7 improvement
assert spent_time < 6.0 * sleep * num / (num + 1)
return spent_time, data
def test_vecenv(size=10, num=8, sleep=0.001):
verbose = __name__ == '__main__'
env_fns = [
@ -60,3 +99,4 @@ def test_vecenv(size=10, num=8, sleep=0.001):
if __name__ == '__main__':
test_vecenv()
test_async_env()

View File

@ -5,10 +5,11 @@ import warnings
import numpy as np
from typing import Any, Dict, List, Union, Optional, Callable
from tianshou.env import BaseVectorEnv, VectorEnv
from tianshou.env import BaseVectorEnv, VectorEnv, AsyncVectorEnv
from tianshou.policy import BasePolicy
from tianshou.exploration import BaseNoise
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
from tianshou.data.batch import _create_value
class Collector(object):
@ -96,6 +97,13 @@ class Collector(object):
env = VectorEnv([lambda: env])
self.env = env
self.env_num = len(env)
# environments that are available in step()
# this means all environments in synchronous simulation
# but only a subset of environments in asynchronous simulation
self._ready_env_ids = np.arange(self.env_num)
# self.async is a flag to indicate whether this collector works
# with asynchronous simulation
self.is_async = isinstance(env, AsyncVectorEnv)
# need cache buffers before storing in the main buffer
self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)]
self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
@ -105,6 +113,9 @@ class Collector(object):
self.process_fn = policy.process_fn
self._action_noise = action_noise
self._rew_metric = reward_metric or Collector._default_rew_metric
# avoid creating attribute outside __init__
self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={},
obs_next={}, policy={})
self.reset()
@staticmethod
@ -139,6 +150,7 @@ class Collector(object):
"""Reset all of the environment(s)' states and reset all of the cache
buffers (if need).
"""
self._ready_env_ids = np.arange(self.env_num)
obs = self.env.reset()
if self.preprocess_fn:
obs = self.preprocess_fn(obs=obs).get('obs', obs)
@ -159,7 +171,7 @@ class Collector(object):
self.env.close()
def _reset_state(self, id: Union[int, List[int]]) -> None:
"""Reset self.data.state[id]."""
"""Reset the hidden state: self.data.state[id]."""
state = self.data.state # it is a reference
if isinstance(state, torch.Tensor):
state[id].zero_()
@ -207,6 +219,7 @@ class Collector(object):
# episode of each environment
episode_count = np.zeros(self.env_num)
reward_total = 0.0
whole_data = Batch()
while True:
if step_count >= 100000 and episode_count.sum() == 0:
warnings.warn(
@ -214,6 +227,15 @@ class Collector(object):
'You should add a time limitation to your environment!',
Warning)
if self.is_async:
# self.data are the data for all environments
# in async simulation, only a subset of data are disposed
# so we store the whole data in ``whole_data``, let self.data
# to be all the data available in ready environments, and
# finally set these back into all the data
whole_data = self.data
self.data = self.data[self._ready_env_ids]
# restore the state and the input data
last_state = self.data.state
if last_state.is_empty():
@ -222,8 +244,16 @@ class Collector(object):
# calculate the next action
if random:
if self.is_async:
# TODO self.env.action_space will invoke remote call for
# all environments, which may hang in async simulation.
# This can be avoided by using a random policy, but not
# in the collector level. Leave it as a future work.
raise RuntimeError("cannot use random "
"sampling in async simulation!")
spaces = self.env.action_space
result = Batch(
act=[a.sample() for a in self.env.action_space])
act=[spaces[i].sample() for i in self._ready_env_ids])
else:
with torch.no_grad():
result = self.policy(self.data, last_state)
@ -243,8 +273,18 @@ class Collector(object):
self.data.act += self._action_noise(self.data.act.shape)
# step in env
obs_next, rew, done, info = self.env.step(self.data.act)
if not self.is_async:
obs_next, rew, done, info = self.env.step(self.data.act)
else:
# store computed actions, states, etc
_batch_set_item(whole_data, self._ready_env_ids,
self.data, self.env_num)
# fetch finished data
obs_next, rew, done, info = self.env.step(
action=self.data.act, id=self._ready_env_ids)
self._ready_env_ids = np.array([i['env_id'] for i in info])
# get the stepped data
self.data = whole_data[self._ready_env_ids]
# move data to self.data
self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
@ -256,9 +296,11 @@ class Collector(object):
if self.preprocess_fn:
result = self.preprocess_fn(**self.data)
self.data.update(result)
for i in range(self.env_num):
self._cached_buf[i].add(**self.data[i])
if self.data.done[i]:
for j, i in enumerate(self._ready_env_ids):
# j is the index in current ready_env_ids
# i is the index in all environments
self._cached_buf[i].add(**self.data[j])
if self.data.done[j]:
if n_step or np.isscalar(n_episode) or \
episode_count[i] < n_episode[i]:
episode_count[i] += 1
@ -267,17 +309,24 @@ class Collector(object):
if self.buffer is not None:
self.buffer.update(self._cached_buf[i])
self._cached_buf[i].reset()
self._reset_state(i)
self._reset_state(j)
obs_next = self.data.obs_next
if sum(self.data.done):
env_ind = np.where(self.data.done)[0]
obs_reset = self.env.reset(env_ind)
env_ind_local = np.where(self.data.done)[0]
env_ind_global = self._ready_env_ids[env_ind_local]
obs_reset = self.env.reset(env_ind_global)
if self.preprocess_fn:
obs_next[env_ind] = self.preprocess_fn(
obs_next[env_ind_local] = self.preprocess_fn(
obs=obs_reset).get('obs', obs_reset)
else:
obs_next[env_ind] = obs_reset
obs_next[env_ind_local] = obs_reset
self.data.obs = obs_next
if self.is_async:
# set data back
_batch_set_item(whole_data, self._ready_env_ids,
self.data, self.env_num)
# let self.data be the data in all environments again
self.data = whole_data
if n_step:
if step_count >= n_step:
break
@ -320,3 +369,24 @@ class Collector(object):
batch_data, indice = self.buffer.sample(batch_size)
batch_data = self.process_fn(batch_data, self.buffer, indice)
return batch_data
def _batch_set_item(source: Batch, indices: np.ndarray,
target: Batch, size: int):
# for any key chain k, there are three cases
# 1. source[k] is non-reserved, but target[k] does not exist or is reserved
# 2. source[k] does not exist or is reserved, but target[k] is non-reserved
# 3. both source[k] and target[k] is non-reserved
for k, v in target.items():
if not isinstance(v, Batch) or not v.is_empty():
# target[k] is non-reserved
vs = source.get(k, Batch())
if isinstance(vs, Batch) and vs.is_empty():
# case 2
# use __dict__ to avoid many type checks
source.__dict__[k] = _create_value(v[0], size)
else:
# target[k] is reserved
# case 1
continue
source.__dict__[k][indices] = v

View File

@ -1,11 +1,15 @@
from tianshou.env.basevecenv import BaseVectorEnv
from tianshou.env.vecenv import VectorEnv, SubprocVectorEnv, RayVectorEnv
from tianshou.env.vecenv.base import BaseVectorEnv
from tianshou.env.vecenv.dummy import VectorEnv
from tianshou.env.vecenv.subproc import SubprocVectorEnv
from tianshou.env.vecenv.asyncenv import AsyncVectorEnv
from tianshou.env.vecenv.rayenv import RayVectorEnv
from tianshou.env.maenv import MultiAgentEnv
__all__ = [
'BaseVectorEnv',
'VectorEnv',
'SubprocVectorEnv',
'AsyncVectorEnv',
'RayVectorEnv',
'MultiAgentEnv',
]

247
tianshou/env/vecenv.py vendored
View File

@ -1,247 +0,0 @@
import gym
import numpy as np
from multiprocessing import Process, Pipe
from typing import List, Tuple, Union, Optional, Callable, Any
try:
import ray
except ImportError:
pass
from tianshou.env import BaseVectorEnv
from tianshou.env.utils import CloudpickleWrapper
class VectorEnv(BaseVectorEnv):
"""Dummy vectorized environment wrapper, implemented in for-loop.
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
explanation.
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
super().__init__(env_fns)
self.envs = [_() for _ in env_fns]
def __getattr__(self, key):
return [getattr(env, key) if hasattr(env, key) else None
for env in self.envs]
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
obs = np.stack([self.envs[i].reset() for i in id])
return obs
def step(self,
action: np.ndarray,
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
assert len(action) == len(id)
result = [self.envs[i].step(action[i]) for i in id]
obs, rew, done, info = map(np.stack, zip(*result))
return obs, rew, done, info
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
if np.isscalar(seed):
seed = [seed + _ for _ in range(self.env_num)]
elif seed is None:
seed = [seed] * self.env_num
result = []
for e, s in zip(self.envs, seed):
if hasattr(e, 'seed'):
result.append(e.seed(s))
return result
def render(self, **kwargs) -> List[Any]:
result = []
for e in self.envs:
if hasattr(e, 'render'):
result.append(e.render(**kwargs))
return result
def close(self) -> List[Any]:
return [e.close() for e in self.envs]
def worker(parent, p, env_fn_wrapper):
parent.close()
env = env_fn_wrapper.data()
try:
while True:
cmd, data = p.recv()
if cmd == 'step':
p.send(env.step(data))
elif cmd == 'reset':
p.send(env.reset())
elif cmd == 'close':
p.send(env.close())
p.close()
break
elif cmd == 'render':
p.send(env.render(**data) if hasattr(env, 'render') else None)
elif cmd == 'seed':
p.send(env.seed(data) if hasattr(env, 'seed') else None)
elif cmd == 'getattr':
p.send(getattr(env, data) if hasattr(env, data) else None)
else:
p.close()
raise NotImplementedError
except KeyboardInterrupt:
p.close()
class SubprocVectorEnv(BaseVectorEnv):
"""Vectorized environment wrapper based on subprocess.
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
explanation.
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
super().__init__(env_fns)
self.closed = False
self.parent_remote, self.child_remote = \
zip(*[Pipe() for _ in range(self.env_num)])
self.processes = [
Process(target=worker, args=(
parent, child, CloudpickleWrapper(env_fn)), daemon=True)
for (parent, child, env_fn) in zip(
self.parent_remote, self.child_remote, env_fns)
]
for p in self.processes:
p.start()
for c in self.child_remote:
c.close()
def __getattr__(self, key):
for p in self.parent_remote:
p.send(['getattr', key])
return [p.recv() for p in self.parent_remote]
def step(self,
action: np.ndarray,
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
assert len(action) == len(id)
for i, j in enumerate(id):
self.parent_remote[j].send(['step', action[i]])
result = [self.parent_remote[i].recv() for i in id]
obs, rew, done, info = map(np.stack, zip(*result))
return obs, rew, done, info
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
for i in id:
self.parent_remote[i].send(['reset', None])
obs = np.stack([self.parent_remote[i].recv() for i in id])
return obs
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
if np.isscalar(seed):
seed = [seed + _ for _ in range(self.env_num)]
elif seed is None:
seed = [seed] * self.env_num
for p, s in zip(self.parent_remote, seed):
p.send(['seed', s])
return [p.recv() for p in self.parent_remote]
def render(self, **kwargs) -> List[Any]:
for p in self.parent_remote:
p.send(['render', kwargs])
return [p.recv() for p in self.parent_remote]
def close(self) -> List[Any]:
if self.closed:
return []
for p in self.parent_remote:
p.send(['close', None])
result = [p.recv() for p in self.parent_remote]
self.closed = True
for p in self.processes:
p.join()
return result
class RayVectorEnv(BaseVectorEnv):
"""Vectorized environment wrapper based on
`ray <https://github.com/ray-project/ray>`_. However, according to our
test, it is about two times slower than
:class:`~tianshou.env.SubprocVectorEnv`.
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
explanation.
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
super().__init__(env_fns)
try:
if not ray.is_initialized():
ray.init()
except NameError:
raise ImportError(
'Please install ray to support RayVectorEnv: pip3 install ray')
self.envs = [
ray.remote(gym.Wrapper).options(num_cpus=0).remote(e())
for e in env_fns]
def __getattr__(self, key):
return ray.get([e.__getattr__.remote(key) for e in self.envs])
def step(self,
action: np.ndarray,
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
assert len(action) == len(id)
result = ray.get([self.envs[j].step.remote(action[i])
for i, j in enumerate(id)])
obs, rew, done, info = map(np.stack, zip(*result))
return obs, rew, done, info
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
obs = np.stack(ray.get([self.envs[i].reset.remote() for i in id]))
return obs
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
if not hasattr(self.envs[0], 'seed'):
return []
if np.isscalar(seed):
seed = [seed + _ for _ in range(self.env_num)]
elif seed is None:
seed = [seed] * self.env_num
return ray.get([e.seed.remote(s) for e, s in zip(self.envs, seed)])
def render(self, **kwargs) -> List[Any]:
if not hasattr(self.envs[0], 'render'):
return [None for e in self.envs]
return ray.get([e.render.remote(**kwargs) for e in self.envs])
def close(self) -> List[Any]:
return ray.get([e.close.remote() for e in self.envs])

0
tianshou/env/vecenv/__init__.py vendored Normal file
View File

104
tianshou/env/vecenv/asyncenv.py vendored Normal file
View File

@ -0,0 +1,104 @@
import gym
import numpy as np
from multiprocessing import connection
from typing import List, Tuple, Union, Optional, Callable, Any
from tianshou.env import SubprocVectorEnv
class AsyncVectorEnv(SubprocVectorEnv):
"""Vectorized asynchronous environment wrapper based on subprocess.
:param wait_num: used in asynchronous simulation if the time cost of
``env.step`` varies with time and synchronously waiting for all
environments to finish a step is time-wasting. In that case, we can
return when ``wait_num`` environments finish a step and keep on
simulation in these environments. If ``None``, asynchronous simulation
is disabled; else, ``1 <= wait_num <= env_num``.
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
explanation.
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]],
wait_num: Optional[int] = None) -> None:
super().__init__(env_fns)
self.wait_num = wait_num or len(env_fns)
assert 1 <= self.wait_num <= len(env_fns), \
f'wait_num should be in [1, {len(env_fns)}], but got {wait_num}'
self.waiting_conn = []
# environments in self.ready_id is actually ready
# but environments in self.waiting_id are just waiting when checked,
# and they may be ready now, but this is not known until we check it
# in the step() function
self.waiting_id = []
# all environments are ready in the beginning
self.ready_id = list(range(self.env_num))
def _assert_and_transform_id(self,
id: Optional[Union[int, List[int]]] = None
) -> List[int]:
if id is None:
id = list(range(self.env_num))
elif np.isscalar(id):
id = [id]
for i in id:
assert i not in self.waiting_id, \
f'Cannot reset environment {i} which is stepping now!'
assert i in self.ready_id, \
f'Can only reset ready environments {self.ready_id}.'
return id
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
id = self._assert_and_transform_id(id)
return super().reset(id)
def render(self, **kwargs) -> List[Any]:
if len(self.waiting_id) > 0:
raise RuntimeError(
f"Environments {self.waiting_id} are still "
f"stepping, cannot render them now.")
return super().render(**kwargs)
def close(self) -> List[Any]:
if self.closed:
return []
# finish remaining steps, and close
self.step(None)
return super().close()
def step(self,
action: Optional[np.ndarray],
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Provide the given action to the environments. The action sequence
should correspond to the ``id`` argument, and the ``id`` argument
should be a subset of the ``env_id`` in the last returned ``info``
(initially they are env_ids of all the environments). If action is
``None``, fetch unfinished step() calls instead.
"""
if action is not None:
id = self._assert_and_transform_id(id)
assert len(action) == len(id)
for i, (act, env_id) in enumerate(zip(action, id)):
self.parent_remote[env_id].send(['step', act])
self.waiting_conn.append(self.parent_remote[env_id])
self.waiting_id.append(env_id)
self.ready_id = [x for x in self.ready_id if x not in id]
result = []
while len(self.waiting_conn) > 0 and len(result) < self.wait_num:
ready_conns = connection.wait(self.waiting_conn)
for conn in ready_conns:
waiting_index = self.waiting_conn.index(conn)
self.waiting_conn.pop(waiting_index)
env_id = self.waiting_id.pop(waiting_index)
ans = conn.recv()
obs, rew, done, info = ans
info["env_id"] = env_id
result.append((obs, rew, done, info))
self.ready_id.append(env_id)
obs, rew, done, info = map(np.stack, zip(*result))
return obs, rew, done, info

65
tianshou/env/vecenv/dummy.py vendored Normal file
View File

@ -0,0 +1,65 @@
import gym
import numpy as np
from typing import List, Tuple, Union, Optional, Callable, Any
from tianshou.env import BaseVectorEnv
class VectorEnv(BaseVectorEnv):
"""Dummy vectorized environment wrapper, implemented in for-loop.
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
explanation.
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
super().__init__(env_fns)
self.envs = [_() for _ in env_fns]
def __getattr__(self, key):
return [getattr(env, key) if hasattr(env, key) else None
for env in self.envs]
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
obs = np.stack([self.envs[i].reset() for i in id])
return obs
def step(self,
action: np.ndarray,
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
assert len(action) == len(id)
result = [self.envs[i].step(action[i]) for i in id]
obs, rew, done, info = map(np.stack, zip(*result))
return obs, rew, done, info
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
if np.isscalar(seed):
seed = [seed + _ for _ in range(self.env_num)]
elif seed is None:
seed = [seed] * self.env_num
result = []
for e, s in zip(self.envs, seed):
if hasattr(e, 'seed'):
result.append(e.seed(s))
return result
def render(self, **kwargs) -> List[Any]:
result = []
for e in self.envs:
if hasattr(e, 'render'):
result.append(e.render(**kwargs))
return result
def close(self) -> List[Any]:
return [e.close() for e in self.envs]

76
tianshou/env/vecenv/rayenv.py vendored Normal file
View File

@ -0,0 +1,76 @@
import gym
import numpy as np
from typing import List, Tuple, Union, Optional, Callable, Any
try:
import ray
except ImportError:
pass
from tianshou.env import BaseVectorEnv
class RayVectorEnv(BaseVectorEnv):
"""Vectorized environment wrapper based on
`ray <https://github.com/ray-project/ray>`_. This is a choice to run
distributed environments in a cluster.
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
explanation.
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
super().__init__(env_fns)
try:
if not ray.is_initialized():
ray.init()
except NameError:
raise ImportError(
'Please install ray to support RayVectorEnv: pip install ray')
self.envs = [
ray.remote(gym.Wrapper).options(num_cpus=0).remote(e())
for e in env_fns]
def __getattr__(self, key):
return ray.get([e.__getattr__.remote(key) for e in self.envs])
def step(self,
action: np.ndarray,
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
assert len(action) == len(id)
result = ray.get([self.envs[j].step.remote(action[i])
for i, j in enumerate(id)])
obs, rew, done, info = map(np.stack, zip(*result))
return obs, rew, done, info
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
obs = np.stack(ray.get([self.envs[i].reset.remote() for i in id]))
return obs
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
if not hasattr(self.envs[0], 'seed'):
return []
if np.isscalar(seed):
seed = [seed + _ for _ in range(self.env_num)]
elif seed is None:
seed = [seed] * self.env_num
return ray.get([e.seed.remote(s) for e, s in zip(self.envs, seed)])
def render(self, **kwargs) -> List[Any]:
if not hasattr(self.envs[0], 'render'):
return [None for e in self.envs]
return ray.get([e.render.remote(**kwargs) for e in self.envs])
def close(self) -> List[Any]:
return ray.get([e.close.remote() for e in self.envs])

115
tianshou/env/vecenv/subproc.py vendored Normal file
View File

@ -0,0 +1,115 @@
import gym
import numpy as np
from multiprocessing import Process, Pipe
from typing import List, Tuple, Union, Optional, Callable, Any
from tianshou.env import BaseVectorEnv
from tianshou.env.utils import CloudpickleWrapper
def worker(parent, p, env_fn_wrapper):
parent.close()
env = env_fn_wrapper.data()
try:
while True:
cmd, data = p.recv()
if cmd == 'step':
p.send(env.step(data))
elif cmd == 'reset':
p.send(env.reset())
elif cmd == 'close':
p.send(env.close())
p.close()
break
elif cmd == 'render':
p.send(env.render(**data) if hasattr(env, 'render') else None)
elif cmd == 'seed':
p.send(env.seed(data) if hasattr(env, 'seed') else None)
elif cmd == 'getattr':
p.send(getattr(env, data) if hasattr(env, data) else None)
else:
p.close()
raise NotImplementedError
except KeyboardInterrupt:
p.close()
class SubprocVectorEnv(BaseVectorEnv):
"""Vectorized environment wrapper based on subprocess.
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
explanation.
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
super().__init__(env_fns)
self.closed = False
self.parent_remote, self.child_remote = \
zip(*[Pipe() for _ in range(self.env_num)])
self.processes = [
Process(target=worker, args=(
parent, child, CloudpickleWrapper(env_fn)), daemon=True)
for (parent, child, env_fn) in zip(
self.parent_remote, self.child_remote, env_fns)
]
for p in self.processes:
p.start()
for c in self.child_remote:
c.close()
def __getattr__(self, key):
for p in self.parent_remote:
p.send(['getattr', key])
return [p.recv() for p in self.parent_remote]
def step(self,
action: np.ndarray,
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
assert len(action) == len(id)
for i, j in enumerate(id):
self.parent_remote[j].send(['step', action[i]])
result = [self.parent_remote[i].recv() for i in id]
obs, rew, done, info = map(np.stack, zip(*result))
return obs, rew, done, info
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
for i in id:
self.parent_remote[i].send(['reset', None])
obs = np.stack([self.parent_remote[i].recv() for i in id])
return obs
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
if np.isscalar(seed):
seed = [seed + _ for _ in range(self.env_num)]
elif seed is None:
seed = [seed] * self.env_num
for p, s in zip(self.parent_remote, seed):
p.send(['seed', s])
return [p.recv() for p in self.parent_remote]
def render(self, **kwargs) -> List[Any]:
for p in self.parent_remote:
p.send(['render', kwargs])
return [p.recv() for p in self.parent_remote]
def close(self) -> List[Any]:
if self.closed:
return []
for p in self.parent_remote:
p.send(['close', None])
result = [p.recv() for p in self.parent_remote]
self.closed = True
for p in self.processes:
p.join()
return result