From 268f9d0533a362b7f2d8d9888802d0c5a4f01209 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 20 Jun 2020 09:57:16 +0800 Subject: [PATCH] type signature correction (#83) --- tianshou/env/vecenv.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tianshou/env/vecenv.py b/tianshou/env/vecenv.py index 2ca7431..26e9835 100644 --- a/tianshou/env/vecenv.py +++ b/tianshou/env/vecenv.py @@ -2,7 +2,7 @@ import gym import numpy as np from abc import ABC, abstractmethod from multiprocessing import Process, Pipe -from typing import List, Tuple, Union, Optional, Callable +from typing import List, Tuple, Union, Optional, Callable, Any try: import ray @@ -141,7 +141,7 @@ class VectorEnv(BaseVectorEnv): return [getattr(env, key) if hasattr(env, key) else None for env in self.envs] - def reset(self, id: Optional[Union[int, List[int]]] = None) -> None: + def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray: if id is None: self._obs = np.stack([e.reset() for e in self.envs]) else: @@ -173,14 +173,14 @@ class VectorEnv(BaseVectorEnv): result.append(e.seed(s)) return result - def render(self, **kwargs) -> None: + def render(self, **kwargs) -> List[Any]: result = [] for e in self.envs: if hasattr(e, 'render'): result.append(e.render(**kwargs)) return result - def close(self) -> None: + def close(self) -> List[Any]: return [e.close() for e in self.envs] @@ -254,7 +254,7 @@ class SubprocVectorEnv(BaseVectorEnv): self._info = np.stack(self._info) return self._obs, self._rew, self._done, self._info - def reset(self, id: Optional[Union[int, List[int]]] = None) -> None: + def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray: if id is None: for p in self.parent_remote: p.send(['reset', None]) @@ -278,14 +278,14 @@ class SubprocVectorEnv(BaseVectorEnv): p.send(['seed', s]) return [p.recv() for p in self.parent_remote] - def render(self, **kwargs) -> None: + def render(self, **kwargs) -> List[Any]: for p in self.parent_remote: p.send(['render', kwargs]) return [p.recv() for p in self.parent_remote] - def close(self) -> None: + def close(self) -> List[Any]: if self.closed: - return + return [] for p in self.parent_remote: p.send(['close', None]) result = [p.recv() for p in self.parent_remote] @@ -333,7 +333,7 @@ class RayVectorEnv(BaseVectorEnv): self._info = np.stack(self._info) return self._obs, self._rew, self._done, self._info - def reset(self, id: Optional[Union[int, List[int]]] = None) -> None: + def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray: if id is None: result_obj = [e.reset.remote() for e in self.envs] self._obs = np.stack(ray.get(result_obj)) @@ -349,17 +349,17 @@ class RayVectorEnv(BaseVectorEnv): def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]: if not hasattr(self.envs[0], 'seed'): - return + return [] if np.isscalar(seed): seed = [seed + _ for _ in range(self.env_num)] elif seed is None: seed = [seed] * self.env_num return ray.get([e.seed.remote(s) for e, s in zip(self.envs, seed)]) - def render(self, **kwargs) -> None: + def render(self, **kwargs) -> List[Any]: if not hasattr(self.envs[0], 'render'): - return + return [None for e in self.envs] return ray.get([e.render.remote(**kwargs) for e in self.envs]) - def close(self) -> None: + def close(self) -> List[Any]: return ray.get([e.close.remote() for e in self.envs])