type signature correction (#83)
This commit is contained in:
parent
81e4a16ef2
commit
268f9d0533
26
tianshou/env/vecenv.py
vendored
26
tianshou/env/vecenv.py
vendored
@ -2,7 +2,7 @@ import gym
|
||||
import numpy as np
|
||||
from abc import ABC, abstractmethod
|
||||
from multiprocessing import Process, Pipe
|
||||
from typing import List, Tuple, Union, Optional, Callable
|
||||
from typing import List, Tuple, Union, Optional, Callable, Any
|
||||
|
||||
try:
|
||||
import ray
|
||||
@ -141,7 +141,7 @@ class VectorEnv(BaseVectorEnv):
|
||||
return [getattr(env, key) if hasattr(env, key) else None
|
||||
for env in self.envs]
|
||||
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> None:
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||
if id is None:
|
||||
self._obs = np.stack([e.reset() for e in self.envs])
|
||||
else:
|
||||
@ -173,14 +173,14 @@ class VectorEnv(BaseVectorEnv):
|
||||
result.append(e.seed(s))
|
||||
return result
|
||||
|
||||
def render(self, **kwargs) -> None:
|
||||
def render(self, **kwargs) -> List[Any]:
|
||||
result = []
|
||||
for e in self.envs:
|
||||
if hasattr(e, 'render'):
|
||||
result.append(e.render(**kwargs))
|
||||
return result
|
||||
|
||||
def close(self) -> None:
|
||||
def close(self) -> List[Any]:
|
||||
return [e.close() for e in self.envs]
|
||||
|
||||
|
||||
@ -254,7 +254,7 @@ class SubprocVectorEnv(BaseVectorEnv):
|
||||
self._info = np.stack(self._info)
|
||||
return self._obs, self._rew, self._done, self._info
|
||||
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> None:
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||
if id is None:
|
||||
for p in self.parent_remote:
|
||||
p.send(['reset', None])
|
||||
@ -278,14 +278,14 @@ class SubprocVectorEnv(BaseVectorEnv):
|
||||
p.send(['seed', s])
|
||||
return [p.recv() for p in self.parent_remote]
|
||||
|
||||
def render(self, **kwargs) -> None:
|
||||
def render(self, **kwargs) -> List[Any]:
|
||||
for p in self.parent_remote:
|
||||
p.send(['render', kwargs])
|
||||
return [p.recv() for p in self.parent_remote]
|
||||
|
||||
def close(self) -> None:
|
||||
def close(self) -> List[Any]:
|
||||
if self.closed:
|
||||
return
|
||||
return []
|
||||
for p in self.parent_remote:
|
||||
p.send(['close', None])
|
||||
result = [p.recv() for p in self.parent_remote]
|
||||
@ -333,7 +333,7 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
self._info = np.stack(self._info)
|
||||
return self._obs, self._rew, self._done, self._info
|
||||
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> None:
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||
if id is None:
|
||||
result_obj = [e.reset.remote() for e in self.envs]
|
||||
self._obs = np.stack(ray.get(result_obj))
|
||||
@ -349,17 +349,17 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
|
||||
if not hasattr(self.envs[0], 'seed'):
|
||||
return
|
||||
return []
|
||||
if np.isscalar(seed):
|
||||
seed = [seed + _ for _ in range(self.env_num)]
|
||||
elif seed is None:
|
||||
seed = [seed] * self.env_num
|
||||
return ray.get([e.seed.remote(s) for e, s in zip(self.envs, seed)])
|
||||
|
||||
def render(self, **kwargs) -> None:
|
||||
def render(self, **kwargs) -> List[Any]:
|
||||
if not hasattr(self.envs[0], 'render'):
|
||||
return
|
||||
return [None for e in self.envs]
|
||||
return ray.get([e.render.remote(**kwargs) for e in self.envs])
|
||||
|
||||
def close(self) -> None:
|
||||
def close(self) -> List[Any]:
|
||||
return ray.get([e.close.remote() for e in self.envs])
|
||||
|
Loading…
x
Reference in New Issue
Block a user