Implements set_env_attr and get_env_attr for vector environments (#478)

close #473
This commit is contained in:
Markus28 2021-11-02 17:08:00 +01:00 committed by GitHub
parent 098d466467
commit 8f19a86966
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 103 additions and 25 deletions

View File

@ -166,10 +166,25 @@ def test_vecenv(size=10, num=8, sleep=0.001):
for i, v in enumerate(venv): for i, v in enumerate(venv):
print(f'{type(v)}: {t[i]:.6f}s') print(f'{type(v)}: {t[i]:.6f}s')
def assert_get(v, expected):
assert v.get_env_attr("size") == expected
assert v.get_env_attr("size", id=0) == [expected[0]]
assert v.get_env_attr("size", id=[0, 1, 2]) == expected[:3]
for v in venv: for v in venv:
assert v.size == list(range(size, size + num)) assert_get(v, list(range(size, size + num)))
assert v.env_num == num assert v.env_num == num
assert v.action_space == [Discrete(2)] * num assert v.action_space == [Discrete(2)] * num
v.set_env_attr("size", 0)
assert_get(v, [0] * num)
v.set_env_attr("size", 1, 0)
assert_get(v, [1] + [0] * (num - 1))
v.set_env_attr("size", 2, [1, 2, 3])
assert_get(v, [1] + [2] * 3 + [0] * (num - 4))
for v in venv: for v in venv:
v.close() v.close()

View File

@ -10,7 +10,7 @@ from tianshou.data import Batch
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def data(): def data():
print("Initialising data...") print("Initializing data...")
np.random.seed(0) np.random.seed(0)
batch_set = [ batch_set = [
Batch( Batch(
@ -19,7 +19,7 @@ def data():
'b1': (3.14, 3.14), 'b1': (3.14, 3.14),
'b2': np.arange(1e3) 'b2': np.arange(1e3)
}, },
c=i c=i,
) for i in np.arange(int(1e4)) ) for i in np.arange(int(1e4))
] ]
batch0 = Batch( batch0 = Batch(
@ -27,8 +27,8 @@ def data():
b=Batch( b=Batch(
c=np.ones((1, ), dtype=np.float64), c=np.ones((1, ), dtype=np.float64),
d=torch.ones((3, 3, 3), dtype=torch.float32), d=torch.ones((3, 3, 3), dtype=torch.float32),
e=list(range(3)) e=list(range(3)),
) ),
) )
batchs1 = [copy.deepcopy(batch0) for _ in np.arange(1e4)] batchs1 = [copy.deepcopy(batch0) for _ in np.arange(1e4)]
batchs2 = [copy.deepcopy(batch0) for _ in np.arange(1e4)] batchs2 = [copy.deepcopy(batch0) for _ in np.arange(1e4)]
@ -39,13 +39,13 @@ def data():
indexs = np.random.choice(batch_len, size=batch_len // 10, replace=False) indexs = np.random.choice(batch_len, size=batch_len // 10, replace=False)
slice_dict = { slice_dict = {
'obs': [np.arange(20) for _ in np.arange(batch_len // 10)], 'obs': [np.arange(20) for _ in np.arange(batch_len // 10)],
'reward': np.arange(batch_len // 10) 'reward': np.arange(batch_len // 10),
} }
dict_set = [ dict_set = [
{ {
'obs': np.arange(20), 'obs': np.arange(20),
'info': "this is info", 'info': "this is info",
'reward': 0 'reward': 0,
} for _ in np.arange(1e2) } for _ in np.arange(1e2)
] ]
batch4 = Batch( batch4 = Batch(
@ -53,11 +53,11 @@ def data():
b=Batch( b=Batch(
c=np.ones((1, ), dtype=np.float64), c=np.ones((1, ), dtype=np.float64),
d=torch.ones((1000, 1000), dtype=torch.float32), d=torch.ones((1000, 1000), dtype=torch.float32),
e=np.arange(1000) e=np.arange(1000),
) ),
) )
print("Initialised") print("Initialized")
return { return {
'batch_set': batch_set, 'batch_set': batch_set,
'batch0': batch0, 'batch0': batch0,
@ -67,7 +67,7 @@ def data():
'indexs': indexs, 'indexs': indexs,
'dict_set': dict_set, 'dict_set': dict_set,
'slice_dict': slice_dict, 'slice_dict': slice_dict,
'batch4': batch4 'batch4': batch4,
} }
@ -106,7 +106,7 @@ def test_set_attr(data):
def test_numpy_torch_convert(data): def test_numpy_torch_convert(data):
"""Test conversion between numpy and torch.""" """Test conversion between numpy and torch."""
for _ in np.arange(1e5): for _ in np.arange(1e4): # not sure what's wrong in torch==1.10.0
data['batch4'].to_torch() data['batch4'].to_torch()
data['batch4'].to_numpy() data['batch4'].to_numpy()

51
tianshou/env/venvs.py vendored
View File

@ -130,17 +130,56 @@ class BaseVectorEnv(gym.Env):
if key in [ if key in [
'metadata', 'reward_range', 'spec', 'action_space', 'observation_space' 'metadata', 'reward_range', 'spec', 'action_space', 'observation_space'
]: # reserved keys in gym.Env ]: # reserved keys in gym.Env
return self.__getattr__(key) return self.get_env_attr(key)
else: else:
return super().__getattribute__(key) return super().__getattribute__(key)
def __getattr__(self, key: str) -> List[Any]: def get_env_attr(
"""Fetch a list of env attributes. self,
key: str,
id: Optional[Union[int, List[int], np.ndarray]] = None
) -> List[Any]:
"""Get an attribute from the underlying environments.
This function tries to retrieve an attribute from each individual wrapped If id is an int, retrieve the attribute denoted by key from the environment
environment, if it does not belong to the wrapping vector environment class. underlying the worker at index id. The result is returned as a list with one
element. Otherwise, retrieve the attribute for all workers at indices id and
return a list that is ordered correspondingly to id.
:param str key: The key of the desired attribute.
:param id: Indice(s) of the desired worker(s). Default to None for all env_id.
:return list: The list of environment attributes.
""" """
return [getattr(worker, key) for worker in self.workers] self._assert_is_not_closed()
id = self._wrap_id(id)
if self.is_async:
self._assert_id(id)
return [self.workers[j].get_env_attr(key) for j in id]
def set_env_attr(
self,
key: str,
value: Any,
id: Optional[Union[int, List[int], np.ndarray]] = None
) -> None:
"""Set an attribute in the underlying environments.
If id is an int, set the attribute denoted by key from the environment
underlying the worker at index id to value.
Otherwise, set the attribute for all workers at indices id.
:param str key: The key of the desired attribute.
:param Any value: The new value of the attribute.
:param id: Indice(s) of the desired worker(s). Default to None for all env_id.
"""
self._assert_is_not_closed()
id = self._wrap_id(id)
if self.is_async:
self._assert_id(id)
for j in id:
self.workers[j].set_env_attr(key, value)
def _wrap_id( def _wrap_id(
self, self,

View File

@ -12,10 +12,14 @@ class EnvWorker(ABC):
self._env_fn = env_fn self._env_fn = env_fn
self.is_closed = False self.is_closed = False
self.result: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] self.result: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
self.action_space = getattr(self, "action_space") # noqa: B009 self.action_space = self.get_env_attr("action_space") # noqa: B009
@abstractmethod @abstractmethod
def __getattr__(self, key: str) -> Any: def get_env_attr(self, key: str) -> Any:
pass
@abstractmethod
def set_env_attr(self, key: str, value: Any) -> None:
pass pass
@abstractmethod @abstractmethod

View File

@ -13,9 +13,12 @@ class DummyEnvWorker(EnvWorker):
self.env = env_fn() self.env = env_fn()
super().__init__(env_fn) super().__init__(env_fn)
def __getattr__(self, key: str) -> Any: def get_env_attr(self, key: str) -> Any:
return getattr(self.env, key) return getattr(self.env, key)
def set_env_attr(self, key: str, value: Any) -> None:
setattr(self.env, key, value)
def reset(self) -> Any: def reset(self) -> Any:
return self.env.reset() return self.env.reset()

View File

@ -11,15 +11,27 @@ except ImportError:
pass pass
class _SetAttrWrapper(gym.Wrapper):
def set_env_attr(self, key: str, value: Any) -> None:
setattr(self.env, key, value)
def get_env_attr(self, key: str) -> Any:
return getattr(self.env, key)
class RayEnvWorker(EnvWorker): class RayEnvWorker(EnvWorker):
"""Ray worker used in RayVectorEnv.""" """Ray worker used in RayVectorEnv."""
def __init__(self, env_fn: Callable[[], gym.Env]) -> None: def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
self.env = ray.remote(gym.Wrapper).options(num_cpus=0).remote(env_fn()) self.env = ray.remote(_SetAttrWrapper).options(num_cpus=0).remote(env_fn())
super().__init__(env_fn) super().__init__(env_fn)
def __getattr__(self, key: str) -> Any: def get_env_attr(self, key: str) -> Any:
return ray.get(self.env.__getattr__.remote(key)) return ray.get(self.env.get_env_attr.remote(key))
def set_env_attr(self, key: str, value: Any) -> None:
ray.get(self.env.set_env_attr.remote(key, value))
def reset(self) -> Any: def reset(self) -> Any:
return ray.get(self.env.reset.remote()) return ray.get(self.env.reset.remote())

View File

@ -107,6 +107,8 @@ def _worker(
p.send(env.seed(data) if hasattr(env, "seed") else None) p.send(env.seed(data) if hasattr(env, "seed") else None)
elif cmd == "getattr": elif cmd == "getattr":
p.send(getattr(env, data) if hasattr(env, data) else None) p.send(getattr(env, data) if hasattr(env, data) else None)
elif cmd == "setattr":
setattr(env, data["key"], data["value"])
else: else:
p.close() p.close()
raise NotImplementedError raise NotImplementedError
@ -140,10 +142,13 @@ class SubprocEnvWorker(EnvWorker):
self.child_remote.close() self.child_remote.close()
super().__init__(env_fn) super().__init__(env_fn)
def __getattr__(self, key: str) -> Any: def get_env_attr(self, key: str) -> Any:
self.parent_remote.send(["getattr", key]) self.parent_remote.send(["getattr", key])
return self.parent_remote.recv() return self.parent_remote.recv()
def set_env_attr(self, key: str, value: Any) -> None:
self.parent_remote.send(["setattr", {"key": key, "value": value}])
def _decode_obs(self) -> Union[dict, tuple, np.ndarray]: def _decode_obs(self) -> Union[dict, tuple, np.ndarray]:
def decode_obs( def decode_obs(