type signature correction (#83)

This commit is contained in:
youkaichao 2020-06-20 09:57:16 +08:00 committed by GitHub
parent 81e4a16ef2
commit 268f9d0533
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,7 +2,7 @@ import gym
import numpy as np import numpy as np
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from multiprocessing import Process, Pipe from multiprocessing import Process, Pipe
from typing import List, Tuple, Union, Optional, Callable from typing import List, Tuple, Union, Optional, Callable, Any
try: try:
import ray import ray
@ -141,7 +141,7 @@ class VectorEnv(BaseVectorEnv):
return [getattr(env, key) if hasattr(env, key) else None return [getattr(env, key) if hasattr(env, key) else None
for env in self.envs] for env in self.envs]
def reset(self, id: Optional[Union[int, List[int]]] = None) -> None: def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
if id is None: if id is None:
self._obs = np.stack([e.reset() for e in self.envs]) self._obs = np.stack([e.reset() for e in self.envs])
else: else:
@ -173,14 +173,14 @@ class VectorEnv(BaseVectorEnv):
result.append(e.seed(s)) result.append(e.seed(s))
return result return result
def render(self, **kwargs) -> None: def render(self, **kwargs) -> List[Any]:
result = [] result = []
for e in self.envs: for e in self.envs:
if hasattr(e, 'render'): if hasattr(e, 'render'):
result.append(e.render(**kwargs)) result.append(e.render(**kwargs))
return result return result
def close(self) -> None: def close(self) -> List[Any]:
return [e.close() for e in self.envs] return [e.close() for e in self.envs]
@ -254,7 +254,7 @@ class SubprocVectorEnv(BaseVectorEnv):
self._info = np.stack(self._info) self._info = np.stack(self._info)
return self._obs, self._rew, self._done, self._info return self._obs, self._rew, self._done, self._info
def reset(self, id: Optional[Union[int, List[int]]] = None) -> None: def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
if id is None: if id is None:
for p in self.parent_remote: for p in self.parent_remote:
p.send(['reset', None]) p.send(['reset', None])
@ -278,14 +278,14 @@ class SubprocVectorEnv(BaseVectorEnv):
p.send(['seed', s]) p.send(['seed', s])
return [p.recv() for p in self.parent_remote] return [p.recv() for p in self.parent_remote]
def render(self, **kwargs) -> None: def render(self, **kwargs) -> List[Any]:
for p in self.parent_remote: for p in self.parent_remote:
p.send(['render', kwargs]) p.send(['render', kwargs])
return [p.recv() for p in self.parent_remote] return [p.recv() for p in self.parent_remote]
def close(self) -> None: def close(self) -> List[Any]:
if self.closed: if self.closed:
return return []
for p in self.parent_remote: for p in self.parent_remote:
p.send(['close', None]) p.send(['close', None])
result = [p.recv() for p in self.parent_remote] result = [p.recv() for p in self.parent_remote]
@ -333,7 +333,7 @@ class RayVectorEnv(BaseVectorEnv):
self._info = np.stack(self._info) self._info = np.stack(self._info)
return self._obs, self._rew, self._done, self._info return self._obs, self._rew, self._done, self._info
def reset(self, id: Optional[Union[int, List[int]]] = None) -> None: def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
if id is None: if id is None:
result_obj = [e.reset.remote() for e in self.envs] result_obj = [e.reset.remote() for e in self.envs]
self._obs = np.stack(ray.get(result_obj)) self._obs = np.stack(ray.get(result_obj))
@ -349,17 +349,17 @@ class RayVectorEnv(BaseVectorEnv):
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]: def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
if not hasattr(self.envs[0], 'seed'): if not hasattr(self.envs[0], 'seed'):
return return []
if np.isscalar(seed): if np.isscalar(seed):
seed = [seed + _ for _ in range(self.env_num)] seed = [seed + _ for _ in range(self.env_num)]
elif seed is None: elif seed is None:
seed = [seed] * self.env_num seed = [seed] * self.env_num
return ray.get([e.seed.remote(s) for e, s in zip(self.envs, seed)]) return ray.get([e.seed.remote(s) for e, s in zip(self.envs, seed)])
def render(self, **kwargs) -> None: def render(self, **kwargs) -> List[Any]:
if not hasattr(self.envs[0], 'render'): if not hasattr(self.envs[0], 'render'):
return return [None for e in self.envs]
return ray.get([e.render.remote(**kwargs) for e in self.envs]) return ray.get([e.render.remote(**kwargs) for e in self.envs])
def close(self) -> None: def close(self) -> List[Any]:
return ray.get([e.close.remote() for e in self.envs]) return ray.get([e.close.remote() for e in self.envs])