Implements set_env_attr and get_env_attr for vector environments (#478)
close #473
This commit is contained in:
parent
098d466467
commit
8f19a86966
@ -166,10 +166,25 @@ def test_vecenv(size=10, num=8, sleep=0.001):
|
|||||||
for i, v in enumerate(venv):
|
for i, v in enumerate(venv):
|
||||||
print(f'{type(v)}: {t[i]:.6f}s')
|
print(f'{type(v)}: {t[i]:.6f}s')
|
||||||
|
|
||||||
|
def assert_get(v, expected):
|
||||||
|
assert v.get_env_attr("size") == expected
|
||||||
|
assert v.get_env_attr("size", id=0) == [expected[0]]
|
||||||
|
assert v.get_env_attr("size", id=[0, 1, 2]) == expected[:3]
|
||||||
|
|
||||||
for v in venv:
|
for v in venv:
|
||||||
assert v.size == list(range(size, size + num))
|
assert_get(v, list(range(size, size + num)))
|
||||||
assert v.env_num == num
|
assert v.env_num == num
|
||||||
assert v.action_space == [Discrete(2)] * num
|
assert v.action_space == [Discrete(2)] * num
|
||||||
|
|
||||||
|
v.set_env_attr("size", 0)
|
||||||
|
assert_get(v, [0] * num)
|
||||||
|
|
||||||
|
v.set_env_attr("size", 1, 0)
|
||||||
|
assert_get(v, [1] + [0] * (num - 1))
|
||||||
|
|
||||||
|
v.set_env_attr("size", 2, [1, 2, 3])
|
||||||
|
assert_get(v, [1] + [2] * 3 + [0] * (num - 4))
|
||||||
|
|
||||||
for v in venv:
|
for v in venv:
|
||||||
v.close()
|
v.close()
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ from tianshou.data import Batch
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def data():
|
def data():
|
||||||
print("Initialising data...")
|
print("Initializing data...")
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
batch_set = [
|
batch_set = [
|
||||||
Batch(
|
Batch(
|
||||||
@ -19,7 +19,7 @@ def data():
|
|||||||
'b1': (3.14, 3.14),
|
'b1': (3.14, 3.14),
|
||||||
'b2': np.arange(1e3)
|
'b2': np.arange(1e3)
|
||||||
},
|
},
|
||||||
c=i
|
c=i,
|
||||||
) for i in np.arange(int(1e4))
|
) for i in np.arange(int(1e4))
|
||||||
]
|
]
|
||||||
batch0 = Batch(
|
batch0 = Batch(
|
||||||
@ -27,8 +27,8 @@ def data():
|
|||||||
b=Batch(
|
b=Batch(
|
||||||
c=np.ones((1, ), dtype=np.float64),
|
c=np.ones((1, ), dtype=np.float64),
|
||||||
d=torch.ones((3, 3, 3), dtype=torch.float32),
|
d=torch.ones((3, 3, 3), dtype=torch.float32),
|
||||||
e=list(range(3))
|
e=list(range(3)),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
batchs1 = [copy.deepcopy(batch0) for _ in np.arange(1e4)]
|
batchs1 = [copy.deepcopy(batch0) for _ in np.arange(1e4)]
|
||||||
batchs2 = [copy.deepcopy(batch0) for _ in np.arange(1e4)]
|
batchs2 = [copy.deepcopy(batch0) for _ in np.arange(1e4)]
|
||||||
@ -39,13 +39,13 @@ def data():
|
|||||||
indexs = np.random.choice(batch_len, size=batch_len // 10, replace=False)
|
indexs = np.random.choice(batch_len, size=batch_len // 10, replace=False)
|
||||||
slice_dict = {
|
slice_dict = {
|
||||||
'obs': [np.arange(20) for _ in np.arange(batch_len // 10)],
|
'obs': [np.arange(20) for _ in np.arange(batch_len // 10)],
|
||||||
'reward': np.arange(batch_len // 10)
|
'reward': np.arange(batch_len // 10),
|
||||||
}
|
}
|
||||||
dict_set = [
|
dict_set = [
|
||||||
{
|
{
|
||||||
'obs': np.arange(20),
|
'obs': np.arange(20),
|
||||||
'info': "this is info",
|
'info': "this is info",
|
||||||
'reward': 0
|
'reward': 0,
|
||||||
} for _ in np.arange(1e2)
|
} for _ in np.arange(1e2)
|
||||||
]
|
]
|
||||||
batch4 = Batch(
|
batch4 = Batch(
|
||||||
@ -53,11 +53,11 @@ def data():
|
|||||||
b=Batch(
|
b=Batch(
|
||||||
c=np.ones((1, ), dtype=np.float64),
|
c=np.ones((1, ), dtype=np.float64),
|
||||||
d=torch.ones((1000, 1000), dtype=torch.float32),
|
d=torch.ones((1000, 1000), dtype=torch.float32),
|
||||||
e=np.arange(1000)
|
e=np.arange(1000),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Initialised")
|
print("Initialized")
|
||||||
return {
|
return {
|
||||||
'batch_set': batch_set,
|
'batch_set': batch_set,
|
||||||
'batch0': batch0,
|
'batch0': batch0,
|
||||||
@ -67,7 +67,7 @@ def data():
|
|||||||
'indexs': indexs,
|
'indexs': indexs,
|
||||||
'dict_set': dict_set,
|
'dict_set': dict_set,
|
||||||
'slice_dict': slice_dict,
|
'slice_dict': slice_dict,
|
||||||
'batch4': batch4
|
'batch4': batch4,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -106,7 +106,7 @@ def test_set_attr(data):
|
|||||||
|
|
||||||
def test_numpy_torch_convert(data):
|
def test_numpy_torch_convert(data):
|
||||||
"""Test conversion between numpy and torch."""
|
"""Test conversion between numpy and torch."""
|
||||||
for _ in np.arange(1e5):
|
for _ in np.arange(1e4): # not sure what's wrong in torch==1.10.0
|
||||||
data['batch4'].to_torch()
|
data['batch4'].to_torch()
|
||||||
data['batch4'].to_numpy()
|
data['batch4'].to_numpy()
|
||||||
|
|
||||||
|
51
tianshou/env/venvs.py
vendored
51
tianshou/env/venvs.py
vendored
@ -130,17 +130,56 @@ class BaseVectorEnv(gym.Env):
|
|||||||
if key in [
|
if key in [
|
||||||
'metadata', 'reward_range', 'spec', 'action_space', 'observation_space'
|
'metadata', 'reward_range', 'spec', 'action_space', 'observation_space'
|
||||||
]: # reserved keys in gym.Env
|
]: # reserved keys in gym.Env
|
||||||
return self.__getattr__(key)
|
return self.get_env_attr(key)
|
||||||
else:
|
else:
|
||||||
return super().__getattribute__(key)
|
return super().__getattribute__(key)
|
||||||
|
|
||||||
def __getattr__(self, key: str) -> List[Any]:
|
def get_env_attr(
|
||||||
"""Fetch a list of env attributes.
|
self,
|
||||||
|
key: str,
|
||||||
|
id: Optional[Union[int, List[int], np.ndarray]] = None
|
||||||
|
) -> List[Any]:
|
||||||
|
"""Get an attribute from the underlying environments.
|
||||||
|
|
||||||
This function tries to retrieve an attribute from each individual wrapped
|
If id is an int, retrieve the attribute denoted by key from the environment
|
||||||
environment, if it does not belong to the wrapping vector environment class.
|
underlying the worker at index id. The result is returned as a list with one
|
||||||
|
element. Otherwise, retrieve the attribute for all workers at indices id and
|
||||||
|
return a list that is ordered correspondingly to id.
|
||||||
|
|
||||||
|
:param str key: The key of the desired attribute.
|
||||||
|
:param id: Indice(s) of the desired worker(s). Default to None for all env_id.
|
||||||
|
|
||||||
|
:return list: The list of environment attributes.
|
||||||
"""
|
"""
|
||||||
return [getattr(worker, key) for worker in self.workers]
|
self._assert_is_not_closed()
|
||||||
|
id = self._wrap_id(id)
|
||||||
|
if self.is_async:
|
||||||
|
self._assert_id(id)
|
||||||
|
|
||||||
|
return [self.workers[j].get_env_attr(key) for j in id]
|
||||||
|
|
||||||
|
def set_env_attr(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
value: Any,
|
||||||
|
id: Optional[Union[int, List[int], np.ndarray]] = None
|
||||||
|
) -> None:
|
||||||
|
"""Set an attribute in the underlying environments.
|
||||||
|
|
||||||
|
If id is an int, set the attribute denoted by key from the environment
|
||||||
|
underlying the worker at index id to value.
|
||||||
|
Otherwise, set the attribute for all workers at indices id.
|
||||||
|
|
||||||
|
:param str key: The key of the desired attribute.
|
||||||
|
:param Any value: The new value of the attribute.
|
||||||
|
:param id: Indice(s) of the desired worker(s). Default to None for all env_id.
|
||||||
|
"""
|
||||||
|
self._assert_is_not_closed()
|
||||||
|
id = self._wrap_id(id)
|
||||||
|
if self.is_async:
|
||||||
|
self._assert_id(id)
|
||||||
|
for j in id:
|
||||||
|
self.workers[j].set_env_attr(key, value)
|
||||||
|
|
||||||
def _wrap_id(
|
def _wrap_id(
|
||||||
self,
|
self,
|
||||||
|
8
tianshou/env/worker/base.py
vendored
8
tianshou/env/worker/base.py
vendored
@ -12,10 +12,14 @@ class EnvWorker(ABC):
|
|||||||
self._env_fn = env_fn
|
self._env_fn = env_fn
|
||||||
self.is_closed = False
|
self.is_closed = False
|
||||||
self.result: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
|
self.result: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
|
||||||
self.action_space = getattr(self, "action_space") # noqa: B009
|
self.action_space = self.get_env_attr("action_space") # noqa: B009
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __getattr__(self, key: str) -> Any:
|
def get_env_attr(self, key: str) -> Any:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_env_attr(self, key: str, value: Any) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
5
tianshou/env/worker/dummy.py
vendored
5
tianshou/env/worker/dummy.py
vendored
@ -13,9 +13,12 @@ class DummyEnvWorker(EnvWorker):
|
|||||||
self.env = env_fn()
|
self.env = env_fn()
|
||||||
super().__init__(env_fn)
|
super().__init__(env_fn)
|
||||||
|
|
||||||
def __getattr__(self, key: str) -> Any:
|
def get_env_attr(self, key: str) -> Any:
|
||||||
return getattr(self.env, key)
|
return getattr(self.env, key)
|
||||||
|
|
||||||
|
def set_env_attr(self, key: str, value: Any) -> None:
|
||||||
|
setattr(self.env, key, value)
|
||||||
|
|
||||||
def reset(self) -> Any:
|
def reset(self) -> Any:
|
||||||
return self.env.reset()
|
return self.env.reset()
|
||||||
|
|
||||||
|
18
tianshou/env/worker/ray.py
vendored
18
tianshou/env/worker/ray.py
vendored
@ -11,15 +11,27 @@ except ImportError:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _SetAttrWrapper(gym.Wrapper):
|
||||||
|
|
||||||
|
def set_env_attr(self, key: str, value: Any) -> None:
|
||||||
|
setattr(self.env, key, value)
|
||||||
|
|
||||||
|
def get_env_attr(self, key: str) -> Any:
|
||||||
|
return getattr(self.env, key)
|
||||||
|
|
||||||
|
|
||||||
class RayEnvWorker(EnvWorker):
|
class RayEnvWorker(EnvWorker):
|
||||||
"""Ray worker used in RayVectorEnv."""
|
"""Ray worker used in RayVectorEnv."""
|
||||||
|
|
||||||
def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
|
def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
|
||||||
self.env = ray.remote(gym.Wrapper).options(num_cpus=0).remote(env_fn())
|
self.env = ray.remote(_SetAttrWrapper).options(num_cpus=0).remote(env_fn())
|
||||||
super().__init__(env_fn)
|
super().__init__(env_fn)
|
||||||
|
|
||||||
def __getattr__(self, key: str) -> Any:
|
def get_env_attr(self, key: str) -> Any:
|
||||||
return ray.get(self.env.__getattr__.remote(key))
|
return ray.get(self.env.get_env_attr.remote(key))
|
||||||
|
|
||||||
|
def set_env_attr(self, key: str, value: Any) -> None:
|
||||||
|
ray.get(self.env.set_env_attr.remote(key, value))
|
||||||
|
|
||||||
def reset(self) -> Any:
|
def reset(self) -> Any:
|
||||||
return ray.get(self.env.reset.remote())
|
return ray.get(self.env.reset.remote())
|
||||||
|
7
tianshou/env/worker/subproc.py
vendored
7
tianshou/env/worker/subproc.py
vendored
@ -107,6 +107,8 @@ def _worker(
|
|||||||
p.send(env.seed(data) if hasattr(env, "seed") else None)
|
p.send(env.seed(data) if hasattr(env, "seed") else None)
|
||||||
elif cmd == "getattr":
|
elif cmd == "getattr":
|
||||||
p.send(getattr(env, data) if hasattr(env, data) else None)
|
p.send(getattr(env, data) if hasattr(env, data) else None)
|
||||||
|
elif cmd == "setattr":
|
||||||
|
setattr(env, data["key"], data["value"])
|
||||||
else:
|
else:
|
||||||
p.close()
|
p.close()
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -140,10 +142,13 @@ class SubprocEnvWorker(EnvWorker):
|
|||||||
self.child_remote.close()
|
self.child_remote.close()
|
||||||
super().__init__(env_fn)
|
super().__init__(env_fn)
|
||||||
|
|
||||||
def __getattr__(self, key: str) -> Any:
|
def get_env_attr(self, key: str) -> Any:
|
||||||
self.parent_remote.send(["getattr", key])
|
self.parent_remote.send(["getattr", key])
|
||||||
return self.parent_remote.recv()
|
return self.parent_remote.recv()
|
||||||
|
|
||||||
|
def set_env_attr(self, key: str, value: Any) -> None:
|
||||||
|
self.parent_remote.send(["setattr", {"key": key, "value": value}])
|
||||||
|
|
||||||
def _decode_obs(self) -> Union[dict, tuple, np.ndarray]:
|
def _decode_obs(self) -> Union[dict, tuple, np.ndarray]:
|
||||||
|
|
||||||
def decode_obs(
|
def decode_obs(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user