Asynchronous sampling vector environment (#134)
Fix #103 Co-authored-by: youkaichao <youkaichao@126.com> Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
parent
30368c29a6
commit
e024afab8c
@ -1,5 +1,6 @@
|
|||||||
import gym
|
import gym
|
||||||
import time
|
import time
|
||||||
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gym.spaces import Discrete, MultiDiscrete, Box
|
from gym.spaces import Discrete, MultiDiscrete, Box
|
||||||
|
|
||||||
@ -9,9 +10,10 @@ class MyTestEnv(gym.Env):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, size, sleep=0, dict_state=False, ma_rew=0,
|
def __init__(self, size, sleep=0, dict_state=False, ma_rew=0,
|
||||||
multidiscrete_action=False):
|
multidiscrete_action=False, random_sleep=False):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.sleep = sleep
|
self.sleep = sleep
|
||||||
|
self.random_sleep = random_sleep
|
||||||
self.dict_state = dict_state
|
self.dict_state = dict_state
|
||||||
self.ma_rew = ma_rew
|
self.ma_rew = ma_rew
|
||||||
self._md_action = multidiscrete_action
|
self._md_action = multidiscrete_action
|
||||||
@ -48,7 +50,9 @@ class MyTestEnv(gym.Env):
|
|||||||
if self.done:
|
if self.done:
|
||||||
raise ValueError('step after done !!!')
|
raise ValueError('step after done !!!')
|
||||||
if self.sleep > 0:
|
if self.sleep > 0:
|
||||||
time.sleep(self.sleep)
|
sleep_time = random.random() if self.random_sleep else 1
|
||||||
|
sleep_time *= self.sleep
|
||||||
|
time.sleep(sleep_time)
|
||||||
if self.index == self.size:
|
if self.index == self.size:
|
||||||
self.done = True
|
self.done = True
|
||||||
return self._get_dict_state(), self._get_reward(), self.done, {}
|
return self._get_dict_state(), self._get_reward(), self.done, {}
|
||||||
|
@ -2,7 +2,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
from tianshou.env import VectorEnv, SubprocVectorEnv, AsyncVectorEnv
|
||||||
from tianshou.data import Collector, Batch, ReplayBuffer
|
from tianshou.data import Collector, Batch, ReplayBuffer
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -103,6 +103,51 @@ def test_collector():
|
|||||||
c2.collect(n_episode=[1, 1, 1, 1], random=True)
|
c2.collect(n_episode=[1, 1, 1, 1], random=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_collector_with_async():
|
||||||
|
env_lens = [2, 3, 4, 5]
|
||||||
|
writer = SummaryWriter('log/async_collector')
|
||||||
|
logger = Logger(writer)
|
||||||
|
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.1, random_sleep=True)
|
||||||
|
for i in env_lens]
|
||||||
|
|
||||||
|
venv = AsyncVectorEnv(env_fns)
|
||||||
|
policy = MyPolicy()
|
||||||
|
c1 = Collector(policy, venv,
|
||||||
|
ReplayBuffer(size=1000, ignore_obs_next=False),
|
||||||
|
logger.preprocess_fn)
|
||||||
|
c1.collect(n_episode=10)
|
||||||
|
# check if the data in the buffer is chronological
|
||||||
|
# i.e. data in the buffer are full episodes, and each episode is
|
||||||
|
# returned by the same environment
|
||||||
|
env_id = c1.buffer.info['env_id']
|
||||||
|
size = len(c1.buffer)
|
||||||
|
obs = c1.buffer.obs[:size]
|
||||||
|
done = c1.buffer.done[:size]
|
||||||
|
print(env_id[:size])
|
||||||
|
print(obs)
|
||||||
|
obs_ground_truth = []
|
||||||
|
i = 0
|
||||||
|
while i < size:
|
||||||
|
# i is the start of an episode
|
||||||
|
if done[i]:
|
||||||
|
# this episode has one transition
|
||||||
|
assert env_lens[env_id[i]] == 1
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
j = i
|
||||||
|
while True:
|
||||||
|
j += 1
|
||||||
|
# in one episode, the environment id is the same
|
||||||
|
assert env_id[j] == env_id[i]
|
||||||
|
if done[j]:
|
||||||
|
break
|
||||||
|
j = j + 1 # j is the start of the next episode
|
||||||
|
assert j - i == env_lens[env_id[i]]
|
||||||
|
obs_ground_truth += list(range(j - i))
|
||||||
|
i = j
|
||||||
|
assert np.allclose(obs, obs_ground_truth)
|
||||||
|
|
||||||
|
|
||||||
def test_collector_with_dict_state():
|
def test_collector_with_dict_state():
|
||||||
env = MyTestEnv(size=5, sleep=0, dict_state=True)
|
env = MyTestEnv(size=5, sleep=0, dict_state=True)
|
||||||
policy = MyPolicy(dict_state=True)
|
policy = MyPolicy(dict_state=True)
|
||||||
@ -181,3 +226,4 @@ if __name__ == '__main__':
|
|||||||
test_collector()
|
test_collector()
|
||||||
test_collector_with_dict_state()
|
test_collector_with_dict_state()
|
||||||
test_collector_with_ma()
|
test_collector_with_ma()
|
||||||
|
test_collector_with_async()
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gym.spaces.discrete import Discrete
|
from gym.spaces.discrete import Discrete
|
||||||
from tianshou.env import VectorEnv, SubprocVectorEnv, RayVectorEnv
|
from tianshou.data import Batch
|
||||||
|
from tianshou.env import VectorEnv, SubprocVectorEnv, \
|
||||||
|
RayVectorEnv, AsyncVectorEnv
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from env import MyTestEnv
|
from env import MyTestEnv
|
||||||
@ -9,6 +11,43 @@ else: # pytest
|
|||||||
from test.base.env import MyTestEnv
|
from test.base.env import MyTestEnv
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_env(num=8, sleep=0.1):
|
||||||
|
# simplify the test case, just keep stepping
|
||||||
|
size = 10000
|
||||||
|
env_fns = [
|
||||||
|
lambda i=i: MyTestEnv(size=i, sleep=sleep, random_sleep=True)
|
||||||
|
for i in range(size, size + num)
|
||||||
|
]
|
||||||
|
v = AsyncVectorEnv(env_fns, wait_num=num // 2)
|
||||||
|
v.seed()
|
||||||
|
v.reset()
|
||||||
|
# for a random variable u ~ U[0, 1], let v = max{u1, u2, ..., un}
|
||||||
|
# P(v <= x) = x^n (0 <= x <= 1), pdf of v is nx^{n-1}
|
||||||
|
# expectation of v is n / (n + 1)
|
||||||
|
# for a synchronous environment, the following actions should take
|
||||||
|
# about 7 * sleep * num / (num + 1) seconds
|
||||||
|
# for AsyncVectorEnv, the analysis is complicated, but the time cost
|
||||||
|
# should be smaller
|
||||||
|
action_list = [1] * num + [0] * (num * 2) + [1] * (num * 4)
|
||||||
|
current_index_start = 0
|
||||||
|
action = action_list[:num]
|
||||||
|
env_ids = list(range(num))
|
||||||
|
o = []
|
||||||
|
spent_time = time.time()
|
||||||
|
while current_index_start < len(action_list):
|
||||||
|
A, B, C, D = v.step(action=action, id=env_ids)
|
||||||
|
b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D})
|
||||||
|
env_ids = b.info.env_id
|
||||||
|
o.append(b)
|
||||||
|
current_index_start += len(action)
|
||||||
|
action = action_list[current_index_start: current_index_start + len(A)]
|
||||||
|
spent_time = time.time() - spent_time
|
||||||
|
data = Batch.cat(o)
|
||||||
|
# assure 1/7 improvement
|
||||||
|
assert spent_time < 6.0 * sleep * num / (num + 1)
|
||||||
|
return spent_time, data
|
||||||
|
|
||||||
|
|
||||||
def test_vecenv(size=10, num=8, sleep=0.001):
|
def test_vecenv(size=10, num=8, sleep=0.001):
|
||||||
verbose = __name__ == '__main__'
|
verbose = __name__ == '__main__'
|
||||||
env_fns = [
|
env_fns = [
|
||||||
@ -60,3 +99,4 @@ def test_vecenv(size=10, num=8, sleep=0.001):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_vecenv()
|
test_vecenv()
|
||||||
|
test_async_env()
|
||||||
|
@ -5,10 +5,11 @@ import warnings
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Any, Dict, List, Union, Optional, Callable
|
from typing import Any, Dict, List, Union, Optional, Callable
|
||||||
|
|
||||||
from tianshou.env import BaseVectorEnv, VectorEnv
|
from tianshou.env import BaseVectorEnv, VectorEnv, AsyncVectorEnv
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.exploration import BaseNoise
|
from tianshou.exploration import BaseNoise
|
||||||
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
|
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
|
||||||
|
from tianshou.data.batch import _create_value
|
||||||
|
|
||||||
|
|
||||||
class Collector(object):
|
class Collector(object):
|
||||||
@ -96,6 +97,13 @@ class Collector(object):
|
|||||||
env = VectorEnv([lambda: env])
|
env = VectorEnv([lambda: env])
|
||||||
self.env = env
|
self.env = env
|
||||||
self.env_num = len(env)
|
self.env_num = len(env)
|
||||||
|
# environments that are available in step()
|
||||||
|
# this means all environments in synchronous simulation
|
||||||
|
# but only a subset of environments in asynchronous simulation
|
||||||
|
self._ready_env_ids = np.arange(self.env_num)
|
||||||
|
# self.async is a flag to indicate whether this collector works
|
||||||
|
# with asynchronous simulation
|
||||||
|
self.is_async = isinstance(env, AsyncVectorEnv)
|
||||||
# need cache buffers before storing in the main buffer
|
# need cache buffers before storing in the main buffer
|
||||||
self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)]
|
self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)]
|
||||||
self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
|
self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
|
||||||
@ -105,6 +113,9 @@ class Collector(object):
|
|||||||
self.process_fn = policy.process_fn
|
self.process_fn = policy.process_fn
|
||||||
self._action_noise = action_noise
|
self._action_noise = action_noise
|
||||||
self._rew_metric = reward_metric or Collector._default_rew_metric
|
self._rew_metric = reward_metric or Collector._default_rew_metric
|
||||||
|
# avoid creating attribute outside __init__
|
||||||
|
self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={},
|
||||||
|
obs_next={}, policy={})
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -139,6 +150,7 @@ class Collector(object):
|
|||||||
"""Reset all of the environment(s)' states and reset all of the cache
|
"""Reset all of the environment(s)' states and reset all of the cache
|
||||||
buffers (if need).
|
buffers (if need).
|
||||||
"""
|
"""
|
||||||
|
self._ready_env_ids = np.arange(self.env_num)
|
||||||
obs = self.env.reset()
|
obs = self.env.reset()
|
||||||
if self.preprocess_fn:
|
if self.preprocess_fn:
|
||||||
obs = self.preprocess_fn(obs=obs).get('obs', obs)
|
obs = self.preprocess_fn(obs=obs).get('obs', obs)
|
||||||
@ -159,7 +171,7 @@ class Collector(object):
|
|||||||
self.env.close()
|
self.env.close()
|
||||||
|
|
||||||
def _reset_state(self, id: Union[int, List[int]]) -> None:
|
def _reset_state(self, id: Union[int, List[int]]) -> None:
|
||||||
"""Reset self.data.state[id]."""
|
"""Reset the hidden state: self.data.state[id]."""
|
||||||
state = self.data.state # it is a reference
|
state = self.data.state # it is a reference
|
||||||
if isinstance(state, torch.Tensor):
|
if isinstance(state, torch.Tensor):
|
||||||
state[id].zero_()
|
state[id].zero_()
|
||||||
@ -207,6 +219,7 @@ class Collector(object):
|
|||||||
# episode of each environment
|
# episode of each environment
|
||||||
episode_count = np.zeros(self.env_num)
|
episode_count = np.zeros(self.env_num)
|
||||||
reward_total = 0.0
|
reward_total = 0.0
|
||||||
|
whole_data = Batch()
|
||||||
while True:
|
while True:
|
||||||
if step_count >= 100000 and episode_count.sum() == 0:
|
if step_count >= 100000 and episode_count.sum() == 0:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
@ -214,6 +227,15 @@ class Collector(object):
|
|||||||
'You should add a time limitation to your environment!',
|
'You should add a time limitation to your environment!',
|
||||||
Warning)
|
Warning)
|
||||||
|
|
||||||
|
if self.is_async:
|
||||||
|
# self.data are the data for all environments
|
||||||
|
# in async simulation, only a subset of data are disposed
|
||||||
|
# so we store the whole data in ``whole_data``, let self.data
|
||||||
|
# to be all the data available in ready environments, and
|
||||||
|
# finally set these back into all the data
|
||||||
|
whole_data = self.data
|
||||||
|
self.data = self.data[self._ready_env_ids]
|
||||||
|
|
||||||
# restore the state and the input data
|
# restore the state and the input data
|
||||||
last_state = self.data.state
|
last_state = self.data.state
|
||||||
if last_state.is_empty():
|
if last_state.is_empty():
|
||||||
@ -222,8 +244,16 @@ class Collector(object):
|
|||||||
|
|
||||||
# calculate the next action
|
# calculate the next action
|
||||||
if random:
|
if random:
|
||||||
|
if self.is_async:
|
||||||
|
# TODO self.env.action_space will invoke remote call for
|
||||||
|
# all environments, which may hang in async simulation.
|
||||||
|
# This can be avoided by using a random policy, but not
|
||||||
|
# in the collector level. Leave it as a future work.
|
||||||
|
raise RuntimeError("cannot use random "
|
||||||
|
"sampling in async simulation!")
|
||||||
|
spaces = self.env.action_space
|
||||||
result = Batch(
|
result = Batch(
|
||||||
act=[a.sample() for a in self.env.action_space])
|
act=[spaces[i].sample() for i in self._ready_env_ids])
|
||||||
else:
|
else:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
result = self.policy(self.data, last_state)
|
result = self.policy(self.data, last_state)
|
||||||
@ -243,8 +273,18 @@ class Collector(object):
|
|||||||
self.data.act += self._action_noise(self.data.act.shape)
|
self.data.act += self._action_noise(self.data.act.shape)
|
||||||
|
|
||||||
# step in env
|
# step in env
|
||||||
|
if not self.is_async:
|
||||||
obs_next, rew, done, info = self.env.step(self.data.act)
|
obs_next, rew, done, info = self.env.step(self.data.act)
|
||||||
|
else:
|
||||||
|
# store computed actions, states, etc
|
||||||
|
_batch_set_item(whole_data, self._ready_env_ids,
|
||||||
|
self.data, self.env_num)
|
||||||
|
# fetch finished data
|
||||||
|
obs_next, rew, done, info = self.env.step(
|
||||||
|
action=self.data.act, id=self._ready_env_ids)
|
||||||
|
self._ready_env_ids = np.array([i['env_id'] for i in info])
|
||||||
|
# get the stepped data
|
||||||
|
self.data = whole_data[self._ready_env_ids]
|
||||||
# move data to self.data
|
# move data to self.data
|
||||||
self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
|
self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
|
||||||
|
|
||||||
@ -256,9 +296,11 @@ class Collector(object):
|
|||||||
if self.preprocess_fn:
|
if self.preprocess_fn:
|
||||||
result = self.preprocess_fn(**self.data)
|
result = self.preprocess_fn(**self.data)
|
||||||
self.data.update(result)
|
self.data.update(result)
|
||||||
for i in range(self.env_num):
|
for j, i in enumerate(self._ready_env_ids):
|
||||||
self._cached_buf[i].add(**self.data[i])
|
# j is the index in current ready_env_ids
|
||||||
if self.data.done[i]:
|
# i is the index in all environments
|
||||||
|
self._cached_buf[i].add(**self.data[j])
|
||||||
|
if self.data.done[j]:
|
||||||
if n_step or np.isscalar(n_episode) or \
|
if n_step or np.isscalar(n_episode) or \
|
||||||
episode_count[i] < n_episode[i]:
|
episode_count[i] < n_episode[i]:
|
||||||
episode_count[i] += 1
|
episode_count[i] += 1
|
||||||
@ -267,17 +309,24 @@ class Collector(object):
|
|||||||
if self.buffer is not None:
|
if self.buffer is not None:
|
||||||
self.buffer.update(self._cached_buf[i])
|
self.buffer.update(self._cached_buf[i])
|
||||||
self._cached_buf[i].reset()
|
self._cached_buf[i].reset()
|
||||||
self._reset_state(i)
|
self._reset_state(j)
|
||||||
obs_next = self.data.obs_next
|
obs_next = self.data.obs_next
|
||||||
if sum(self.data.done):
|
if sum(self.data.done):
|
||||||
env_ind = np.where(self.data.done)[0]
|
env_ind_local = np.where(self.data.done)[0]
|
||||||
obs_reset = self.env.reset(env_ind)
|
env_ind_global = self._ready_env_ids[env_ind_local]
|
||||||
|
obs_reset = self.env.reset(env_ind_global)
|
||||||
if self.preprocess_fn:
|
if self.preprocess_fn:
|
||||||
obs_next[env_ind] = self.preprocess_fn(
|
obs_next[env_ind_local] = self.preprocess_fn(
|
||||||
obs=obs_reset).get('obs', obs_reset)
|
obs=obs_reset).get('obs', obs_reset)
|
||||||
else:
|
else:
|
||||||
obs_next[env_ind] = obs_reset
|
obs_next[env_ind_local] = obs_reset
|
||||||
self.data.obs = obs_next
|
self.data.obs = obs_next
|
||||||
|
if self.is_async:
|
||||||
|
# set data back
|
||||||
|
_batch_set_item(whole_data, self._ready_env_ids,
|
||||||
|
self.data, self.env_num)
|
||||||
|
# let self.data be the data in all environments again
|
||||||
|
self.data = whole_data
|
||||||
if n_step:
|
if n_step:
|
||||||
if step_count >= n_step:
|
if step_count >= n_step:
|
||||||
break
|
break
|
||||||
@ -320,3 +369,24 @@ class Collector(object):
|
|||||||
batch_data, indice = self.buffer.sample(batch_size)
|
batch_data, indice = self.buffer.sample(batch_size)
|
||||||
batch_data = self.process_fn(batch_data, self.buffer, indice)
|
batch_data = self.process_fn(batch_data, self.buffer, indice)
|
||||||
return batch_data
|
return batch_data
|
||||||
|
|
||||||
|
|
||||||
|
def _batch_set_item(source: Batch, indices: np.ndarray,
|
||||||
|
target: Batch, size: int):
|
||||||
|
# for any key chain k, there are three cases
|
||||||
|
# 1. source[k] is non-reserved, but target[k] does not exist or is reserved
|
||||||
|
# 2. source[k] does not exist or is reserved, but target[k] is non-reserved
|
||||||
|
# 3. both source[k] and target[k] is non-reserved
|
||||||
|
for k, v in target.items():
|
||||||
|
if not isinstance(v, Batch) or not v.is_empty():
|
||||||
|
# target[k] is non-reserved
|
||||||
|
vs = source.get(k, Batch())
|
||||||
|
if isinstance(vs, Batch) and vs.is_empty():
|
||||||
|
# case 2
|
||||||
|
# use __dict__ to avoid many type checks
|
||||||
|
source.__dict__[k] = _create_value(v[0], size)
|
||||||
|
else:
|
||||||
|
# target[k] is reserved
|
||||||
|
# case 1
|
||||||
|
continue
|
||||||
|
source.__dict__[k][indices] = v
|
||||||
|
8
tianshou/env/__init__.py
vendored
8
tianshou/env/__init__.py
vendored
@ -1,11 +1,15 @@
|
|||||||
from tianshou.env.basevecenv import BaseVectorEnv
|
from tianshou.env.vecenv.base import BaseVectorEnv
|
||||||
from tianshou.env.vecenv import VectorEnv, SubprocVectorEnv, RayVectorEnv
|
from tianshou.env.vecenv.dummy import VectorEnv
|
||||||
|
from tianshou.env.vecenv.subproc import SubprocVectorEnv
|
||||||
|
from tianshou.env.vecenv.asyncenv import AsyncVectorEnv
|
||||||
|
from tianshou.env.vecenv.rayenv import RayVectorEnv
|
||||||
from tianshou.env.maenv import MultiAgentEnv
|
from tianshou.env.maenv import MultiAgentEnv
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BaseVectorEnv',
|
'BaseVectorEnv',
|
||||||
'VectorEnv',
|
'VectorEnv',
|
||||||
'SubprocVectorEnv',
|
'SubprocVectorEnv',
|
||||||
|
'AsyncVectorEnv',
|
||||||
'RayVectorEnv',
|
'RayVectorEnv',
|
||||||
'MultiAgentEnv',
|
'MultiAgentEnv',
|
||||||
]
|
]
|
||||||
|
247
tianshou/env/vecenv.py
vendored
247
tianshou/env/vecenv.py
vendored
@ -1,247 +0,0 @@
|
|||||||
import gym
|
|
||||||
import numpy as np
|
|
||||||
from multiprocessing import Process, Pipe
|
|
||||||
from typing import List, Tuple, Union, Optional, Callable, Any
|
|
||||||
|
|
||||||
try:
|
|
||||||
import ray
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
from tianshou.env import BaseVectorEnv
|
|
||||||
from tianshou.env.utils import CloudpickleWrapper
|
|
||||||
|
|
||||||
|
|
||||||
class VectorEnv(BaseVectorEnv):
|
|
||||||
"""Dummy vectorized environment wrapper, implemented in for-loop.
|
|
||||||
|
|
||||||
.. seealso::
|
|
||||||
|
|
||||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
|
||||||
explanation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
|
||||||
super().__init__(env_fns)
|
|
||||||
self.envs = [_() for _ in env_fns]
|
|
||||||
|
|
||||||
def __getattr__(self, key):
|
|
||||||
return [getattr(env, key) if hasattr(env, key) else None
|
|
||||||
for env in self.envs]
|
|
||||||
|
|
||||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
|
||||||
if id is None:
|
|
||||||
id = range(self.env_num)
|
|
||||||
elif np.isscalar(id):
|
|
||||||
id = [id]
|
|
||||||
obs = np.stack([self.envs[i].reset() for i in id])
|
|
||||||
return obs
|
|
||||||
|
|
||||||
def step(self,
|
|
||||||
action: np.ndarray,
|
|
||||||
id: Optional[Union[int, List[int]]] = None
|
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
|
||||||
if id is None:
|
|
||||||
id = range(self.env_num)
|
|
||||||
elif np.isscalar(id):
|
|
||||||
id = [id]
|
|
||||||
assert len(action) == len(id)
|
|
||||||
result = [self.envs[i].step(action[i]) for i in id]
|
|
||||||
obs, rew, done, info = map(np.stack, zip(*result))
|
|
||||||
return obs, rew, done, info
|
|
||||||
|
|
||||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
|
|
||||||
if np.isscalar(seed):
|
|
||||||
seed = [seed + _ for _ in range(self.env_num)]
|
|
||||||
elif seed is None:
|
|
||||||
seed = [seed] * self.env_num
|
|
||||||
result = []
|
|
||||||
for e, s in zip(self.envs, seed):
|
|
||||||
if hasattr(e, 'seed'):
|
|
||||||
result.append(e.seed(s))
|
|
||||||
return result
|
|
||||||
|
|
||||||
def render(self, **kwargs) -> List[Any]:
|
|
||||||
result = []
|
|
||||||
for e in self.envs:
|
|
||||||
if hasattr(e, 'render'):
|
|
||||||
result.append(e.render(**kwargs))
|
|
||||||
return result
|
|
||||||
|
|
||||||
def close(self) -> List[Any]:
|
|
||||||
return [e.close() for e in self.envs]
|
|
||||||
|
|
||||||
|
|
||||||
def worker(parent, p, env_fn_wrapper):
|
|
||||||
parent.close()
|
|
||||||
env = env_fn_wrapper.data()
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
cmd, data = p.recv()
|
|
||||||
if cmd == 'step':
|
|
||||||
p.send(env.step(data))
|
|
||||||
elif cmd == 'reset':
|
|
||||||
p.send(env.reset())
|
|
||||||
elif cmd == 'close':
|
|
||||||
p.send(env.close())
|
|
||||||
p.close()
|
|
||||||
break
|
|
||||||
elif cmd == 'render':
|
|
||||||
p.send(env.render(**data) if hasattr(env, 'render') else None)
|
|
||||||
elif cmd == 'seed':
|
|
||||||
p.send(env.seed(data) if hasattr(env, 'seed') else None)
|
|
||||||
elif cmd == 'getattr':
|
|
||||||
p.send(getattr(env, data) if hasattr(env, data) else None)
|
|
||||||
else:
|
|
||||||
p.close()
|
|
||||||
raise NotImplementedError
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
p.close()
|
|
||||||
|
|
||||||
|
|
||||||
class SubprocVectorEnv(BaseVectorEnv):
|
|
||||||
"""Vectorized environment wrapper based on subprocess.
|
|
||||||
|
|
||||||
.. seealso::
|
|
||||||
|
|
||||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
|
||||||
explanation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
|
||||||
super().__init__(env_fns)
|
|
||||||
self.closed = False
|
|
||||||
self.parent_remote, self.child_remote = \
|
|
||||||
zip(*[Pipe() for _ in range(self.env_num)])
|
|
||||||
self.processes = [
|
|
||||||
Process(target=worker, args=(
|
|
||||||
parent, child, CloudpickleWrapper(env_fn)), daemon=True)
|
|
||||||
for (parent, child, env_fn) in zip(
|
|
||||||
self.parent_remote, self.child_remote, env_fns)
|
|
||||||
]
|
|
||||||
for p in self.processes:
|
|
||||||
p.start()
|
|
||||||
for c in self.child_remote:
|
|
||||||
c.close()
|
|
||||||
|
|
||||||
def __getattr__(self, key):
|
|
||||||
for p in self.parent_remote:
|
|
||||||
p.send(['getattr', key])
|
|
||||||
return [p.recv() for p in self.parent_remote]
|
|
||||||
|
|
||||||
def step(self,
|
|
||||||
action: np.ndarray,
|
|
||||||
id: Optional[Union[int, List[int]]] = None
|
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
|
||||||
if id is None:
|
|
||||||
id = range(self.env_num)
|
|
||||||
elif np.isscalar(id):
|
|
||||||
id = [id]
|
|
||||||
assert len(action) == len(id)
|
|
||||||
for i, j in enumerate(id):
|
|
||||||
self.parent_remote[j].send(['step', action[i]])
|
|
||||||
result = [self.parent_remote[i].recv() for i in id]
|
|
||||||
obs, rew, done, info = map(np.stack, zip(*result))
|
|
||||||
return obs, rew, done, info
|
|
||||||
|
|
||||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
|
||||||
if id is None:
|
|
||||||
id = range(self.env_num)
|
|
||||||
elif np.isscalar(id):
|
|
||||||
id = [id]
|
|
||||||
for i in id:
|
|
||||||
self.parent_remote[i].send(['reset', None])
|
|
||||||
obs = np.stack([self.parent_remote[i].recv() for i in id])
|
|
||||||
return obs
|
|
||||||
|
|
||||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
|
|
||||||
if np.isscalar(seed):
|
|
||||||
seed = [seed + _ for _ in range(self.env_num)]
|
|
||||||
elif seed is None:
|
|
||||||
seed = [seed] * self.env_num
|
|
||||||
for p, s in zip(self.parent_remote, seed):
|
|
||||||
p.send(['seed', s])
|
|
||||||
return [p.recv() for p in self.parent_remote]
|
|
||||||
|
|
||||||
def render(self, **kwargs) -> List[Any]:
|
|
||||||
for p in self.parent_remote:
|
|
||||||
p.send(['render', kwargs])
|
|
||||||
return [p.recv() for p in self.parent_remote]
|
|
||||||
|
|
||||||
def close(self) -> List[Any]:
|
|
||||||
if self.closed:
|
|
||||||
return []
|
|
||||||
for p in self.parent_remote:
|
|
||||||
p.send(['close', None])
|
|
||||||
result = [p.recv() for p in self.parent_remote]
|
|
||||||
self.closed = True
|
|
||||||
for p in self.processes:
|
|
||||||
p.join()
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class RayVectorEnv(BaseVectorEnv):
|
|
||||||
"""Vectorized environment wrapper based on
|
|
||||||
`ray <https://github.com/ray-project/ray>`_. However, according to our
|
|
||||||
test, it is about two times slower than
|
|
||||||
:class:`~tianshou.env.SubprocVectorEnv`.
|
|
||||||
|
|
||||||
.. seealso::
|
|
||||||
|
|
||||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
|
||||||
explanation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
|
||||||
super().__init__(env_fns)
|
|
||||||
try:
|
|
||||||
if not ray.is_initialized():
|
|
||||||
ray.init()
|
|
||||||
except NameError:
|
|
||||||
raise ImportError(
|
|
||||||
'Please install ray to support RayVectorEnv: pip3 install ray')
|
|
||||||
self.envs = [
|
|
||||||
ray.remote(gym.Wrapper).options(num_cpus=0).remote(e())
|
|
||||||
for e in env_fns]
|
|
||||||
|
|
||||||
def __getattr__(self, key):
|
|
||||||
return ray.get([e.__getattr__.remote(key) for e in self.envs])
|
|
||||||
|
|
||||||
def step(self,
|
|
||||||
action: np.ndarray,
|
|
||||||
id: Optional[Union[int, List[int]]] = None
|
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
|
||||||
if id is None:
|
|
||||||
id = range(self.env_num)
|
|
||||||
elif np.isscalar(id):
|
|
||||||
id = [id]
|
|
||||||
assert len(action) == len(id)
|
|
||||||
result = ray.get([self.envs[j].step.remote(action[i])
|
|
||||||
for i, j in enumerate(id)])
|
|
||||||
obs, rew, done, info = map(np.stack, zip(*result))
|
|
||||||
return obs, rew, done, info
|
|
||||||
|
|
||||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
|
||||||
if id is None:
|
|
||||||
id = range(self.env_num)
|
|
||||||
elif np.isscalar(id):
|
|
||||||
id = [id]
|
|
||||||
obs = np.stack(ray.get([self.envs[i].reset.remote() for i in id]))
|
|
||||||
return obs
|
|
||||||
|
|
||||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
|
|
||||||
if not hasattr(self.envs[0], 'seed'):
|
|
||||||
return []
|
|
||||||
if np.isscalar(seed):
|
|
||||||
seed = [seed + _ for _ in range(self.env_num)]
|
|
||||||
elif seed is None:
|
|
||||||
seed = [seed] * self.env_num
|
|
||||||
return ray.get([e.seed.remote(s) for e, s in zip(self.envs, seed)])
|
|
||||||
|
|
||||||
def render(self, **kwargs) -> List[Any]:
|
|
||||||
if not hasattr(self.envs[0], 'render'):
|
|
||||||
return [None for e in self.envs]
|
|
||||||
return ray.get([e.render.remote(**kwargs) for e in self.envs])
|
|
||||||
|
|
||||||
def close(self) -> List[Any]:
|
|
||||||
return ray.get([e.close.remote() for e in self.envs])
|
|
0
tianshou/env/vecenv/__init__.py
vendored
Normal file
0
tianshou/env/vecenv/__init__.py
vendored
Normal file
104
tianshou/env/vecenv/asyncenv.py
vendored
Normal file
104
tianshou/env/vecenv/asyncenv.py
vendored
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
from multiprocessing import connection
|
||||||
|
from typing import List, Tuple, Union, Optional, Callable, Any
|
||||||
|
|
||||||
|
from tianshou.env import SubprocVectorEnv
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncVectorEnv(SubprocVectorEnv):
|
||||||
|
"""Vectorized asynchronous environment wrapper based on subprocess.
|
||||||
|
|
||||||
|
:param wait_num: used in asynchronous simulation if the time cost of
|
||||||
|
``env.step`` varies with time and synchronously waiting for all
|
||||||
|
environments to finish a step is time-wasting. In that case, we can
|
||||||
|
return when ``wait_num`` environments finish a step and keep on
|
||||||
|
simulation in these environments. If ``None``, asynchronous simulation
|
||||||
|
is disabled; else, ``1 <= wait_num <= env_num``.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||||
|
explanation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env_fns: List[Callable[[], gym.Env]],
|
||||||
|
wait_num: Optional[int] = None) -> None:
|
||||||
|
super().__init__(env_fns)
|
||||||
|
self.wait_num = wait_num or len(env_fns)
|
||||||
|
assert 1 <= self.wait_num <= len(env_fns), \
|
||||||
|
f'wait_num should be in [1, {len(env_fns)}], but got {wait_num}'
|
||||||
|
self.waiting_conn = []
|
||||||
|
# environments in self.ready_id is actually ready
|
||||||
|
# but environments in self.waiting_id are just waiting when checked,
|
||||||
|
# and they may be ready now, but this is not known until we check it
|
||||||
|
# in the step() function
|
||||||
|
self.waiting_id = []
|
||||||
|
# all environments are ready in the beginning
|
||||||
|
self.ready_id = list(range(self.env_num))
|
||||||
|
|
||||||
|
def _assert_and_transform_id(self,
|
||||||
|
id: Optional[Union[int, List[int]]] = None
|
||||||
|
) -> List[int]:
|
||||||
|
if id is None:
|
||||||
|
id = list(range(self.env_num))
|
||||||
|
elif np.isscalar(id):
|
||||||
|
id = [id]
|
||||||
|
for i in id:
|
||||||
|
assert i not in self.waiting_id, \
|
||||||
|
f'Cannot reset environment {i} which is stepping now!'
|
||||||
|
assert i in self.ready_id, \
|
||||||
|
f'Can only reset ready environments {self.ready_id}.'
|
||||||
|
return id
|
||||||
|
|
||||||
|
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||||
|
id = self._assert_and_transform_id(id)
|
||||||
|
return super().reset(id)
|
||||||
|
|
||||||
|
def render(self, **kwargs) -> List[Any]:
|
||||||
|
if len(self.waiting_id) > 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Environments {self.waiting_id} are still "
|
||||||
|
f"stepping, cannot render them now.")
|
||||||
|
return super().render(**kwargs)
|
||||||
|
|
||||||
|
def close(self) -> List[Any]:
|
||||||
|
if self.closed:
|
||||||
|
return []
|
||||||
|
# finish remaining steps, and close
|
||||||
|
self.step(None)
|
||||||
|
return super().close()
|
||||||
|
|
||||||
|
def step(self,
|
||||||
|
action: Optional[np.ndarray],
|
||||||
|
id: Optional[Union[int, List[int]]] = None
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Provide the given action to the environments. The action sequence
|
||||||
|
should correspond to the ``id`` argument, and the ``id`` argument
|
||||||
|
should be a subset of the ``env_id`` in the last returned ``info``
|
||||||
|
(initially they are env_ids of all the environments). If action is
|
||||||
|
``None``, fetch unfinished step() calls instead.
|
||||||
|
"""
|
||||||
|
if action is not None:
|
||||||
|
id = self._assert_and_transform_id(id)
|
||||||
|
assert len(action) == len(id)
|
||||||
|
for i, (act, env_id) in enumerate(zip(action, id)):
|
||||||
|
self.parent_remote[env_id].send(['step', act])
|
||||||
|
self.waiting_conn.append(self.parent_remote[env_id])
|
||||||
|
self.waiting_id.append(env_id)
|
||||||
|
self.ready_id = [x for x in self.ready_id if x not in id]
|
||||||
|
result = []
|
||||||
|
while len(self.waiting_conn) > 0 and len(result) < self.wait_num:
|
||||||
|
ready_conns = connection.wait(self.waiting_conn)
|
||||||
|
for conn in ready_conns:
|
||||||
|
waiting_index = self.waiting_conn.index(conn)
|
||||||
|
self.waiting_conn.pop(waiting_index)
|
||||||
|
env_id = self.waiting_id.pop(waiting_index)
|
||||||
|
ans = conn.recv()
|
||||||
|
obs, rew, done, info = ans
|
||||||
|
info["env_id"] = env_id
|
||||||
|
result.append((obs, rew, done, info))
|
||||||
|
self.ready_id.append(env_id)
|
||||||
|
obs, rew, done, info = map(np.stack, zip(*result))
|
||||||
|
return obs, rew, done, info
|
65
tianshou/env/vecenv/dummy.py
vendored
Normal file
65
tianshou/env/vecenv/dummy.py
vendored
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Tuple, Union, Optional, Callable, Any
|
||||||
|
|
||||||
|
from tianshou.env import BaseVectorEnv
|
||||||
|
|
||||||
|
|
||||||
|
class VectorEnv(BaseVectorEnv):
|
||||||
|
"""Dummy vectorized environment wrapper, implemented in for-loop.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||||
|
explanation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
||||||
|
super().__init__(env_fns)
|
||||||
|
self.envs = [_() for _ in env_fns]
|
||||||
|
|
||||||
|
def __getattr__(self, key):
|
||||||
|
return [getattr(env, key) if hasattr(env, key) else None
|
||||||
|
for env in self.envs]
|
||||||
|
|
||||||
|
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||||
|
if id is None:
|
||||||
|
id = range(self.env_num)
|
||||||
|
elif np.isscalar(id):
|
||||||
|
id = [id]
|
||||||
|
obs = np.stack([self.envs[i].reset() for i in id])
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def step(self,
|
||||||
|
action: np.ndarray,
|
||||||
|
id: Optional[Union[int, List[int]]] = None
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
if id is None:
|
||||||
|
id = range(self.env_num)
|
||||||
|
elif np.isscalar(id):
|
||||||
|
id = [id]
|
||||||
|
assert len(action) == len(id)
|
||||||
|
result = [self.envs[i].step(action[i]) for i in id]
|
||||||
|
obs, rew, done, info = map(np.stack, zip(*result))
|
||||||
|
return obs, rew, done, info
|
||||||
|
|
||||||
|
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
|
||||||
|
if np.isscalar(seed):
|
||||||
|
seed = [seed + _ for _ in range(self.env_num)]
|
||||||
|
elif seed is None:
|
||||||
|
seed = [seed] * self.env_num
|
||||||
|
result = []
|
||||||
|
for e, s in zip(self.envs, seed):
|
||||||
|
if hasattr(e, 'seed'):
|
||||||
|
result.append(e.seed(s))
|
||||||
|
return result
|
||||||
|
|
||||||
|
def render(self, **kwargs) -> List[Any]:
|
||||||
|
result = []
|
||||||
|
for e in self.envs:
|
||||||
|
if hasattr(e, 'render'):
|
||||||
|
result.append(e.render(**kwargs))
|
||||||
|
return result
|
||||||
|
|
||||||
|
def close(self) -> List[Any]:
|
||||||
|
return [e.close() for e in self.envs]
|
76
tianshou/env/vecenv/rayenv.py
vendored
Normal file
76
tianshou/env/vecenv/rayenv.py
vendored
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Tuple, Union, Optional, Callable, Any
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ray
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
from tianshou.env import BaseVectorEnv
|
||||||
|
|
||||||
|
|
||||||
|
class RayVectorEnv(BaseVectorEnv):
|
||||||
|
"""Vectorized environment wrapper based on
|
||||||
|
`ray <https://github.com/ray-project/ray>`_. This is a choice to run
|
||||||
|
distributed environments in a cluster.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||||
|
explanation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
||||||
|
super().__init__(env_fns)
|
||||||
|
try:
|
||||||
|
if not ray.is_initialized():
|
||||||
|
ray.init()
|
||||||
|
except NameError:
|
||||||
|
raise ImportError(
|
||||||
|
'Please install ray to support RayVectorEnv: pip install ray')
|
||||||
|
self.envs = [
|
||||||
|
ray.remote(gym.Wrapper).options(num_cpus=0).remote(e())
|
||||||
|
for e in env_fns]
|
||||||
|
|
||||||
|
def __getattr__(self, key):
|
||||||
|
return ray.get([e.__getattr__.remote(key) for e in self.envs])
|
||||||
|
|
||||||
|
def step(self,
|
||||||
|
action: np.ndarray,
|
||||||
|
id: Optional[Union[int, List[int]]] = None
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
if id is None:
|
||||||
|
id = range(self.env_num)
|
||||||
|
elif np.isscalar(id):
|
||||||
|
id = [id]
|
||||||
|
assert len(action) == len(id)
|
||||||
|
result = ray.get([self.envs[j].step.remote(action[i])
|
||||||
|
for i, j in enumerate(id)])
|
||||||
|
obs, rew, done, info = map(np.stack, zip(*result))
|
||||||
|
return obs, rew, done, info
|
||||||
|
|
||||||
|
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||||
|
if id is None:
|
||||||
|
id = range(self.env_num)
|
||||||
|
elif np.isscalar(id):
|
||||||
|
id = [id]
|
||||||
|
obs = np.stack(ray.get([self.envs[i].reset.remote() for i in id]))
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
|
||||||
|
if not hasattr(self.envs[0], 'seed'):
|
||||||
|
return []
|
||||||
|
if np.isscalar(seed):
|
||||||
|
seed = [seed + _ for _ in range(self.env_num)]
|
||||||
|
elif seed is None:
|
||||||
|
seed = [seed] * self.env_num
|
||||||
|
return ray.get([e.seed.remote(s) for e, s in zip(self.envs, seed)])
|
||||||
|
|
||||||
|
def render(self, **kwargs) -> List[Any]:
|
||||||
|
if not hasattr(self.envs[0], 'render'):
|
||||||
|
return [None for e in self.envs]
|
||||||
|
return ray.get([e.render.remote(**kwargs) for e in self.envs])
|
||||||
|
|
||||||
|
def close(self) -> List[Any]:
|
||||||
|
return ray.get([e.close.remote() for e in self.envs])
|
115
tianshou/env/vecenv/subproc.py
vendored
Normal file
115
tianshou/env/vecenv/subproc.py
vendored
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
from multiprocessing import Process, Pipe
|
||||||
|
from typing import List, Tuple, Union, Optional, Callable, Any
|
||||||
|
|
||||||
|
from tianshou.env import BaseVectorEnv
|
||||||
|
from tianshou.env.utils import CloudpickleWrapper
|
||||||
|
|
||||||
|
|
||||||
|
def worker(parent, p, env_fn_wrapper):
|
||||||
|
parent.close()
|
||||||
|
env = env_fn_wrapper.data()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
cmd, data = p.recv()
|
||||||
|
if cmd == 'step':
|
||||||
|
p.send(env.step(data))
|
||||||
|
elif cmd == 'reset':
|
||||||
|
p.send(env.reset())
|
||||||
|
elif cmd == 'close':
|
||||||
|
p.send(env.close())
|
||||||
|
p.close()
|
||||||
|
break
|
||||||
|
elif cmd == 'render':
|
||||||
|
p.send(env.render(**data) if hasattr(env, 'render') else None)
|
||||||
|
elif cmd == 'seed':
|
||||||
|
p.send(env.seed(data) if hasattr(env, 'seed') else None)
|
||||||
|
elif cmd == 'getattr':
|
||||||
|
p.send(getattr(env, data) if hasattr(env, data) else None)
|
||||||
|
else:
|
||||||
|
p.close()
|
||||||
|
raise NotImplementedError
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
p.close()
|
||||||
|
|
||||||
|
|
||||||
|
class SubprocVectorEnv(BaseVectorEnv):
|
||||||
|
"""Vectorized environment wrapper based on subprocess.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||||
|
explanation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
||||||
|
super().__init__(env_fns)
|
||||||
|
self.closed = False
|
||||||
|
self.parent_remote, self.child_remote = \
|
||||||
|
zip(*[Pipe() for _ in range(self.env_num)])
|
||||||
|
self.processes = [
|
||||||
|
Process(target=worker, args=(
|
||||||
|
parent, child, CloudpickleWrapper(env_fn)), daemon=True)
|
||||||
|
for (parent, child, env_fn) in zip(
|
||||||
|
self.parent_remote, self.child_remote, env_fns)
|
||||||
|
]
|
||||||
|
for p in self.processes:
|
||||||
|
p.start()
|
||||||
|
for c in self.child_remote:
|
||||||
|
c.close()
|
||||||
|
|
||||||
|
def __getattr__(self, key):
|
||||||
|
for p in self.parent_remote:
|
||||||
|
p.send(['getattr', key])
|
||||||
|
return [p.recv() for p in self.parent_remote]
|
||||||
|
|
||||||
|
def step(self,
|
||||||
|
action: np.ndarray,
|
||||||
|
id: Optional[Union[int, List[int]]] = None
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
if id is None:
|
||||||
|
id = range(self.env_num)
|
||||||
|
elif np.isscalar(id):
|
||||||
|
id = [id]
|
||||||
|
assert len(action) == len(id)
|
||||||
|
for i, j in enumerate(id):
|
||||||
|
self.parent_remote[j].send(['step', action[i]])
|
||||||
|
result = [self.parent_remote[i].recv() for i in id]
|
||||||
|
obs, rew, done, info = map(np.stack, zip(*result))
|
||||||
|
return obs, rew, done, info
|
||||||
|
|
||||||
|
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||||
|
if id is None:
|
||||||
|
id = range(self.env_num)
|
||||||
|
elif np.isscalar(id):
|
||||||
|
id = [id]
|
||||||
|
for i in id:
|
||||||
|
self.parent_remote[i].send(['reset', None])
|
||||||
|
obs = np.stack([self.parent_remote[i].recv() for i in id])
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
|
||||||
|
if np.isscalar(seed):
|
||||||
|
seed = [seed + _ for _ in range(self.env_num)]
|
||||||
|
elif seed is None:
|
||||||
|
seed = [seed] * self.env_num
|
||||||
|
for p, s in zip(self.parent_remote, seed):
|
||||||
|
p.send(['seed', s])
|
||||||
|
return [p.recv() for p in self.parent_remote]
|
||||||
|
|
||||||
|
def render(self, **kwargs) -> List[Any]:
|
||||||
|
for p in self.parent_remote:
|
||||||
|
p.send(['render', kwargs])
|
||||||
|
return [p.recv() for p in self.parent_remote]
|
||||||
|
|
||||||
|
def close(self) -> List[Any]:
|
||||||
|
if self.closed:
|
||||||
|
return []
|
||||||
|
for p in self.parent_remote:
|
||||||
|
p.send(['close', None])
|
||||||
|
result = [p.recv() for p in self.parent_remote]
|
||||||
|
self.closed = True
|
||||||
|
for p in self.processes:
|
||||||
|
p.join()
|
||||||
|
return result
|
Loading…
x
Reference in New Issue
Block a user