Tianshou/tianshou/env/vecenv/asyncenv.py
Alexis DUBURCQ e024afab8c
Asynchronous sampling vector environment (#134)
Fix #103

Co-authored-by: youkaichao <youkaichao@126.com>
Co-authored-by: Trinkle23897 <463003665@qq.com>
2020-07-26 18:01:21 +08:00

105 lines
4.3 KiB
Python

import gym
import numpy as np
from multiprocessing import connection
from typing import List, Tuple, Union, Optional, Callable, Any
from tianshou.env import SubprocVectorEnv
class AsyncVectorEnv(SubprocVectorEnv):
"""Vectorized asynchronous environment wrapper based on subprocess.
:param wait_num: used in asynchronous simulation if the time cost of
``env.step`` varies with time and synchronously waiting for all
environments to finish a step is time-wasting. In that case, we can
return when ``wait_num`` environments finish a step and keep on
simulation in these environments. If ``None``, asynchronous simulation
is disabled; else, ``1 <= wait_num <= env_num``.
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
explanation.
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]],
wait_num: Optional[int] = None) -> None:
super().__init__(env_fns)
self.wait_num = wait_num or len(env_fns)
assert 1 <= self.wait_num <= len(env_fns), \
f'wait_num should be in [1, {len(env_fns)}], but got {wait_num}'
self.waiting_conn = []
# environments in self.ready_id is actually ready
# but environments in self.waiting_id are just waiting when checked,
# and they may be ready now, but this is not known until we check it
# in the step() function
self.waiting_id = []
# all environments are ready in the beginning
self.ready_id = list(range(self.env_num))
def _assert_and_transform_id(self,
id: Optional[Union[int, List[int]]] = None
) -> List[int]:
if id is None:
id = list(range(self.env_num))
elif np.isscalar(id):
id = [id]
for i in id:
assert i not in self.waiting_id, \
f'Cannot reset environment {i} which is stepping now!'
assert i in self.ready_id, \
f'Can only reset ready environments {self.ready_id}.'
return id
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
id = self._assert_and_transform_id(id)
return super().reset(id)
def render(self, **kwargs) -> List[Any]:
if len(self.waiting_id) > 0:
raise RuntimeError(
f"Environments {self.waiting_id} are still "
f"stepping, cannot render them now.")
return super().render(**kwargs)
def close(self) -> List[Any]:
if self.closed:
return []
# finish remaining steps, and close
self.step(None)
return super().close()
def step(self,
action: Optional[np.ndarray],
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Provide the given action to the environments. The action sequence
should correspond to the ``id`` argument, and the ``id`` argument
should be a subset of the ``env_id`` in the last returned ``info``
(initially they are env_ids of all the environments). If action is
``None``, fetch unfinished step() calls instead.
"""
if action is not None:
id = self._assert_and_transform_id(id)
assert len(action) == len(id)
for i, (act, env_id) in enumerate(zip(action, id)):
self.parent_remote[env_id].send(['step', act])
self.waiting_conn.append(self.parent_remote[env_id])
self.waiting_id.append(env_id)
self.ready_id = [x for x in self.ready_id if x not in id]
result = []
while len(self.waiting_conn) > 0 and len(result) < self.wait_num:
ready_conns = connection.wait(self.waiting_conn)
for conn in ready_conns:
waiting_index = self.waiting_conn.index(conn)
self.waiting_conn.pop(waiting_index)
env_id = self.waiting_id.pop(waiting_index)
ans = conn.recv()
obs, rew, done, info = ans
info["env_id"] = env_id
result.append((obs, rew, done, info))
self.ready_id.append(env_id)
obs, rew, done, info = map(np.stack, zip(*result))
return obs, rew, done, info