type signature correction (#83)

This commit is contained in:
youkaichao 2020-06-20 09:57:16 +08:00 committed by GitHub
parent 81e4a16ef2
commit 268f9d0533
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,7 +2,7 @@ import gym
import numpy as np
from abc import ABC, abstractmethod
from multiprocessing import Process, Pipe
from typing import List, Tuple, Union, Optional, Callable
from typing import List, Tuple, Union, Optional, Callable, Any
try:
import ray
@ -141,7 +141,7 @@ class VectorEnv(BaseVectorEnv):
return [getattr(env, key) if hasattr(env, key) else None
for env in self.envs]
def reset(self, id: Optional[Union[int, List[int]]] = None) -> None:
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
if id is None:
self._obs = np.stack([e.reset() for e in self.envs])
else:
@ -173,14 +173,14 @@ class VectorEnv(BaseVectorEnv):
result.append(e.seed(s))
return result
def render(self, **kwargs) -> None:
def render(self, **kwargs) -> List[Any]:
result = []
for e in self.envs:
if hasattr(e, 'render'):
result.append(e.render(**kwargs))
return result
def close(self) -> None:
def close(self) -> List[Any]:
return [e.close() for e in self.envs]
@ -254,7 +254,7 @@ class SubprocVectorEnv(BaseVectorEnv):
self._info = np.stack(self._info)
return self._obs, self._rew, self._done, self._info
def reset(self, id: Optional[Union[int, List[int]]] = None) -> None:
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
if id is None:
for p in self.parent_remote:
p.send(['reset', None])
@ -278,14 +278,14 @@ class SubprocVectorEnv(BaseVectorEnv):
p.send(['seed', s])
return [p.recv() for p in self.parent_remote]
def render(self, **kwargs) -> None:
def render(self, **kwargs) -> List[Any]:
for p in self.parent_remote:
p.send(['render', kwargs])
return [p.recv() for p in self.parent_remote]
def close(self) -> None:
def close(self) -> List[Any]:
if self.closed:
return
return []
for p in self.parent_remote:
p.send(['close', None])
result = [p.recv() for p in self.parent_remote]
@ -333,7 +333,7 @@ class RayVectorEnv(BaseVectorEnv):
self._info = np.stack(self._info)
return self._obs, self._rew, self._done, self._info
def reset(self, id: Optional[Union[int, List[int]]] = None) -> None:
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
if id is None:
result_obj = [e.reset.remote() for e in self.envs]
self._obs = np.stack(ray.get(result_obj))
@ -349,17 +349,17 @@ class RayVectorEnv(BaseVectorEnv):
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
if not hasattr(self.envs[0], 'seed'):
return
return []
if np.isscalar(seed):
seed = [seed + _ for _ in range(self.env_num)]
elif seed is None:
seed = [seed] * self.env_num
return ray.get([e.seed.remote(s) for e, s in zip(self.envs, seed)])
def render(self, **kwargs) -> None:
def render(self, **kwargs) -> List[Any]:
if not hasattr(self.envs[0], 'render'):
return
return [None for e in self.envs]
return ray.get([e.render.remote(**kwargs) for e in self.envs])
def close(self) -> None:
def close(self) -> List[Any]:
return ray.get([e.close.remote() for e in self.envs])