Tianshou/tianshou/env/vecenv/asyncenv.py

105 lines
4.3 KiB
Python
Raw Normal View History

import gym
import numpy as np
from multiprocessing import connection
from typing import List, Tuple, Union, Optional, Callable, Any
from tianshou.env import SubprocVectorEnv
class AsyncVectorEnv(SubprocVectorEnv):
"""Vectorized asynchronous environment wrapper based on subprocess.
:param wait_num: used in asynchronous simulation if the time cost of
``env.step`` varies with time and synchronously waiting for all
environments to finish a step is time-wasting. In that case, we can
return when ``wait_num`` environments finish a step and keep on
simulation in these environments. If ``None``, asynchronous simulation
is disabled; else, ``1 <= wait_num <= env_num``.
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
explanation.
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]],
wait_num: Optional[int] = None) -> None:
super().__init__(env_fns)
self.wait_num = wait_num or len(env_fns)
assert 1 <= self.wait_num <= len(env_fns), \
f'wait_num should be in [1, {len(env_fns)}], but got {wait_num}'
self.waiting_conn = []
# environments in self.ready_id is actually ready
# but environments in self.waiting_id are just waiting when checked,
# and they may be ready now, but this is not known until we check it
# in the step() function
self.waiting_id = []
# all environments are ready in the beginning
self.ready_id = list(range(self.env_num))
def _assert_and_transform_id(self,
id: Optional[Union[int, List[int]]] = None
) -> List[int]:
if id is None:
id = list(range(self.env_num))
elif np.isscalar(id):
id = [id]
for i in id:
assert i not in self.waiting_id, \
f'Cannot reset environment {i} which is stepping now!'
assert i in self.ready_id, \
f'Can only reset ready environments {self.ready_id}.'
return id
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
id = self._assert_and_transform_id(id)
return super().reset(id)
def render(self, **kwargs) -> List[Any]:
if len(self.waiting_id) > 0:
raise RuntimeError(
f"Environments {self.waiting_id} are still "
f"stepping, cannot render them now.")
return super().render(**kwargs)
def close(self) -> List[Any]:
if self.closed:
return []
# finish remaining steps, and close
self.step(None)
return super().close()
def step(self,
action: Optional[np.ndarray],
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Provide the given action to the environments. The action sequence
should correspond to the ``id`` argument, and the ``id`` argument
should be a subset of the ``env_id`` in the last returned ``info``
(initially they are env_ids of all the environments). If action is
``None``, fetch unfinished step() calls instead.
"""
if action is not None:
id = self._assert_and_transform_id(id)
assert len(action) == len(id)
for i, (act, env_id) in enumerate(zip(action, id)):
self.parent_remote[env_id].send(['step', act])
self.waiting_conn.append(self.parent_remote[env_id])
self.waiting_id.append(env_id)
self.ready_id = [x for x in self.ready_id if x not in id]
result = []
while len(self.waiting_conn) > 0 and len(result) < self.wait_num:
ready_conns = connection.wait(self.waiting_conn)
for conn in ready_conns:
waiting_index = self.waiting_conn.index(conn)
self.waiting_conn.pop(waiting_index)
env_id = self.waiting_id.pop(waiting_index)
ans = conn.recv()
obs, rew, done, info = ans
info["env_id"] = env_id
result.append((obs, rew, done, info))
self.ready_id.append(env_id)
obs, rew, done, info = map(np.stack, zip(*result))
return obs, rew, done, info