Enable getattr for SubprocVecEnv. (#74)
* Enable getattr for SubprovVecEnv. * Consistent API between VectorEnv and SubprocVecEnv. * Avoid code duplication. Add unit tests. * Add docstring. * Test more branches. * Fix UT. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
parent
66be5641b6
commit
52be533d06
@ -1,5 +1,6 @@
|
|||||||
import gym
|
|
||||||
import time
|
import time
|
||||||
|
import gym
|
||||||
|
from gym.spaces.discrete import Discrete
|
||||||
|
|
||||||
|
|
||||||
class MyTestEnv(gym.Env):
|
class MyTestEnv(gym.Env):
|
||||||
@ -7,6 +8,7 @@ class MyTestEnv(gym.Env):
|
|||||||
self.size = size
|
self.size = size
|
||||||
self.sleep = sleep
|
self.sleep = sleep
|
||||||
self.dict_state = dict_state
|
self.dict_state = dict_state
|
||||||
|
self.action_space = Discrete(1)
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self, state=0):
|
def reset(self, state=0):
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from gym.spaces.discrete import Discrete
|
||||||
from tianshou.env import VectorEnv, SubprocVectorEnv, RayVectorEnv
|
from tianshou.env import VectorEnv, SubprocVectorEnv, RayVectorEnv
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -11,14 +12,8 @@ else: # pytest
|
|||||||
def test_vecenv(size=10, num=8, sleep=0.001):
|
def test_vecenv(size=10, num=8, sleep=0.001):
|
||||||
verbose = __name__ == '__main__'
|
verbose = __name__ == '__main__'
|
||||||
env_fns = [
|
env_fns = [
|
||||||
lambda: MyTestEnv(size=size, sleep=sleep),
|
lambda i=i: MyTestEnv(size=i, sleep=sleep)
|
||||||
lambda: MyTestEnv(size=size + 1, sleep=sleep),
|
for i in range(size, size + num)
|
||||||
lambda: MyTestEnv(size=size + 2, sleep=sleep),
|
|
||||||
lambda: MyTestEnv(size=size + 3, sleep=sleep),
|
|
||||||
lambda: MyTestEnv(size=size + 4, sleep=sleep),
|
|
||||||
lambda: MyTestEnv(size=size + 5, sleep=sleep),
|
|
||||||
lambda: MyTestEnv(size=size + 6, sleep=sleep),
|
|
||||||
lambda: MyTestEnv(size=size + 7, sleep=sleep),
|
|
||||||
]
|
]
|
||||||
venv = [
|
venv = [
|
||||||
VectorEnv(env_fns),
|
VectorEnv(env_fns),
|
||||||
@ -54,6 +49,11 @@ def test_vecenv(size=10, num=8, sleep=0.001):
|
|||||||
print(f'VectorEnv: {t[0]:.6f}s')
|
print(f'VectorEnv: {t[0]:.6f}s')
|
||||||
print(f'SubprocVectorEnv: {t[1]:.6f}s')
|
print(f'SubprocVectorEnv: {t[1]:.6f}s')
|
||||||
print(f'RayVectorEnv: {t[2]:.6f}s')
|
print(f'RayVectorEnv: {t[2]:.6f}s')
|
||||||
|
for v in venv:
|
||||||
|
assert v.size == list(range(size, size + num))
|
||||||
|
assert v.env_num == num
|
||||||
|
assert v.action_space == [Discrete(1)] * num
|
||||||
|
|
||||||
for v in venv:
|
for v in venv:
|
||||||
v.close()
|
v.close()
|
||||||
|
|
||||||
|
|||||||
33
tianshou/env/vecenv.py
vendored
33
tianshou/env/vecenv.py
vendored
@ -40,11 +40,30 @@ class BaseVectorEnv(ABC, gym.Env):
|
|||||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
||||||
self._env_fns = env_fns
|
self._env_fns = env_fns
|
||||||
self.env_num = len(env_fns)
|
self.env_num = len(env_fns)
|
||||||
|
self._obs = None
|
||||||
|
self._rew = None
|
||||||
|
self._done = None
|
||||||
|
self._info = None
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Return len(self), which is the number of environments."""
|
"""Return len(self), which is the number of environments."""
|
||||||
return self.env_num
|
return self.env_num
|
||||||
|
|
||||||
|
def __getattribute__(self, key):
|
||||||
|
"""Switch between the default attribute getter or one
|
||||||
|
looking at wrapped environment level depending on the key."""
|
||||||
|
if key not in ('observation_space', 'action_space'):
|
||||||
|
return super().__getattribute__(key)
|
||||||
|
else:
|
||||||
|
return self.__getattr__(key)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __getattr__(self, key):
|
||||||
|
"""Try to retrieve an attribute from each individual wrapped
|
||||||
|
environment, if it does not belong to the wrapping vector
|
||||||
|
environment class."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def reset(self, id: Optional[Union[int, List[int]]] = None):
|
def reset(self, id: Optional[Union[int, List[int]]] = None):
|
||||||
"""Reset the state of all the environments and return initial
|
"""Reset the state of all the environments and return initial
|
||||||
@ -118,6 +137,10 @@ class VectorEnv(BaseVectorEnv):
|
|||||||
super().__init__(env_fns)
|
super().__init__(env_fns)
|
||||||
self.envs = [_() for _ in env_fns]
|
self.envs = [_() for _ in env_fns]
|
||||||
|
|
||||||
|
def __getattr__(self, key):
|
||||||
|
return [getattr(env, key) if hasattr(env, key) else None
|
||||||
|
for env in self.envs]
|
||||||
|
|
||||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> None:
|
def reset(self, id: Optional[Union[int, List[int]]] = None) -> None:
|
||||||
if id is None:
|
if id is None:
|
||||||
self._obs = np.stack([e.reset() for e in self.envs])
|
self._obs = np.stack([e.reset() for e in self.envs])
|
||||||
@ -179,6 +202,8 @@ def worker(parent, p, env_fn_wrapper):
|
|||||||
p.send(env.render(**data) if hasattr(env, 'render') else None)
|
p.send(env.render(**data) if hasattr(env, 'render') else None)
|
||||||
elif cmd == 'seed':
|
elif cmd == 'seed':
|
||||||
p.send(env.seed(data) if hasattr(env, 'seed') else None)
|
p.send(env.seed(data) if hasattr(env, 'seed') else None)
|
||||||
|
elif cmd == 'getattr':
|
||||||
|
p.send(getattr(env, data) if hasattr(env, data) else None)
|
||||||
else:
|
else:
|
||||||
p.close()
|
p.close()
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -211,6 +236,11 @@ class SubprocVectorEnv(BaseVectorEnv):
|
|||||||
for c in self.child_remote:
|
for c in self.child_remote:
|
||||||
c.close()
|
c.close()
|
||||||
|
|
||||||
|
def __getattr__(self, key):
|
||||||
|
for p in self.parent_remote:
|
||||||
|
p.send(['getattr', key])
|
||||||
|
return [p.recv() for p in self.parent_remote]
|
||||||
|
|
||||||
def step(self, action: np.ndarray
|
def step(self, action: np.ndarray
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||||
assert len(action) == self.env_num
|
assert len(action) == self.env_num
|
||||||
@ -289,6 +319,9 @@ class RayVectorEnv(BaseVectorEnv):
|
|||||||
ray.remote(gym.Wrapper).options(num_cpus=0).remote(e())
|
ray.remote(gym.Wrapper).options(num_cpus=0).remote(e())
|
||||||
for e in env_fns]
|
for e in env_fns]
|
||||||
|
|
||||||
|
def __getattr__(self, key):
|
||||||
|
return ray.get([e.getattr.remote(key) for e in self.envs])
|
||||||
|
|
||||||
def step(self, action: np.ndarray
|
def step(self, action: np.ndarray
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||||
assert len(action) == self.env_num
|
assert len(action) == self.env_num
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user