Implements set_env_attr and get_env_attr for vector environments (#478)
close #473
This commit is contained in:
parent
098d466467
commit
8f19a86966
@ -166,10 +166,25 @@ def test_vecenv(size=10, num=8, sleep=0.001):
|
||||
for i, v in enumerate(venv):
|
||||
print(f'{type(v)}: {t[i]:.6f}s')
|
||||
|
||||
def assert_get(v, expected):
|
||||
assert v.get_env_attr("size") == expected
|
||||
assert v.get_env_attr("size", id=0) == [expected[0]]
|
||||
assert v.get_env_attr("size", id=[0, 1, 2]) == expected[:3]
|
||||
|
||||
for v in venv:
|
||||
assert v.size == list(range(size, size + num))
|
||||
assert_get(v, list(range(size, size + num)))
|
||||
assert v.env_num == num
|
||||
assert v.action_space == [Discrete(2)] * num
|
||||
|
||||
v.set_env_attr("size", 0)
|
||||
assert_get(v, [0] * num)
|
||||
|
||||
v.set_env_attr("size", 1, 0)
|
||||
assert_get(v, [1] + [0] * (num - 1))
|
||||
|
||||
v.set_env_attr("size", 2, [1, 2, 3])
|
||||
assert_get(v, [1] + [2] * 3 + [0] * (num - 4))
|
||||
|
||||
for v in venv:
|
||||
v.close()
|
||||
|
||||
|
@ -10,7 +10,7 @@ from tianshou.data import Batch
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def data():
|
||||
print("Initialising data...")
|
||||
print("Initializing data...")
|
||||
np.random.seed(0)
|
||||
batch_set = [
|
||||
Batch(
|
||||
@ -19,7 +19,7 @@ def data():
|
||||
'b1': (3.14, 3.14),
|
||||
'b2': np.arange(1e3)
|
||||
},
|
||||
c=i
|
||||
c=i,
|
||||
) for i in np.arange(int(1e4))
|
||||
]
|
||||
batch0 = Batch(
|
||||
@ -27,8 +27,8 @@ def data():
|
||||
b=Batch(
|
||||
c=np.ones((1, ), dtype=np.float64),
|
||||
d=torch.ones((3, 3, 3), dtype=torch.float32),
|
||||
e=list(range(3))
|
||||
)
|
||||
e=list(range(3)),
|
||||
),
|
||||
)
|
||||
batchs1 = [copy.deepcopy(batch0) for _ in np.arange(1e4)]
|
||||
batchs2 = [copy.deepcopy(batch0) for _ in np.arange(1e4)]
|
||||
@ -39,13 +39,13 @@ def data():
|
||||
indexs = np.random.choice(batch_len, size=batch_len // 10, replace=False)
|
||||
slice_dict = {
|
||||
'obs': [np.arange(20) for _ in np.arange(batch_len // 10)],
|
||||
'reward': np.arange(batch_len // 10)
|
||||
'reward': np.arange(batch_len // 10),
|
||||
}
|
||||
dict_set = [
|
||||
{
|
||||
'obs': np.arange(20),
|
||||
'info': "this is info",
|
||||
'reward': 0
|
||||
'reward': 0,
|
||||
} for _ in np.arange(1e2)
|
||||
]
|
||||
batch4 = Batch(
|
||||
@ -53,11 +53,11 @@ def data():
|
||||
b=Batch(
|
||||
c=np.ones((1, ), dtype=np.float64),
|
||||
d=torch.ones((1000, 1000), dtype=torch.float32),
|
||||
e=np.arange(1000)
|
||||
)
|
||||
e=np.arange(1000),
|
||||
),
|
||||
)
|
||||
|
||||
print("Initialised")
|
||||
print("Initialized")
|
||||
return {
|
||||
'batch_set': batch_set,
|
||||
'batch0': batch0,
|
||||
@ -67,7 +67,7 @@ def data():
|
||||
'indexs': indexs,
|
||||
'dict_set': dict_set,
|
||||
'slice_dict': slice_dict,
|
||||
'batch4': batch4
|
||||
'batch4': batch4,
|
||||
}
|
||||
|
||||
|
||||
@ -106,7 +106,7 @@ def test_set_attr(data):
|
||||
|
||||
def test_numpy_torch_convert(data):
|
||||
"""Test conversion between numpy and torch."""
|
||||
for _ in np.arange(1e5):
|
||||
for _ in np.arange(1e4): # not sure what's wrong in torch==1.10.0
|
||||
data['batch4'].to_torch()
|
||||
data['batch4'].to_numpy()
|
||||
|
||||
|
51
tianshou/env/venvs.py
vendored
51
tianshou/env/venvs.py
vendored
@ -130,17 +130,56 @@ class BaseVectorEnv(gym.Env):
|
||||
if key in [
|
||||
'metadata', 'reward_range', 'spec', 'action_space', 'observation_space'
|
||||
]: # reserved keys in gym.Env
|
||||
return self.__getattr__(key)
|
||||
return self.get_env_attr(key)
|
||||
else:
|
||||
return super().__getattribute__(key)
|
||||
|
||||
def __getattr__(self, key: str) -> List[Any]:
|
||||
"""Fetch a list of env attributes.
|
||||
def get_env_attr(
|
||||
self,
|
||||
key: str,
|
||||
id: Optional[Union[int, List[int], np.ndarray]] = None
|
||||
) -> List[Any]:
|
||||
"""Get an attribute from the underlying environments.
|
||||
|
||||
This function tries to retrieve an attribute from each individual wrapped
|
||||
environment, if it does not belong to the wrapping vector environment class.
|
||||
If id is an int, retrieve the attribute denoted by key from the environment
|
||||
underlying the worker at index id. The result is returned as a list with one
|
||||
element. Otherwise, retrieve the attribute for all workers at indices id and
|
||||
return a list that is ordered correspondingly to id.
|
||||
|
||||
:param str key: The key of the desired attribute.
|
||||
:param id: Indice(s) of the desired worker(s). Default to None for all env_id.
|
||||
|
||||
:return list: The list of environment attributes.
|
||||
"""
|
||||
return [getattr(worker, key) for worker in self.workers]
|
||||
self._assert_is_not_closed()
|
||||
id = self._wrap_id(id)
|
||||
if self.is_async:
|
||||
self._assert_id(id)
|
||||
|
||||
return [self.workers[j].get_env_attr(key) for j in id]
|
||||
|
||||
def set_env_attr(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
id: Optional[Union[int, List[int], np.ndarray]] = None
|
||||
) -> None:
|
||||
"""Set an attribute in the underlying environments.
|
||||
|
||||
If id is an int, set the attribute denoted by key from the environment
|
||||
underlying the worker at index id to value.
|
||||
Otherwise, set the attribute for all workers at indices id.
|
||||
|
||||
:param str key: The key of the desired attribute.
|
||||
:param Any value: The new value of the attribute.
|
||||
:param id: Indice(s) of the desired worker(s). Default to None for all env_id.
|
||||
"""
|
||||
self._assert_is_not_closed()
|
||||
id = self._wrap_id(id)
|
||||
if self.is_async:
|
||||
self._assert_id(id)
|
||||
for j in id:
|
||||
self.workers[j].set_env_attr(key, value)
|
||||
|
||||
def _wrap_id(
|
||||
self,
|
||||
|
8
tianshou/env/worker/base.py
vendored
8
tianshou/env/worker/base.py
vendored
@ -12,10 +12,14 @@ class EnvWorker(ABC):
|
||||
self._env_fn = env_fn
|
||||
self.is_closed = False
|
||||
self.result: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
|
||||
self.action_space = getattr(self, "action_space") # noqa: B009
|
||||
self.action_space = self.get_env_attr("action_space") # noqa: B009
|
||||
|
||||
@abstractmethod
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
def get_env_attr(self, key: str) -> Any:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_env_attr(self, key: str, value: Any) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
5
tianshou/env/worker/dummy.py
vendored
5
tianshou/env/worker/dummy.py
vendored
@ -13,9 +13,12 @@ class DummyEnvWorker(EnvWorker):
|
||||
self.env = env_fn()
|
||||
super().__init__(env_fn)
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
def get_env_attr(self, key: str) -> Any:
|
||||
return getattr(self.env, key)
|
||||
|
||||
def set_env_attr(self, key: str, value: Any) -> None:
|
||||
setattr(self.env, key, value)
|
||||
|
||||
def reset(self) -> Any:
|
||||
return self.env.reset()
|
||||
|
||||
|
18
tianshou/env/worker/ray.py
vendored
18
tianshou/env/worker/ray.py
vendored
@ -11,15 +11,27 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class _SetAttrWrapper(gym.Wrapper):
|
||||
|
||||
def set_env_attr(self, key: str, value: Any) -> None:
|
||||
setattr(self.env, key, value)
|
||||
|
||||
def get_env_attr(self, key: str) -> Any:
|
||||
return getattr(self.env, key)
|
||||
|
||||
|
||||
class RayEnvWorker(EnvWorker):
|
||||
"""Ray worker used in RayVectorEnv."""
|
||||
|
||||
def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
|
||||
self.env = ray.remote(gym.Wrapper).options(num_cpus=0).remote(env_fn())
|
||||
self.env = ray.remote(_SetAttrWrapper).options(num_cpus=0).remote(env_fn())
|
||||
super().__init__(env_fn)
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
return ray.get(self.env.__getattr__.remote(key))
|
||||
def get_env_attr(self, key: str) -> Any:
|
||||
return ray.get(self.env.get_env_attr.remote(key))
|
||||
|
||||
def set_env_attr(self, key: str, value: Any) -> None:
|
||||
ray.get(self.env.set_env_attr.remote(key, value))
|
||||
|
||||
def reset(self) -> Any:
|
||||
return ray.get(self.env.reset.remote())
|
||||
|
7
tianshou/env/worker/subproc.py
vendored
7
tianshou/env/worker/subproc.py
vendored
@ -107,6 +107,8 @@ def _worker(
|
||||
p.send(env.seed(data) if hasattr(env, "seed") else None)
|
||||
elif cmd == "getattr":
|
||||
p.send(getattr(env, data) if hasattr(env, data) else None)
|
||||
elif cmd == "setattr":
|
||||
setattr(env, data["key"], data["value"])
|
||||
else:
|
||||
p.close()
|
||||
raise NotImplementedError
|
||||
@ -140,10 +142,13 @@ class SubprocEnvWorker(EnvWorker):
|
||||
self.child_remote.close()
|
||||
super().__init__(env_fn)
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
def get_env_attr(self, key: str) -> Any:
|
||||
self.parent_remote.send(["getattr", key])
|
||||
return self.parent_remote.recv()
|
||||
|
||||
def set_env_attr(self, key: str, value: Any) -> None:
|
||||
self.parent_remote.send(["setattr", {"key": key, "value": value}])
|
||||
|
||||
def _decode_obs(self) -> Union[dict, tuple, np.ndarray]:
|
||||
|
||||
def decode_obs(
|
||||
|
Loading…
x
Reference in New Issue
Block a user