type signature correction (#83)
This commit is contained in:
parent
81e4a16ef2
commit
268f9d0533
26
tianshou/env/vecenv.py
vendored
26
tianshou/env/vecenv.py
vendored
@ -2,7 +2,7 @@ import gym
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from multiprocessing import Process, Pipe
|
from multiprocessing import Process, Pipe
|
||||||
from typing import List, Tuple, Union, Optional, Callable
|
from typing import List, Tuple, Union, Optional, Callable, Any
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ray
|
import ray
|
||||||
@ -141,7 +141,7 @@ class VectorEnv(BaseVectorEnv):
|
|||||||
return [getattr(env, key) if hasattr(env, key) else None
|
return [getattr(env, key) if hasattr(env, key) else None
|
||||||
for env in self.envs]
|
for env in self.envs]
|
||||||
|
|
||||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> None:
|
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||||
if id is None:
|
if id is None:
|
||||||
self._obs = np.stack([e.reset() for e in self.envs])
|
self._obs = np.stack([e.reset() for e in self.envs])
|
||||||
else:
|
else:
|
||||||
@ -173,14 +173,14 @@ class VectorEnv(BaseVectorEnv):
|
|||||||
result.append(e.seed(s))
|
result.append(e.seed(s))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def render(self, **kwargs) -> None:
|
def render(self, **kwargs) -> List[Any]:
|
||||||
result = []
|
result = []
|
||||||
for e in self.envs:
|
for e in self.envs:
|
||||||
if hasattr(e, 'render'):
|
if hasattr(e, 'render'):
|
||||||
result.append(e.render(**kwargs))
|
result.append(e.render(**kwargs))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> List[Any]:
|
||||||
return [e.close() for e in self.envs]
|
return [e.close() for e in self.envs]
|
||||||
|
|
||||||
|
|
||||||
@ -254,7 +254,7 @@ class SubprocVectorEnv(BaseVectorEnv):
|
|||||||
self._info = np.stack(self._info)
|
self._info = np.stack(self._info)
|
||||||
return self._obs, self._rew, self._done, self._info
|
return self._obs, self._rew, self._done, self._info
|
||||||
|
|
||||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> None:
|
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||||
if id is None:
|
if id is None:
|
||||||
for p in self.parent_remote:
|
for p in self.parent_remote:
|
||||||
p.send(['reset', None])
|
p.send(['reset', None])
|
||||||
@ -278,14 +278,14 @@ class SubprocVectorEnv(BaseVectorEnv):
|
|||||||
p.send(['seed', s])
|
p.send(['seed', s])
|
||||||
return [p.recv() for p in self.parent_remote]
|
return [p.recv() for p in self.parent_remote]
|
||||||
|
|
||||||
def render(self, **kwargs) -> None:
|
def render(self, **kwargs) -> List[Any]:
|
||||||
for p in self.parent_remote:
|
for p in self.parent_remote:
|
||||||
p.send(['render', kwargs])
|
p.send(['render', kwargs])
|
||||||
return [p.recv() for p in self.parent_remote]
|
return [p.recv() for p in self.parent_remote]
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> List[Any]:
|
||||||
if self.closed:
|
if self.closed:
|
||||||
return
|
return []
|
||||||
for p in self.parent_remote:
|
for p in self.parent_remote:
|
||||||
p.send(['close', None])
|
p.send(['close', None])
|
||||||
result = [p.recv() for p in self.parent_remote]
|
result = [p.recv() for p in self.parent_remote]
|
||||||
@ -333,7 +333,7 @@ class RayVectorEnv(BaseVectorEnv):
|
|||||||
self._info = np.stack(self._info)
|
self._info = np.stack(self._info)
|
||||||
return self._obs, self._rew, self._done, self._info
|
return self._obs, self._rew, self._done, self._info
|
||||||
|
|
||||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> None:
|
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||||
if id is None:
|
if id is None:
|
||||||
result_obj = [e.reset.remote() for e in self.envs]
|
result_obj = [e.reset.remote() for e in self.envs]
|
||||||
self._obs = np.stack(ray.get(result_obj))
|
self._obs = np.stack(ray.get(result_obj))
|
||||||
@ -349,17 +349,17 @@ class RayVectorEnv(BaseVectorEnv):
|
|||||||
|
|
||||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
|
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
|
||||||
if not hasattr(self.envs[0], 'seed'):
|
if not hasattr(self.envs[0], 'seed'):
|
||||||
return
|
return []
|
||||||
if np.isscalar(seed):
|
if np.isscalar(seed):
|
||||||
seed = [seed + _ for _ in range(self.env_num)]
|
seed = [seed + _ for _ in range(self.env_num)]
|
||||||
elif seed is None:
|
elif seed is None:
|
||||||
seed = [seed] * self.env_num
|
seed = [seed] * self.env_num
|
||||||
return ray.get([e.seed.remote(s) for e, s in zip(self.envs, seed)])
|
return ray.get([e.seed.remote(s) for e, s in zip(self.envs, seed)])
|
||||||
|
|
||||||
def render(self, **kwargs) -> None:
|
def render(self, **kwargs) -> List[Any]:
|
||||||
if not hasattr(self.envs[0], 'render'):
|
if not hasattr(self.envs[0], 'render'):
|
||||||
return
|
return [None for e in self.envs]
|
||||||
return ray.get([e.render.remote(**kwargs) for e in self.envs])
|
return ray.get([e.render.remote(**kwargs) for e in self.envs])
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> List[Any]:
|
||||||
return ray.get([e.close.remote() for e in self.envs])
|
return ray.get([e.close.remote() for e in self.envs])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user