Fix venv wrapper reset retval error with gym env (#712)
* Fix venv wrapper reset retval error with gym env * fix lint
This commit is contained in:
parent
f270e88461
commit
0f59e38b12
@ -18,7 +18,7 @@ from tianshou.env import (
|
|||||||
)
|
)
|
||||||
from tianshou.utils import RunningMeanStd
|
from tianshou.utils import RunningMeanStd
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
from env import MyTestEnv, NXEnv
|
from env import MyTestEnv, NXEnv
|
||||||
else: # pytest
|
else: # pytest
|
||||||
from test.base.env import MyTestEnv, NXEnv
|
from test.base.env import MyTestEnv, NXEnv
|
||||||
@ -80,7 +80,7 @@ def test_async_env(size=10000, num=8, sleep=0.1):
|
|||||||
spent_time = time.time()
|
spent_time = time.time()
|
||||||
while current_idx_start < len(action_list):
|
while current_idx_start < len(action_list):
|
||||||
A, B, C, D = v.step(action=act, id=env_ids)
|
A, B, C, D = v.step(action=act, id=env_ids)
|
||||||
b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D})
|
b = Batch({"obs": A, "rew": B, "done": C, "info": D})
|
||||||
env_ids = b.info.env_id
|
env_ids = b.info.env_id
|
||||||
o.append(b)
|
o.append(b)
|
||||||
current_idx_start += len(act)
|
current_idx_start += len(act)
|
||||||
@ -175,7 +175,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
|
|||||||
for info in infos:
|
for info in infos:
|
||||||
assert recurse_comp(infos[0], info)
|
assert recurse_comp(infos[0], info)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
t = [0] * len(venv)
|
t = [0] * len(venv)
|
||||||
for i, e in enumerate(venv):
|
for i, e in enumerate(venv):
|
||||||
t[i] = time.time()
|
t[i] = time.time()
|
||||||
@ -186,7 +186,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
|
|||||||
e.reset(np.where(done)[0])
|
e.reset(np.where(done)[0])
|
||||||
t[i] = time.time() - t[i]
|
t[i] = time.time() - t[i]
|
||||||
for i, v in enumerate(venv):
|
for i, v in enumerate(venv):
|
||||||
print(f'{type(v)}: {t[i]:.6f}s')
|
print(f"{type(v)}: {t[i]:.6f}s")
|
||||||
|
|
||||||
def assert_get(v, expected):
|
def assert_get(v, expected):
|
||||||
assert v.get_env_attr("size") == expected
|
assert v.get_env_attr("size") == expected
|
||||||
@ -242,6 +242,19 @@ def test_env_reset_optional_kwargs(size=10000, num=8):
|
|||||||
assert isinstance(info[0], dict)
|
assert isinstance(info[0], dict)
|
||||||
|
|
||||||
|
|
||||||
|
def test_venv_wrapper_gym(num_envs: int = 4):
|
||||||
|
# Issue 697
|
||||||
|
envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(num_envs)])
|
||||||
|
envs = VectorEnvNormObs(envs)
|
||||||
|
obs_ref = envs.reset(return_info=False)
|
||||||
|
obs, info = envs.reset(return_info=True)
|
||||||
|
assert isinstance(obs_ref, np.ndarray)
|
||||||
|
assert isinstance(obs, np.ndarray)
|
||||||
|
assert isinstance(info, list)
|
||||||
|
assert isinstance(info[0], dict)
|
||||||
|
assert obs_ref.shape[0] == obs.shape[0] == len(info) == num_envs
|
||||||
|
|
||||||
|
|
||||||
def run_align_norm_obs(raw_env, train_env, test_env, action_list):
|
def run_align_norm_obs(raw_env, train_env, test_env, action_list):
|
||||||
eps = np.finfo(np.float32).eps.item()
|
eps = np.finfo(np.float32).eps.item()
|
||||||
raw_obs, train_obs = [raw_env.reset()], [train_env.reset()]
|
raw_obs, train_obs = [raw_env.reset()], [train_env.reset()]
|
||||||
@ -309,7 +322,7 @@ def test_gym_wrappers():
|
|||||||
# check conversion is working properly for a batch of actions
|
# check conversion is working properly for a batch of actions
|
||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
env_m.action(np.array([env_m.action_space.nvec - 1] * bsz)),
|
env_m.action(np.array([env_m.action_space.nvec - 1] * bsz)),
|
||||||
np.array([original_act] * bsz)
|
np.array([original_act] * bsz),
|
||||||
)
|
)
|
||||||
# convert multidiscrete with different action number per
|
# convert multidiscrete with different action number per
|
||||||
# dimension to discrete action space
|
# dimension to discrete action space
|
||||||
@ -321,7 +334,7 @@ def test_gym_wrappers():
|
|||||||
# check conversion is working properly for a batch of actions
|
# check conversion is working properly for a batch of actions
|
||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
env_d.action(np.array([env_d.action_space.n - 1] * bsz)),
|
env_d.action(np.array([env_d.action_space.n - 1] * bsz)),
|
||||||
np.array([env_m.action_space.nvec - 1] * bsz)
|
np.array([env_m.action_space.nvec - 1] * bsz),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -352,9 +365,11 @@ def test_venv_wrapper_envpool_gym_reset_return_info():
|
|||||||
assert v.shape[0] == num_envs
|
assert v.shape[0] == num_envs
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
test_venv_norm_obs()
|
test_venv_norm_obs()
|
||||||
|
test_venv_wrapper_gym()
|
||||||
test_venv_wrapper_envpool()
|
test_venv_wrapper_envpool()
|
||||||
|
test_venv_wrapper_envpool_gym_reset_return_info()
|
||||||
test_env_obs_dtype()
|
test_env_obs_dtype()
|
||||||
test_vecenv()
|
test_vecenv()
|
||||||
test_attr_unwrapped()
|
test_attr_unwrapped()
|
||||||
|
@ -135,7 +135,7 @@ class Collector(object):
|
|||||||
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
|
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
|
||||||
rval = self.env.reset(**gym_reset_kwargs)
|
rval = self.env.reset(**gym_reset_kwargs)
|
||||||
returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
|
returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
|
||||||
isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore
|
isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
|
||||||
)
|
)
|
||||||
if returns_info:
|
if returns_info:
|
||||||
obs, info = rval
|
obs, info = rval
|
||||||
@ -173,7 +173,7 @@ class Collector(object):
|
|||||||
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
|
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
|
||||||
rval = self.env.reset(global_ids, **gym_reset_kwargs)
|
rval = self.env.reset(global_ids, **gym_reset_kwargs)
|
||||||
returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
|
returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
|
||||||
isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore
|
isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
|
||||||
)
|
)
|
||||||
if returns_info:
|
if returns_info:
|
||||||
obs_reset, info = rval
|
obs_reset, info = rval
|
||||||
|
20
tianshou/env/venv_wrappers.py
vendored
20
tianshou/env/venv_wrappers.py
vendored
@ -41,7 +41,7 @@ class VectorEnvWrapper(BaseVectorEnv):
|
|||||||
self,
|
self,
|
||||||
id: Optional[Union[int, List[int], np.ndarray]] = None,
|
id: Optional[Union[int, List[int], np.ndarray]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
|
) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:
|
||||||
return self.venv.reset(id, **kwargs)
|
return self.venv.reset(id, **kwargs)
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
@ -84,15 +84,15 @@ class VectorEnvNormObs(VectorEnvWrapper):
|
|||||||
self,
|
self,
|
||||||
id: Optional[Union[int, List[int], np.ndarray]] = None,
|
id: Optional[Union[int, List[int], np.ndarray]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
|
) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:
|
||||||
retval = self.venv.reset(id, **kwargs)
|
rval = self.venv.reset(id, **kwargs)
|
||||||
reset_returns_info = isinstance(
|
returns_info = isinstance(rval, (tuple, list)) and (len(rval) == 2) and (
|
||||||
retval, (tuple, list)
|
isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
|
||||||
) and len(retval) == 2 and isinstance(retval[1], dict)
|
)
|
||||||
if reset_returns_info:
|
if returns_info:
|
||||||
obs, info = retval
|
obs, info = rval
|
||||||
else:
|
else:
|
||||||
obs = retval
|
obs = rval
|
||||||
|
|
||||||
if isinstance(obs, tuple):
|
if isinstance(obs, tuple):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
@ -103,7 +103,7 @@ class VectorEnvNormObs(VectorEnvWrapper):
|
|||||||
if self.obs_rms and self.update_obs_rms:
|
if self.obs_rms and self.update_obs_rms:
|
||||||
self.obs_rms.update(obs)
|
self.obs_rms.update(obs)
|
||||||
obs = self._norm_obs(obs)
|
obs = self._norm_obs(obs)
|
||||||
if reset_returns_info:
|
if returns_info:
|
||||||
return obs, info
|
return obs, info
|
||||||
else:
|
else:
|
||||||
return obs
|
return obs
|
||||||
|
2
tianshou/env/venvs.py
vendored
2
tianshou/env/venvs.py
vendored
@ -185,7 +185,7 @@ class BaseVectorEnv(object):
|
|||||||
self,
|
self,
|
||||||
id: Optional[Union[int, List[int], np.ndarray]] = None,
|
id: Optional[Union[int, List[int], np.ndarray]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
|
) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:
|
||||||
"""Reset the state of some envs and return initial observations.
|
"""Reset the state of some envs and return initial observations.
|
||||||
|
|
||||||
If id is None, reset the state of all the environments and return
|
If id is None, reset the state of all the environments and return
|
||||||
|
Loading…
x
Reference in New Issue
Block a user