Alexis DUBURCQ e024afab8c
Asynchronous sampling vector environment (#134)
Fix #103

Co-authored-by: youkaichao <youkaichao@126.com>
Co-authored-by: Trinkle23897 <463003665@qq.com>
2020-07-26 18:01:21 +08:00

77 lines
2.5 KiB
Python

import gym
import numpy as np
from typing import List, Tuple, Union, Optional, Callable, Any
try:
import ray
except ImportError:
pass
from tianshou.env import BaseVectorEnv
class RayVectorEnv(BaseVectorEnv):
"""Vectorized environment wrapper based on
`ray <https://github.com/ray-project/ray>`_. This is a choice to run
distributed environments in a cluster.
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
explanation.
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
super().__init__(env_fns)
try:
if not ray.is_initialized():
ray.init()
except NameError:
raise ImportError(
'Please install ray to support RayVectorEnv: pip install ray')
self.envs = [
ray.remote(gym.Wrapper).options(num_cpus=0).remote(e())
for e in env_fns]
def __getattr__(self, key):
return ray.get([e.__getattr__.remote(key) for e in self.envs])
def step(self,
action: np.ndarray,
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
assert len(action) == len(id)
result = ray.get([self.envs[j].step.remote(action[i])
for i, j in enumerate(id)])
obs, rew, done, info = map(np.stack, zip(*result))
return obs, rew, done, info
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
obs = np.stack(ray.get([self.envs[i].reset.remote() for i in id]))
return obs
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
if not hasattr(self.envs[0], 'seed'):
return []
if np.isscalar(seed):
seed = [seed + _ for _ in range(self.env_num)]
elif seed is None:
seed = [seed] * self.env_num
return ray.get([e.seed.remote(s) for e, s in zip(self.envs, seed)])
def render(self, **kwargs) -> List[Any]:
if not hasattr(self.envs[0], 'render'):
return [None for e in self.envs]
return ray.get([e.render.remote(**kwargs) for e in self.envs])
def close(self) -> List[Any]:
return ray.get([e.close.remote() for e in self.envs])