Enable getattr for SubprocVecEnv. (#74)

* Enable getattr for SubprovVecEnv.

* Consistent API between VectorEnv and SubprocVecEnv.

* Avoid code duplication. Add unit tests.

* Add docstring.

* Test more branches.

* Fix UT.

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
Alexis DUBURCQ 2020-06-05 11:17:43 +02:00 committed by GitHub
parent 66be5641b6
commit 52be533d06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 9 deletions

View File

@ -1,5 +1,6 @@
import gym
import time
import gym
from gym.spaces.discrete import Discrete
class MyTestEnv(gym.Env):
@ -7,6 +8,7 @@ class MyTestEnv(gym.Env):
self.size = size
self.sleep = sleep
self.dict_state = dict_state
self.action_space = Discrete(1)
self.reset()
def reset(self, state=0):

View File

@ -1,5 +1,6 @@
import time
import numpy as np
from gym.spaces.discrete import Discrete
from tianshou.env import VectorEnv, SubprocVectorEnv, RayVectorEnv
if __name__ == '__main__':
@ -11,14 +12,8 @@ else: # pytest
def test_vecenv(size=10, num=8, sleep=0.001):
verbose = __name__ == '__main__'
env_fns = [
lambda: MyTestEnv(size=size, sleep=sleep),
lambda: MyTestEnv(size=size + 1, sleep=sleep),
lambda: MyTestEnv(size=size + 2, sleep=sleep),
lambda: MyTestEnv(size=size + 3, sleep=sleep),
lambda: MyTestEnv(size=size + 4, sleep=sleep),
lambda: MyTestEnv(size=size + 5, sleep=sleep),
lambda: MyTestEnv(size=size + 6, sleep=sleep),
lambda: MyTestEnv(size=size + 7, sleep=sleep),
lambda i=i: MyTestEnv(size=i, sleep=sleep)
for i in range(size, size + num)
]
venv = [
VectorEnv(env_fns),
@ -54,6 +49,11 @@ def test_vecenv(size=10, num=8, sleep=0.001):
print(f'VectorEnv: {t[0]:.6f}s')
print(f'SubprocVectorEnv: {t[1]:.6f}s')
print(f'RayVectorEnv: {t[2]:.6f}s')
for v in venv:
assert v.size == list(range(size, size + num))
assert v.env_num == num
assert v.action_space == [Discrete(1)] * num
for v in venv:
v.close()

View File

@ -40,11 +40,30 @@ class BaseVectorEnv(ABC, gym.Env):
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
self._env_fns = env_fns
self.env_num = len(env_fns)
self._obs = None
self._rew = None
self._done = None
self._info = None
def __len__(self) -> int:
"""Return len(self), which is the number of environments."""
return self.env_num
def __getattribute__(self, key):
"""Switch between the default attribute getter or one
looking at wrapped environment level depending on the key."""
if key not in ('observation_space', 'action_space'):
return super().__getattribute__(key)
else:
return self.__getattr__(key)
@abstractmethod
def __getattr__(self, key):
"""Try to retrieve an attribute from each individual wrapped
environment, if it does not belong to the wrapping vector
environment class."""
pass
@abstractmethod
def reset(self, id: Optional[Union[int, List[int]]] = None):
"""Reset the state of all the environments and return initial
@ -118,6 +137,10 @@ class VectorEnv(BaseVectorEnv):
super().__init__(env_fns)
self.envs = [_() for _ in env_fns]
def __getattr__(self, key):
return [getattr(env, key) if hasattr(env, key) else None
for env in self.envs]
def reset(self, id: Optional[Union[int, List[int]]] = None) -> None:
if id is None:
self._obs = np.stack([e.reset() for e in self.envs])
@ -179,6 +202,8 @@ def worker(parent, p, env_fn_wrapper):
p.send(env.render(**data) if hasattr(env, 'render') else None)
elif cmd == 'seed':
p.send(env.seed(data) if hasattr(env, 'seed') else None)
elif cmd == 'getattr':
p.send(getattr(env, data) if hasattr(env, data) else None)
else:
p.close()
raise NotImplementedError
@ -211,6 +236,11 @@ class SubprocVectorEnv(BaseVectorEnv):
for c in self.child_remote:
c.close()
def __getattr__(self, key):
for p in self.parent_remote:
p.send(['getattr', key])
return [p.recv() for p in self.parent_remote]
def step(self, action: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
assert len(action) == self.env_num
@ -289,6 +319,9 @@ class RayVectorEnv(BaseVectorEnv):
ray.remote(gym.Wrapper).options(num_cpus=0).remote(e())
for e in env_fns]
def __getattr__(self, key):
return ray.get([e.getattr.remote(key) for e in self.envs])
def step(self, action: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
assert len(action) == self.env_num