Implements set_env_attr and get_env_attr for vector environments (#478)

close #473
This commit is contained in:
Markus28 2021-11-02 17:08:00 +01:00 committed by GitHub
parent 098d466467
commit 8f19a86966
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 103 additions and 25 deletions

View File

@ -166,10 +166,25 @@ def test_vecenv(size=10, num=8, sleep=0.001):
for i, v in enumerate(venv):
print(f'{type(v)}: {t[i]:.6f}s')
def assert_get(v, expected):
assert v.get_env_attr("size") == expected
assert v.get_env_attr("size", id=0) == [expected[0]]
assert v.get_env_attr("size", id=[0, 1, 2]) == expected[:3]
for v in venv:
assert v.size == list(range(size, size + num))
assert_get(v, list(range(size, size + num)))
assert v.env_num == num
assert v.action_space == [Discrete(2)] * num
v.set_env_attr("size", 0)
assert_get(v, [0] * num)
v.set_env_attr("size", 1, 0)
assert_get(v, [1] + [0] * (num - 1))
v.set_env_attr("size", 2, [1, 2, 3])
assert_get(v, [1] + [2] * 3 + [0] * (num - 4))
for v in venv:
v.close()

View File

@ -10,7 +10,7 @@ from tianshou.data import Batch
@pytest.fixture(scope="module")
def data():
print("Initialising data...")
print("Initializing data...")
np.random.seed(0)
batch_set = [
Batch(
@ -19,7 +19,7 @@ def data():
'b1': (3.14, 3.14),
'b2': np.arange(1e3)
},
c=i
c=i,
) for i in np.arange(int(1e4))
]
batch0 = Batch(
@ -27,8 +27,8 @@ def data():
b=Batch(
c=np.ones((1, ), dtype=np.float64),
d=torch.ones((3, 3, 3), dtype=torch.float32),
e=list(range(3))
)
e=list(range(3)),
),
)
batchs1 = [copy.deepcopy(batch0) for _ in np.arange(1e4)]
batchs2 = [copy.deepcopy(batch0) for _ in np.arange(1e4)]
@ -39,13 +39,13 @@ def data():
indexs = np.random.choice(batch_len, size=batch_len // 10, replace=False)
slice_dict = {
'obs': [np.arange(20) for _ in np.arange(batch_len // 10)],
'reward': np.arange(batch_len // 10)
'reward': np.arange(batch_len // 10),
}
dict_set = [
{
'obs': np.arange(20),
'info': "this is info",
'reward': 0
'reward': 0,
} for _ in np.arange(1e2)
]
batch4 = Batch(
@ -53,11 +53,11 @@ def data():
b=Batch(
c=np.ones((1, ), dtype=np.float64),
d=torch.ones((1000, 1000), dtype=torch.float32),
e=np.arange(1000)
)
e=np.arange(1000),
),
)
print("Initialised")
print("Initialized")
return {
'batch_set': batch_set,
'batch0': batch0,
@ -67,7 +67,7 @@ def data():
'indexs': indexs,
'dict_set': dict_set,
'slice_dict': slice_dict,
'batch4': batch4
'batch4': batch4,
}
@ -106,7 +106,7 @@ def test_set_attr(data):
def test_numpy_torch_convert(data):
"""Test conversion between numpy and torch."""
for _ in np.arange(1e5):
for _ in np.arange(1e4): # not sure what's wrong in torch==1.10.0
data['batch4'].to_torch()
data['batch4'].to_numpy()

51
tianshou/env/venvs.py vendored
View File

@ -130,17 +130,56 @@ class BaseVectorEnv(gym.Env):
if key in [
'metadata', 'reward_range', 'spec', 'action_space', 'observation_space'
]: # reserved keys in gym.Env
return self.__getattr__(key)
return self.get_env_attr(key)
else:
return super().__getattribute__(key)
def __getattr__(self, key: str) -> List[Any]:
"""Fetch a list of env attributes.
def get_env_attr(
self,
key: str,
id: Optional[Union[int, List[int], np.ndarray]] = None
) -> List[Any]:
"""Get an attribute from the underlying environments.
This function tries to retrieve an attribute from each individual wrapped
environment, if it does not belong to the wrapping vector environment class.
If id is an int, retrieve the attribute denoted by key from the environment
underlying the worker at index id. The result is returned as a list with one
element. Otherwise, retrieve the attribute for all workers at indices id and
return a list that is ordered correspondingly to id.
:param str key: The key of the desired attribute.
:param id: Indice(s) of the desired worker(s). Default to None for all env_id.
:return list: The list of environment attributes.
"""
return [getattr(worker, key) for worker in self.workers]
self._assert_is_not_closed()
id = self._wrap_id(id)
if self.is_async:
self._assert_id(id)
return [self.workers[j].get_env_attr(key) for j in id]
def set_env_attr(
self,
key: str,
value: Any,
id: Optional[Union[int, List[int], np.ndarray]] = None
) -> None:
"""Set an attribute in the underlying environments.
If id is an int, set the attribute denoted by key from the environment
underlying the worker at index id to value.
Otherwise, set the attribute for all workers at indices id.
:param str key: The key of the desired attribute.
:param Any value: The new value of the attribute.
:param id: Indice(s) of the desired worker(s). Default to None for all env_id.
"""
self._assert_is_not_closed()
id = self._wrap_id(id)
if self.is_async:
self._assert_id(id)
for j in id:
self.workers[j].set_env_attr(key, value)
def _wrap_id(
self,

View File

@ -12,10 +12,14 @@ class EnvWorker(ABC):
self._env_fn = env_fn
self.is_closed = False
self.result: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
self.action_space = getattr(self, "action_space") # noqa: B009
self.action_space = self.get_env_attr("action_space") # noqa: B009
@abstractmethod
def __getattr__(self, key: str) -> Any:
def get_env_attr(self, key: str) -> Any:
pass
@abstractmethod
def set_env_attr(self, key: str, value: Any) -> None:
pass
@abstractmethod

View File

@ -13,9 +13,12 @@ class DummyEnvWorker(EnvWorker):
self.env = env_fn()
super().__init__(env_fn)
def __getattr__(self, key: str) -> Any:
def get_env_attr(self, key: str) -> Any:
return getattr(self.env, key)
def set_env_attr(self, key: str, value: Any) -> None:
setattr(self.env, key, value)
def reset(self) -> Any:
return self.env.reset()

View File

@ -11,15 +11,27 @@ except ImportError:
pass
class _SetAttrWrapper(gym.Wrapper):
def set_env_attr(self, key: str, value: Any) -> None:
setattr(self.env, key, value)
def get_env_attr(self, key: str) -> Any:
return getattr(self.env, key)
class RayEnvWorker(EnvWorker):
"""Ray worker used in RayVectorEnv."""
def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
self.env = ray.remote(gym.Wrapper).options(num_cpus=0).remote(env_fn())
self.env = ray.remote(_SetAttrWrapper).options(num_cpus=0).remote(env_fn())
super().__init__(env_fn)
def __getattr__(self, key: str) -> Any:
return ray.get(self.env.__getattr__.remote(key))
def get_env_attr(self, key: str) -> Any:
return ray.get(self.env.get_env_attr.remote(key))
def set_env_attr(self, key: str, value: Any) -> None:
ray.get(self.env.set_env_attr.remote(key, value))
def reset(self) -> Any:
return ray.get(self.env.reset.remote())

View File

@ -107,6 +107,8 @@ def _worker(
p.send(env.seed(data) if hasattr(env, "seed") else None)
elif cmd == "getattr":
p.send(getattr(env, data) if hasattr(env, data) else None)
elif cmd == "setattr":
setattr(env, data["key"], data["value"])
else:
p.close()
raise NotImplementedError
@ -140,10 +142,13 @@ class SubprocEnvWorker(EnvWorker):
self.child_remote.close()
super().__init__(env_fn)
def __getattr__(self, key: str) -> Any:
def get_env_attr(self, key: str) -> Any:
self.parent_remote.send(["getattr", key])
return self.parent_remote.recv()
def set_env_attr(self, key: str, value: Any) -> None:
self.parent_remote.send(["setattr", {"key": key, "value": value}])
def _decode_obs(self) -> Union[dict, tuple, np.ndarray]:
def decode_obs(