Asynchronous sampling vector environment (#134)
Fix #103 Co-authored-by: youkaichao <youkaichao@126.com> Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
parent
30368c29a6
commit
e024afab8c
@ -1,5 +1,6 @@
|
||||
import gym
|
||||
import time
|
||||
import random
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, MultiDiscrete, Box
|
||||
|
||||
@ -9,9 +10,10 @@ class MyTestEnv(gym.Env):
|
||||
"""
|
||||
|
||||
def __init__(self, size, sleep=0, dict_state=False, ma_rew=0,
|
||||
multidiscrete_action=False):
|
||||
multidiscrete_action=False, random_sleep=False):
|
||||
self.size = size
|
||||
self.sleep = sleep
|
||||
self.random_sleep = random_sleep
|
||||
self.dict_state = dict_state
|
||||
self.ma_rew = ma_rew
|
||||
self._md_action = multidiscrete_action
|
||||
@ -48,7 +50,9 @@ class MyTestEnv(gym.Env):
|
||||
if self.done:
|
||||
raise ValueError('step after done !!!')
|
||||
if self.sleep > 0:
|
||||
time.sleep(self.sleep)
|
||||
sleep_time = random.random() if self.random_sleep else 1
|
||||
sleep_time *= self.sleep
|
||||
time.sleep(sleep_time)
|
||||
if self.index == self.size:
|
||||
self.done = True
|
||||
return self._get_dict_state(), self._get_reward(), self.done, {}
|
||||
|
@ -2,7 +2,7 @@ import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||
from tianshou.env import VectorEnv, SubprocVectorEnv, AsyncVectorEnv
|
||||
from tianshou.data import Collector, Batch, ReplayBuffer
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -103,6 +103,51 @@ def test_collector():
|
||||
c2.collect(n_episode=[1, 1, 1, 1], random=True)
|
||||
|
||||
|
||||
def test_collector_with_async():
|
||||
env_lens = [2, 3, 4, 5]
|
||||
writer = SummaryWriter('log/async_collector')
|
||||
logger = Logger(writer)
|
||||
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.1, random_sleep=True)
|
||||
for i in env_lens]
|
||||
|
||||
venv = AsyncVectorEnv(env_fns)
|
||||
policy = MyPolicy()
|
||||
c1 = Collector(policy, venv,
|
||||
ReplayBuffer(size=1000, ignore_obs_next=False),
|
||||
logger.preprocess_fn)
|
||||
c1.collect(n_episode=10)
|
||||
# check if the data in the buffer is chronological
|
||||
# i.e. data in the buffer are full episodes, and each episode is
|
||||
# returned by the same environment
|
||||
env_id = c1.buffer.info['env_id']
|
||||
size = len(c1.buffer)
|
||||
obs = c1.buffer.obs[:size]
|
||||
done = c1.buffer.done[:size]
|
||||
print(env_id[:size])
|
||||
print(obs)
|
||||
obs_ground_truth = []
|
||||
i = 0
|
||||
while i < size:
|
||||
# i is the start of an episode
|
||||
if done[i]:
|
||||
# this episode has one transition
|
||||
assert env_lens[env_id[i]] == 1
|
||||
i += 1
|
||||
continue
|
||||
j = i
|
||||
while True:
|
||||
j += 1
|
||||
# in one episode, the environment id is the same
|
||||
assert env_id[j] == env_id[i]
|
||||
if done[j]:
|
||||
break
|
||||
j = j + 1 # j is the start of the next episode
|
||||
assert j - i == env_lens[env_id[i]]
|
||||
obs_ground_truth += list(range(j - i))
|
||||
i = j
|
||||
assert np.allclose(obs, obs_ground_truth)
|
||||
|
||||
|
||||
def test_collector_with_dict_state():
|
||||
env = MyTestEnv(size=5, sleep=0, dict_state=True)
|
||||
policy = MyPolicy(dict_state=True)
|
||||
@ -181,3 +226,4 @@ if __name__ == '__main__':
|
||||
test_collector()
|
||||
test_collector_with_dict_state()
|
||||
test_collector_with_ma()
|
||||
test_collector_with_async()
|
||||
|
@ -1,7 +1,9 @@
|
||||
import time
|
||||
import numpy as np
|
||||
from gym.spaces.discrete import Discrete
|
||||
from tianshou.env import VectorEnv, SubprocVectorEnv, RayVectorEnv
|
||||
from tianshou.data import Batch
|
||||
from tianshou.env import VectorEnv, SubprocVectorEnv, \
|
||||
RayVectorEnv, AsyncVectorEnv
|
||||
|
||||
if __name__ == '__main__':
|
||||
from env import MyTestEnv
|
||||
@ -9,6 +11,43 @@ else: # pytest
|
||||
from test.base.env import MyTestEnv
|
||||
|
||||
|
||||
def test_async_env(num=8, sleep=0.1):
|
||||
# simplify the test case, just keep stepping
|
||||
size = 10000
|
||||
env_fns = [
|
||||
lambda i=i: MyTestEnv(size=i, sleep=sleep, random_sleep=True)
|
||||
for i in range(size, size + num)
|
||||
]
|
||||
v = AsyncVectorEnv(env_fns, wait_num=num // 2)
|
||||
v.seed()
|
||||
v.reset()
|
||||
# for a random variable u ~ U[0, 1], let v = max{u1, u2, ..., un}
|
||||
# P(v <= x) = x^n (0 <= x <= 1), pdf of v is nx^{n-1}
|
||||
# expectation of v is n / (n + 1)
|
||||
# for a synchronous environment, the following actions should take
|
||||
# about 7 * sleep * num / (num + 1) seconds
|
||||
# for AsyncVectorEnv, the analysis is complicated, but the time cost
|
||||
# should be smaller
|
||||
action_list = [1] * num + [0] * (num * 2) + [1] * (num * 4)
|
||||
current_index_start = 0
|
||||
action = action_list[:num]
|
||||
env_ids = list(range(num))
|
||||
o = []
|
||||
spent_time = time.time()
|
||||
while current_index_start < len(action_list):
|
||||
A, B, C, D = v.step(action=action, id=env_ids)
|
||||
b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D})
|
||||
env_ids = b.info.env_id
|
||||
o.append(b)
|
||||
current_index_start += len(action)
|
||||
action = action_list[current_index_start: current_index_start + len(A)]
|
||||
spent_time = time.time() - spent_time
|
||||
data = Batch.cat(o)
|
||||
# assure 1/7 improvement
|
||||
assert spent_time < 6.0 * sleep * num / (num + 1)
|
||||
return spent_time, data
|
||||
|
||||
|
||||
def test_vecenv(size=10, num=8, sleep=0.001):
|
||||
verbose = __name__ == '__main__'
|
||||
env_fns = [
|
||||
@ -60,3 +99,4 @@ def test_vecenv(size=10, num=8, sleep=0.001):
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_vecenv()
|
||||
test_async_env()
|
||||
|
@ -5,10 +5,11 @@ import warnings
|
||||
import numpy as np
|
||||
from typing import Any, Dict, List, Union, Optional, Callable
|
||||
|
||||
from tianshou.env import BaseVectorEnv, VectorEnv
|
||||
from tianshou.env import BaseVectorEnv, VectorEnv, AsyncVectorEnv
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.exploration import BaseNoise
|
||||
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
|
||||
from tianshou.data.batch import _create_value
|
||||
|
||||
|
||||
class Collector(object):
|
||||
@ -96,6 +97,13 @@ class Collector(object):
|
||||
env = VectorEnv([lambda: env])
|
||||
self.env = env
|
||||
self.env_num = len(env)
|
||||
# environments that are available in step()
|
||||
# this means all environments in synchronous simulation
|
||||
# but only a subset of environments in asynchronous simulation
|
||||
self._ready_env_ids = np.arange(self.env_num)
|
||||
# self.async is a flag to indicate whether this collector works
|
||||
# with asynchronous simulation
|
||||
self.is_async = isinstance(env, AsyncVectorEnv)
|
||||
# need cache buffers before storing in the main buffer
|
||||
self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)]
|
||||
self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
|
||||
@ -105,6 +113,9 @@ class Collector(object):
|
||||
self.process_fn = policy.process_fn
|
||||
self._action_noise = action_noise
|
||||
self._rew_metric = reward_metric or Collector._default_rew_metric
|
||||
# avoid creating attribute outside __init__
|
||||
self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={},
|
||||
obs_next={}, policy={})
|
||||
self.reset()
|
||||
|
||||
@staticmethod
|
||||
@ -139,6 +150,7 @@ class Collector(object):
|
||||
"""Reset all of the environment(s)' states and reset all of the cache
|
||||
buffers (if need).
|
||||
"""
|
||||
self._ready_env_ids = np.arange(self.env_num)
|
||||
obs = self.env.reset()
|
||||
if self.preprocess_fn:
|
||||
obs = self.preprocess_fn(obs=obs).get('obs', obs)
|
||||
@ -159,7 +171,7 @@ class Collector(object):
|
||||
self.env.close()
|
||||
|
||||
def _reset_state(self, id: Union[int, List[int]]) -> None:
|
||||
"""Reset self.data.state[id]."""
|
||||
"""Reset the hidden state: self.data.state[id]."""
|
||||
state = self.data.state # it is a reference
|
||||
if isinstance(state, torch.Tensor):
|
||||
state[id].zero_()
|
||||
@ -207,6 +219,7 @@ class Collector(object):
|
||||
# episode of each environment
|
||||
episode_count = np.zeros(self.env_num)
|
||||
reward_total = 0.0
|
||||
whole_data = Batch()
|
||||
while True:
|
||||
if step_count >= 100000 and episode_count.sum() == 0:
|
||||
warnings.warn(
|
||||
@ -214,6 +227,15 @@ class Collector(object):
|
||||
'You should add a time limitation to your environment!',
|
||||
Warning)
|
||||
|
||||
if self.is_async:
|
||||
# self.data are the data for all environments
|
||||
# in async simulation, only a subset of data are disposed
|
||||
# so we store the whole data in ``whole_data``, let self.data
|
||||
# to be all the data available in ready environments, and
|
||||
# finally set these back into all the data
|
||||
whole_data = self.data
|
||||
self.data = self.data[self._ready_env_ids]
|
||||
|
||||
# restore the state and the input data
|
||||
last_state = self.data.state
|
||||
if last_state.is_empty():
|
||||
@ -222,8 +244,16 @@ class Collector(object):
|
||||
|
||||
# calculate the next action
|
||||
if random:
|
||||
if self.is_async:
|
||||
# TODO self.env.action_space will invoke remote call for
|
||||
# all environments, which may hang in async simulation.
|
||||
# This can be avoided by using a random policy, but not
|
||||
# in the collector level. Leave it as a future work.
|
||||
raise RuntimeError("cannot use random "
|
||||
"sampling in async simulation!")
|
||||
spaces = self.env.action_space
|
||||
result = Batch(
|
||||
act=[a.sample() for a in self.env.action_space])
|
||||
act=[spaces[i].sample() for i in self._ready_env_ids])
|
||||
else:
|
||||
with torch.no_grad():
|
||||
result = self.policy(self.data, last_state)
|
||||
@ -243,8 +273,18 @@ class Collector(object):
|
||||
self.data.act += self._action_noise(self.data.act.shape)
|
||||
|
||||
# step in env
|
||||
obs_next, rew, done, info = self.env.step(self.data.act)
|
||||
|
||||
if not self.is_async:
|
||||
obs_next, rew, done, info = self.env.step(self.data.act)
|
||||
else:
|
||||
# store computed actions, states, etc
|
||||
_batch_set_item(whole_data, self._ready_env_ids,
|
||||
self.data, self.env_num)
|
||||
# fetch finished data
|
||||
obs_next, rew, done, info = self.env.step(
|
||||
action=self.data.act, id=self._ready_env_ids)
|
||||
self._ready_env_ids = np.array([i['env_id'] for i in info])
|
||||
# get the stepped data
|
||||
self.data = whole_data[self._ready_env_ids]
|
||||
# move data to self.data
|
||||
self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
|
||||
|
||||
@ -256,9 +296,11 @@ class Collector(object):
|
||||
if self.preprocess_fn:
|
||||
result = self.preprocess_fn(**self.data)
|
||||
self.data.update(result)
|
||||
for i in range(self.env_num):
|
||||
self._cached_buf[i].add(**self.data[i])
|
||||
if self.data.done[i]:
|
||||
for j, i in enumerate(self._ready_env_ids):
|
||||
# j is the index in current ready_env_ids
|
||||
# i is the index in all environments
|
||||
self._cached_buf[i].add(**self.data[j])
|
||||
if self.data.done[j]:
|
||||
if n_step or np.isscalar(n_episode) or \
|
||||
episode_count[i] < n_episode[i]:
|
||||
episode_count[i] += 1
|
||||
@ -267,17 +309,24 @@ class Collector(object):
|
||||
if self.buffer is not None:
|
||||
self.buffer.update(self._cached_buf[i])
|
||||
self._cached_buf[i].reset()
|
||||
self._reset_state(i)
|
||||
self._reset_state(j)
|
||||
obs_next = self.data.obs_next
|
||||
if sum(self.data.done):
|
||||
env_ind = np.where(self.data.done)[0]
|
||||
obs_reset = self.env.reset(env_ind)
|
||||
env_ind_local = np.where(self.data.done)[0]
|
||||
env_ind_global = self._ready_env_ids[env_ind_local]
|
||||
obs_reset = self.env.reset(env_ind_global)
|
||||
if self.preprocess_fn:
|
||||
obs_next[env_ind] = self.preprocess_fn(
|
||||
obs_next[env_ind_local] = self.preprocess_fn(
|
||||
obs=obs_reset).get('obs', obs_reset)
|
||||
else:
|
||||
obs_next[env_ind] = obs_reset
|
||||
obs_next[env_ind_local] = obs_reset
|
||||
self.data.obs = obs_next
|
||||
if self.is_async:
|
||||
# set data back
|
||||
_batch_set_item(whole_data, self._ready_env_ids,
|
||||
self.data, self.env_num)
|
||||
# let self.data be the data in all environments again
|
||||
self.data = whole_data
|
||||
if n_step:
|
||||
if step_count >= n_step:
|
||||
break
|
||||
@ -320,3 +369,24 @@ class Collector(object):
|
||||
batch_data, indice = self.buffer.sample(batch_size)
|
||||
batch_data = self.process_fn(batch_data, self.buffer, indice)
|
||||
return batch_data
|
||||
|
||||
|
||||
def _batch_set_item(source: Batch, indices: np.ndarray,
|
||||
target: Batch, size: int):
|
||||
# for any key chain k, there are three cases
|
||||
# 1. source[k] is non-reserved, but target[k] does not exist or is reserved
|
||||
# 2. source[k] does not exist or is reserved, but target[k] is non-reserved
|
||||
# 3. both source[k] and target[k] is non-reserved
|
||||
for k, v in target.items():
|
||||
if not isinstance(v, Batch) or not v.is_empty():
|
||||
# target[k] is non-reserved
|
||||
vs = source.get(k, Batch())
|
||||
if isinstance(vs, Batch) and vs.is_empty():
|
||||
# case 2
|
||||
# use __dict__ to avoid many type checks
|
||||
source.__dict__[k] = _create_value(v[0], size)
|
||||
else:
|
||||
# target[k] is reserved
|
||||
# case 1
|
||||
continue
|
||||
source.__dict__[k][indices] = v
|
||||
|
8
tianshou/env/__init__.py
vendored
8
tianshou/env/__init__.py
vendored
@ -1,11 +1,15 @@
|
||||
from tianshou.env.basevecenv import BaseVectorEnv
|
||||
from tianshou.env.vecenv import VectorEnv, SubprocVectorEnv, RayVectorEnv
|
||||
from tianshou.env.vecenv.base import BaseVectorEnv
|
||||
from tianshou.env.vecenv.dummy import VectorEnv
|
||||
from tianshou.env.vecenv.subproc import SubprocVectorEnv
|
||||
from tianshou.env.vecenv.asyncenv import AsyncVectorEnv
|
||||
from tianshou.env.vecenv.rayenv import RayVectorEnv
|
||||
from tianshou.env.maenv import MultiAgentEnv
|
||||
|
||||
__all__ = [
|
||||
'BaseVectorEnv',
|
||||
'VectorEnv',
|
||||
'SubprocVectorEnv',
|
||||
'AsyncVectorEnv',
|
||||
'RayVectorEnv',
|
||||
'MultiAgentEnv',
|
||||
]
|
||||
|
247
tianshou/env/vecenv.py
vendored
247
tianshou/env/vecenv.py
vendored
@ -1,247 +0,0 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
from multiprocessing import Process, Pipe
|
||||
from typing import List, Tuple, Union, Optional, Callable, Any
|
||||
|
||||
try:
|
||||
import ray
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from tianshou.env import BaseVectorEnv
|
||||
from tianshou.env.utils import CloudpickleWrapper
|
||||
|
||||
|
||||
class VectorEnv(BaseVectorEnv):
|
||||
"""Dummy vectorized environment wrapper, implemented in for-loop.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
||||
super().__init__(env_fns)
|
||||
self.envs = [_() for _ in env_fns]
|
||||
|
||||
def __getattr__(self, key):
|
||||
return [getattr(env, key) if hasattr(env, key) else None
|
||||
for env in self.envs]
|
||||
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||
if id is None:
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
obs = np.stack([self.envs[i].reset() for i in id])
|
||||
return obs
|
||||
|
||||
def step(self,
|
||||
action: np.ndarray,
|
||||
id: Optional[Union[int, List[int]]] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
if id is None:
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
assert len(action) == len(id)
|
||||
result = [self.envs[i].step(action[i]) for i in id]
|
||||
obs, rew, done, info = map(np.stack, zip(*result))
|
||||
return obs, rew, done, info
|
||||
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
|
||||
if np.isscalar(seed):
|
||||
seed = [seed + _ for _ in range(self.env_num)]
|
||||
elif seed is None:
|
||||
seed = [seed] * self.env_num
|
||||
result = []
|
||||
for e, s in zip(self.envs, seed):
|
||||
if hasattr(e, 'seed'):
|
||||
result.append(e.seed(s))
|
||||
return result
|
||||
|
||||
def render(self, **kwargs) -> List[Any]:
|
||||
result = []
|
||||
for e in self.envs:
|
||||
if hasattr(e, 'render'):
|
||||
result.append(e.render(**kwargs))
|
||||
return result
|
||||
|
||||
def close(self) -> List[Any]:
|
||||
return [e.close() for e in self.envs]
|
||||
|
||||
|
||||
def worker(parent, p, env_fn_wrapper):
|
||||
parent.close()
|
||||
env = env_fn_wrapper.data()
|
||||
try:
|
||||
while True:
|
||||
cmd, data = p.recv()
|
||||
if cmd == 'step':
|
||||
p.send(env.step(data))
|
||||
elif cmd == 'reset':
|
||||
p.send(env.reset())
|
||||
elif cmd == 'close':
|
||||
p.send(env.close())
|
||||
p.close()
|
||||
break
|
||||
elif cmd == 'render':
|
||||
p.send(env.render(**data) if hasattr(env, 'render') else None)
|
||||
elif cmd == 'seed':
|
||||
p.send(env.seed(data) if hasattr(env, 'seed') else None)
|
||||
elif cmd == 'getattr':
|
||||
p.send(getattr(env, data) if hasattr(env, data) else None)
|
||||
else:
|
||||
p.close()
|
||||
raise NotImplementedError
|
||||
except KeyboardInterrupt:
|
||||
p.close()
|
||||
|
||||
|
||||
class SubprocVectorEnv(BaseVectorEnv):
|
||||
"""Vectorized environment wrapper based on subprocess.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
||||
super().__init__(env_fns)
|
||||
self.closed = False
|
||||
self.parent_remote, self.child_remote = \
|
||||
zip(*[Pipe() for _ in range(self.env_num)])
|
||||
self.processes = [
|
||||
Process(target=worker, args=(
|
||||
parent, child, CloudpickleWrapper(env_fn)), daemon=True)
|
||||
for (parent, child, env_fn) in zip(
|
||||
self.parent_remote, self.child_remote, env_fns)
|
||||
]
|
||||
for p in self.processes:
|
||||
p.start()
|
||||
for c in self.child_remote:
|
||||
c.close()
|
||||
|
||||
def __getattr__(self, key):
|
||||
for p in self.parent_remote:
|
||||
p.send(['getattr', key])
|
||||
return [p.recv() for p in self.parent_remote]
|
||||
|
||||
def step(self,
|
||||
action: np.ndarray,
|
||||
id: Optional[Union[int, List[int]]] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
if id is None:
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
assert len(action) == len(id)
|
||||
for i, j in enumerate(id):
|
||||
self.parent_remote[j].send(['step', action[i]])
|
||||
result = [self.parent_remote[i].recv() for i in id]
|
||||
obs, rew, done, info = map(np.stack, zip(*result))
|
||||
return obs, rew, done, info
|
||||
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||
if id is None:
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
for i in id:
|
||||
self.parent_remote[i].send(['reset', None])
|
||||
obs = np.stack([self.parent_remote[i].recv() for i in id])
|
||||
return obs
|
||||
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
|
||||
if np.isscalar(seed):
|
||||
seed = [seed + _ for _ in range(self.env_num)]
|
||||
elif seed is None:
|
||||
seed = [seed] * self.env_num
|
||||
for p, s in zip(self.parent_remote, seed):
|
||||
p.send(['seed', s])
|
||||
return [p.recv() for p in self.parent_remote]
|
||||
|
||||
def render(self, **kwargs) -> List[Any]:
|
||||
for p in self.parent_remote:
|
||||
p.send(['render', kwargs])
|
||||
return [p.recv() for p in self.parent_remote]
|
||||
|
||||
def close(self) -> List[Any]:
|
||||
if self.closed:
|
||||
return []
|
||||
for p in self.parent_remote:
|
||||
p.send(['close', None])
|
||||
result = [p.recv() for p in self.parent_remote]
|
||||
self.closed = True
|
||||
for p in self.processes:
|
||||
p.join()
|
||||
return result
|
||||
|
||||
|
||||
class RayVectorEnv(BaseVectorEnv):
|
||||
"""Vectorized environment wrapper based on
|
||||
`ray <https://github.com/ray-project/ray>`_. However, according to our
|
||||
test, it is about two times slower than
|
||||
:class:`~tianshou.env.SubprocVectorEnv`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
||||
super().__init__(env_fns)
|
||||
try:
|
||||
if not ray.is_initialized():
|
||||
ray.init()
|
||||
except NameError:
|
||||
raise ImportError(
|
||||
'Please install ray to support RayVectorEnv: pip3 install ray')
|
||||
self.envs = [
|
||||
ray.remote(gym.Wrapper).options(num_cpus=0).remote(e())
|
||||
for e in env_fns]
|
||||
|
||||
def __getattr__(self, key):
|
||||
return ray.get([e.__getattr__.remote(key) for e in self.envs])
|
||||
|
||||
def step(self,
|
||||
action: np.ndarray,
|
||||
id: Optional[Union[int, List[int]]] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
if id is None:
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
assert len(action) == len(id)
|
||||
result = ray.get([self.envs[j].step.remote(action[i])
|
||||
for i, j in enumerate(id)])
|
||||
obs, rew, done, info = map(np.stack, zip(*result))
|
||||
return obs, rew, done, info
|
||||
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||
if id is None:
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
obs = np.stack(ray.get([self.envs[i].reset.remote() for i in id]))
|
||||
return obs
|
||||
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
|
||||
if not hasattr(self.envs[0], 'seed'):
|
||||
return []
|
||||
if np.isscalar(seed):
|
||||
seed = [seed + _ for _ in range(self.env_num)]
|
||||
elif seed is None:
|
||||
seed = [seed] * self.env_num
|
||||
return ray.get([e.seed.remote(s) for e, s in zip(self.envs, seed)])
|
||||
|
||||
def render(self, **kwargs) -> List[Any]:
|
||||
if not hasattr(self.envs[0], 'render'):
|
||||
return [None for e in self.envs]
|
||||
return ray.get([e.render.remote(**kwargs) for e in self.envs])
|
||||
|
||||
def close(self) -> List[Any]:
|
||||
return ray.get([e.close.remote() for e in self.envs])
|
0
tianshou/env/vecenv/__init__.py
vendored
Normal file
0
tianshou/env/vecenv/__init__.py
vendored
Normal file
104
tianshou/env/vecenv/asyncenv.py
vendored
Normal file
104
tianshou/env/vecenv/asyncenv.py
vendored
Normal file
@ -0,0 +1,104 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
from multiprocessing import connection
|
||||
from typing import List, Tuple, Union, Optional, Callable, Any
|
||||
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
|
||||
|
||||
class AsyncVectorEnv(SubprocVectorEnv):
|
||||
"""Vectorized asynchronous environment wrapper based on subprocess.
|
||||
|
||||
:param wait_num: used in asynchronous simulation if the time cost of
|
||||
``env.step`` varies with time and synchronously waiting for all
|
||||
environments to finish a step is time-wasting. In that case, we can
|
||||
return when ``wait_num`` environments finish a step and keep on
|
||||
simulation in these environments. If ``None``, asynchronous simulation
|
||||
is disabled; else, ``1 <= wait_num <= env_num``.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]],
|
||||
wait_num: Optional[int] = None) -> None:
|
||||
super().__init__(env_fns)
|
||||
self.wait_num = wait_num or len(env_fns)
|
||||
assert 1 <= self.wait_num <= len(env_fns), \
|
||||
f'wait_num should be in [1, {len(env_fns)}], but got {wait_num}'
|
||||
self.waiting_conn = []
|
||||
# environments in self.ready_id is actually ready
|
||||
# but environments in self.waiting_id are just waiting when checked,
|
||||
# and they may be ready now, but this is not known until we check it
|
||||
# in the step() function
|
||||
self.waiting_id = []
|
||||
# all environments are ready in the beginning
|
||||
self.ready_id = list(range(self.env_num))
|
||||
|
||||
def _assert_and_transform_id(self,
|
||||
id: Optional[Union[int, List[int]]] = None
|
||||
) -> List[int]:
|
||||
if id is None:
|
||||
id = list(range(self.env_num))
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
for i in id:
|
||||
assert i not in self.waiting_id, \
|
||||
f'Cannot reset environment {i} which is stepping now!'
|
||||
assert i in self.ready_id, \
|
||||
f'Can only reset ready environments {self.ready_id}.'
|
||||
return id
|
||||
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||
id = self._assert_and_transform_id(id)
|
||||
return super().reset(id)
|
||||
|
||||
def render(self, **kwargs) -> List[Any]:
|
||||
if len(self.waiting_id) > 0:
|
||||
raise RuntimeError(
|
||||
f"Environments {self.waiting_id} are still "
|
||||
f"stepping, cannot render them now.")
|
||||
return super().render(**kwargs)
|
||||
|
||||
def close(self) -> List[Any]:
|
||||
if self.closed:
|
||||
return []
|
||||
# finish remaining steps, and close
|
||||
self.step(None)
|
||||
return super().close()
|
||||
|
||||
def step(self,
|
||||
action: Optional[np.ndarray],
|
||||
id: Optional[Union[int, List[int]]] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Provide the given action to the environments. The action sequence
|
||||
should correspond to the ``id`` argument, and the ``id`` argument
|
||||
should be a subset of the ``env_id`` in the last returned ``info``
|
||||
(initially they are env_ids of all the environments). If action is
|
||||
``None``, fetch unfinished step() calls instead.
|
||||
"""
|
||||
if action is not None:
|
||||
id = self._assert_and_transform_id(id)
|
||||
assert len(action) == len(id)
|
||||
for i, (act, env_id) in enumerate(zip(action, id)):
|
||||
self.parent_remote[env_id].send(['step', act])
|
||||
self.waiting_conn.append(self.parent_remote[env_id])
|
||||
self.waiting_id.append(env_id)
|
||||
self.ready_id = [x for x in self.ready_id if x not in id]
|
||||
result = []
|
||||
while len(self.waiting_conn) > 0 and len(result) < self.wait_num:
|
||||
ready_conns = connection.wait(self.waiting_conn)
|
||||
for conn in ready_conns:
|
||||
waiting_index = self.waiting_conn.index(conn)
|
||||
self.waiting_conn.pop(waiting_index)
|
||||
env_id = self.waiting_id.pop(waiting_index)
|
||||
ans = conn.recv()
|
||||
obs, rew, done, info = ans
|
||||
info["env_id"] = env_id
|
||||
result.append((obs, rew, done, info))
|
||||
self.ready_id.append(env_id)
|
||||
obs, rew, done, info = map(np.stack, zip(*result))
|
||||
return obs, rew, done, info
|
65
tianshou/env/vecenv/dummy.py
vendored
Normal file
65
tianshou/env/vecenv/dummy.py
vendored
Normal file
@ -0,0 +1,65 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Union, Optional, Callable, Any
|
||||
|
||||
from tianshou.env import BaseVectorEnv
|
||||
|
||||
|
||||
class VectorEnv(BaseVectorEnv):
|
||||
"""Dummy vectorized environment wrapper, implemented in for-loop.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
||||
super().__init__(env_fns)
|
||||
self.envs = [_() for _ in env_fns]
|
||||
|
||||
def __getattr__(self, key):
|
||||
return [getattr(env, key) if hasattr(env, key) else None
|
||||
for env in self.envs]
|
||||
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||
if id is None:
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
obs = np.stack([self.envs[i].reset() for i in id])
|
||||
return obs
|
||||
|
||||
def step(self,
|
||||
action: np.ndarray,
|
||||
id: Optional[Union[int, List[int]]] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
if id is None:
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
assert len(action) == len(id)
|
||||
result = [self.envs[i].step(action[i]) for i in id]
|
||||
obs, rew, done, info = map(np.stack, zip(*result))
|
||||
return obs, rew, done, info
|
||||
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
|
||||
if np.isscalar(seed):
|
||||
seed = [seed + _ for _ in range(self.env_num)]
|
||||
elif seed is None:
|
||||
seed = [seed] * self.env_num
|
||||
result = []
|
||||
for e, s in zip(self.envs, seed):
|
||||
if hasattr(e, 'seed'):
|
||||
result.append(e.seed(s))
|
||||
return result
|
||||
|
||||
def render(self, **kwargs) -> List[Any]:
|
||||
result = []
|
||||
for e in self.envs:
|
||||
if hasattr(e, 'render'):
|
||||
result.append(e.render(**kwargs))
|
||||
return result
|
||||
|
||||
def close(self) -> List[Any]:
|
||||
return [e.close() for e in self.envs]
|
76
tianshou/env/vecenv/rayenv.py
vendored
Normal file
76
tianshou/env/vecenv/rayenv.py
vendored
Normal file
@ -0,0 +1,76 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Union, Optional, Callable, Any
|
||||
|
||||
try:
|
||||
import ray
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from tianshou.env import BaseVectorEnv
|
||||
|
||||
|
||||
class RayVectorEnv(BaseVectorEnv):
|
||||
"""Vectorized environment wrapper based on
|
||||
`ray <https://github.com/ray-project/ray>`_. This is a choice to run
|
||||
distributed environments in a cluster.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
||||
super().__init__(env_fns)
|
||||
try:
|
||||
if not ray.is_initialized():
|
||||
ray.init()
|
||||
except NameError:
|
||||
raise ImportError(
|
||||
'Please install ray to support RayVectorEnv: pip install ray')
|
||||
self.envs = [
|
||||
ray.remote(gym.Wrapper).options(num_cpus=0).remote(e())
|
||||
for e in env_fns]
|
||||
|
||||
def __getattr__(self, key):
|
||||
return ray.get([e.__getattr__.remote(key) for e in self.envs])
|
||||
|
||||
def step(self,
|
||||
action: np.ndarray,
|
||||
id: Optional[Union[int, List[int]]] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
if id is None:
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
assert len(action) == len(id)
|
||||
result = ray.get([self.envs[j].step.remote(action[i])
|
||||
for i, j in enumerate(id)])
|
||||
obs, rew, done, info = map(np.stack, zip(*result))
|
||||
return obs, rew, done, info
|
||||
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||
if id is None:
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
obs = np.stack(ray.get([self.envs[i].reset.remote() for i in id]))
|
||||
return obs
|
||||
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
|
||||
if not hasattr(self.envs[0], 'seed'):
|
||||
return []
|
||||
if np.isscalar(seed):
|
||||
seed = [seed + _ for _ in range(self.env_num)]
|
||||
elif seed is None:
|
||||
seed = [seed] * self.env_num
|
||||
return ray.get([e.seed.remote(s) for e, s in zip(self.envs, seed)])
|
||||
|
||||
def render(self, **kwargs) -> List[Any]:
|
||||
if not hasattr(self.envs[0], 'render'):
|
||||
return [None for e in self.envs]
|
||||
return ray.get([e.render.remote(**kwargs) for e in self.envs])
|
||||
|
||||
def close(self) -> List[Any]:
|
||||
return ray.get([e.close.remote() for e in self.envs])
|
115
tianshou/env/vecenv/subproc.py
vendored
Normal file
115
tianshou/env/vecenv/subproc.py
vendored
Normal file
@ -0,0 +1,115 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
from multiprocessing import Process, Pipe
|
||||
from typing import List, Tuple, Union, Optional, Callable, Any
|
||||
|
||||
from tianshou.env import BaseVectorEnv
|
||||
from tianshou.env.utils import CloudpickleWrapper
|
||||
|
||||
|
||||
def worker(parent, p, env_fn_wrapper):
|
||||
parent.close()
|
||||
env = env_fn_wrapper.data()
|
||||
try:
|
||||
while True:
|
||||
cmd, data = p.recv()
|
||||
if cmd == 'step':
|
||||
p.send(env.step(data))
|
||||
elif cmd == 'reset':
|
||||
p.send(env.reset())
|
||||
elif cmd == 'close':
|
||||
p.send(env.close())
|
||||
p.close()
|
||||
break
|
||||
elif cmd == 'render':
|
||||
p.send(env.render(**data) if hasattr(env, 'render') else None)
|
||||
elif cmd == 'seed':
|
||||
p.send(env.seed(data) if hasattr(env, 'seed') else None)
|
||||
elif cmd == 'getattr':
|
||||
p.send(getattr(env, data) if hasattr(env, data) else None)
|
||||
else:
|
||||
p.close()
|
||||
raise NotImplementedError
|
||||
except KeyboardInterrupt:
|
||||
p.close()
|
||||
|
||||
|
||||
class SubprocVectorEnv(BaseVectorEnv):
|
||||
"""Vectorized environment wrapper based on subprocess.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
||||
super().__init__(env_fns)
|
||||
self.closed = False
|
||||
self.parent_remote, self.child_remote = \
|
||||
zip(*[Pipe() for _ in range(self.env_num)])
|
||||
self.processes = [
|
||||
Process(target=worker, args=(
|
||||
parent, child, CloudpickleWrapper(env_fn)), daemon=True)
|
||||
for (parent, child, env_fn) in zip(
|
||||
self.parent_remote, self.child_remote, env_fns)
|
||||
]
|
||||
for p in self.processes:
|
||||
p.start()
|
||||
for c in self.child_remote:
|
||||
c.close()
|
||||
|
||||
def __getattr__(self, key):
|
||||
for p in self.parent_remote:
|
||||
p.send(['getattr', key])
|
||||
return [p.recv() for p in self.parent_remote]
|
||||
|
||||
def step(self,
|
||||
action: np.ndarray,
|
||||
id: Optional[Union[int, List[int]]] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
if id is None:
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
assert len(action) == len(id)
|
||||
for i, j in enumerate(id):
|
||||
self.parent_remote[j].send(['step', action[i]])
|
||||
result = [self.parent_remote[i].recv() for i in id]
|
||||
obs, rew, done, info = map(np.stack, zip(*result))
|
||||
return obs, rew, done, info
|
||||
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||
if id is None:
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
for i in id:
|
||||
self.parent_remote[i].send(['reset', None])
|
||||
obs = np.stack([self.parent_remote[i].recv() for i in id])
|
||||
return obs
|
||||
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
|
||||
if np.isscalar(seed):
|
||||
seed = [seed + _ for _ in range(self.env_num)]
|
||||
elif seed is None:
|
||||
seed = [seed] * self.env_num
|
||||
for p, s in zip(self.parent_remote, seed):
|
||||
p.send(['seed', s])
|
||||
return [p.recv() for p in self.parent_remote]
|
||||
|
||||
def render(self, **kwargs) -> List[Any]:
|
||||
for p in self.parent_remote:
|
||||
p.send(['render', kwargs])
|
||||
return [p.recv() for p in self.parent_remote]
|
||||
|
||||
def close(self) -> List[Any]:
|
||||
if self.closed:
|
||||
return []
|
||||
for p in self.parent_remote:
|
||||
p.send(['close', None])
|
||||
result = [p.recv() for p in self.parent_remote]
|
||||
self.closed = True
|
||||
for p in self.processes:
|
||||
p.join()
|
||||
return result
|
Loading…
x
Reference in New Issue
Block a user