Enable getattr for SubprocVecEnv. (#74)
* Enable getattr for SubprovVecEnv. * Consistent API between VectorEnv and SubprocVecEnv. * Avoid code duplication. Add unit tests. * Add docstring. * Test more branches. * Fix UT. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
parent
66be5641b6
commit
52be533d06
@ -1,5 +1,6 @@
|
||||
import gym
|
||||
import time
|
||||
import gym
|
||||
from gym.spaces.discrete import Discrete
|
||||
|
||||
|
||||
class MyTestEnv(gym.Env):
|
||||
@ -7,6 +8,7 @@ class MyTestEnv(gym.Env):
|
||||
self.size = size
|
||||
self.sleep = sleep
|
||||
self.dict_state = dict_state
|
||||
self.action_space = Discrete(1)
|
||||
self.reset()
|
||||
|
||||
def reset(self, state=0):
|
||||
|
@ -1,5 +1,6 @@
|
||||
import time
|
||||
import numpy as np
|
||||
from gym.spaces.discrete import Discrete
|
||||
from tianshou.env import VectorEnv, SubprocVectorEnv, RayVectorEnv
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -11,14 +12,8 @@ else: # pytest
|
||||
def test_vecenv(size=10, num=8, sleep=0.001):
|
||||
verbose = __name__ == '__main__'
|
||||
env_fns = [
|
||||
lambda: MyTestEnv(size=size, sleep=sleep),
|
||||
lambda: MyTestEnv(size=size + 1, sleep=sleep),
|
||||
lambda: MyTestEnv(size=size + 2, sleep=sleep),
|
||||
lambda: MyTestEnv(size=size + 3, sleep=sleep),
|
||||
lambda: MyTestEnv(size=size + 4, sleep=sleep),
|
||||
lambda: MyTestEnv(size=size + 5, sleep=sleep),
|
||||
lambda: MyTestEnv(size=size + 6, sleep=sleep),
|
||||
lambda: MyTestEnv(size=size + 7, sleep=sleep),
|
||||
lambda i=i: MyTestEnv(size=i, sleep=sleep)
|
||||
for i in range(size, size + num)
|
||||
]
|
||||
venv = [
|
||||
VectorEnv(env_fns),
|
||||
@ -54,6 +49,11 @@ def test_vecenv(size=10, num=8, sleep=0.001):
|
||||
print(f'VectorEnv: {t[0]:.6f}s')
|
||||
print(f'SubprocVectorEnv: {t[1]:.6f}s')
|
||||
print(f'RayVectorEnv: {t[2]:.6f}s')
|
||||
for v in venv:
|
||||
assert v.size == list(range(size, size + num))
|
||||
assert v.env_num == num
|
||||
assert v.action_space == [Discrete(1)] * num
|
||||
|
||||
for v in venv:
|
||||
v.close()
|
||||
|
||||
|
33
tianshou/env/vecenv.py
vendored
33
tianshou/env/vecenv.py
vendored
@ -40,11 +40,30 @@ class BaseVectorEnv(ABC, gym.Env):
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
||||
self._env_fns = env_fns
|
||||
self.env_num = len(env_fns)
|
||||
self._obs = None
|
||||
self._rew = None
|
||||
self._done = None
|
||||
self._info = None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return len(self), which is the number of environments."""
|
||||
return self.env_num
|
||||
|
||||
def __getattribute__(self, key):
|
||||
"""Switch between the default attribute getter or one
|
||||
looking at wrapped environment level depending on the key."""
|
||||
if key not in ('observation_space', 'action_space'):
|
||||
return super().__getattribute__(key)
|
||||
else:
|
||||
return self.__getattr__(key)
|
||||
|
||||
@abstractmethod
|
||||
def __getattr__(self, key):
|
||||
"""Try to retrieve an attribute from each individual wrapped
|
||||
environment, if it does not belong to the wrapping vector
|
||||
environment class."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None):
|
||||
"""Reset the state of all the environments and return initial
|
||||
@ -118,6 +137,10 @@ class VectorEnv(BaseVectorEnv):
|
||||
super().__init__(env_fns)
|
||||
self.envs = [_() for _ in env_fns]
|
||||
|
||||
def __getattr__(self, key):
|
||||
return [getattr(env, key) if hasattr(env, key) else None
|
||||
for env in self.envs]
|
||||
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> None:
|
||||
if id is None:
|
||||
self._obs = np.stack([e.reset() for e in self.envs])
|
||||
@ -179,6 +202,8 @@ def worker(parent, p, env_fn_wrapper):
|
||||
p.send(env.render(**data) if hasattr(env, 'render') else None)
|
||||
elif cmd == 'seed':
|
||||
p.send(env.seed(data) if hasattr(env, 'seed') else None)
|
||||
elif cmd == 'getattr':
|
||||
p.send(getattr(env, data) if hasattr(env, data) else None)
|
||||
else:
|
||||
p.close()
|
||||
raise NotImplementedError
|
||||
@ -211,6 +236,11 @@ class SubprocVectorEnv(BaseVectorEnv):
|
||||
for c in self.child_remote:
|
||||
c.close()
|
||||
|
||||
def __getattr__(self, key):
|
||||
for p in self.parent_remote:
|
||||
p.send(['getattr', key])
|
||||
return [p.recv() for p in self.parent_remote]
|
||||
|
||||
def step(self, action: np.ndarray
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
assert len(action) == self.env_num
|
||||
@ -289,6 +319,9 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
ray.remote(gym.Wrapper).options(num_cpus=0).remote(e())
|
||||
for e in env_fns]
|
||||
|
||||
def __getattr__(self, key):
|
||||
return ray.get([e.getattr.remote(key) for e in self.envs])
|
||||
|
||||
def step(self, action: np.ndarray
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
assert len(action) == self.env_num
|
||||
|
Loading…
x
Reference in New Issue
Block a user