Fix venv wrapper reset retval error with gym env (#712)
* Fix venv wrapper reset retval error with gym env * fix lint
This commit is contained in:
parent
f270e88461
commit
0f59e38b12
@ -18,7 +18,7 @@ from tianshou.env import (
|
||||
)
|
||||
from tianshou.utils import RunningMeanStd
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
from env import MyTestEnv, NXEnv
|
||||
else: # pytest
|
||||
from test.base.env import MyTestEnv, NXEnv
|
||||
@ -80,7 +80,7 @@ def test_async_env(size=10000, num=8, sleep=0.1):
|
||||
spent_time = time.time()
|
||||
while current_idx_start < len(action_list):
|
||||
A, B, C, D = v.step(action=act, id=env_ids)
|
||||
b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D})
|
||||
b = Batch({"obs": A, "rew": B, "done": C, "info": D})
|
||||
env_ids = b.info.env_id
|
||||
o.append(b)
|
||||
current_idx_start += len(act)
|
||||
@ -175,7 +175,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
|
||||
for info in infos:
|
||||
assert recurse_comp(infos[0], info)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
t = [0] * len(venv)
|
||||
for i, e in enumerate(venv):
|
||||
t[i] = time.time()
|
||||
@ -186,7 +186,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
|
||||
e.reset(np.where(done)[0])
|
||||
t[i] = time.time() - t[i]
|
||||
for i, v in enumerate(venv):
|
||||
print(f'{type(v)}: {t[i]:.6f}s')
|
||||
print(f"{type(v)}: {t[i]:.6f}s")
|
||||
|
||||
def assert_get(v, expected):
|
||||
assert v.get_env_attr("size") == expected
|
||||
@ -242,6 +242,19 @@ def test_env_reset_optional_kwargs(size=10000, num=8):
|
||||
assert isinstance(info[0], dict)
|
||||
|
||||
|
||||
def test_venv_wrapper_gym(num_envs: int = 4):
|
||||
# Issue 697
|
||||
envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(num_envs)])
|
||||
envs = VectorEnvNormObs(envs)
|
||||
obs_ref = envs.reset(return_info=False)
|
||||
obs, info = envs.reset(return_info=True)
|
||||
assert isinstance(obs_ref, np.ndarray)
|
||||
assert isinstance(obs, np.ndarray)
|
||||
assert isinstance(info, list)
|
||||
assert isinstance(info[0], dict)
|
||||
assert obs_ref.shape[0] == obs.shape[0] == len(info) == num_envs
|
||||
|
||||
|
||||
def run_align_norm_obs(raw_env, train_env, test_env, action_list):
|
||||
eps = np.finfo(np.float32).eps.item()
|
||||
raw_obs, train_obs = [raw_env.reset()], [train_env.reset()]
|
||||
@ -309,7 +322,7 @@ def test_gym_wrappers():
|
||||
# check conversion is working properly for a batch of actions
|
||||
np.testing.assert_allclose(
|
||||
env_m.action(np.array([env_m.action_space.nvec - 1] * bsz)),
|
||||
np.array([original_act] * bsz)
|
||||
np.array([original_act] * bsz),
|
||||
)
|
||||
# convert multidiscrete with different action number per
|
||||
# dimension to discrete action space
|
||||
@ -321,7 +334,7 @@ def test_gym_wrappers():
|
||||
# check conversion is working properly for a batch of actions
|
||||
np.testing.assert_allclose(
|
||||
env_d.action(np.array([env_d.action_space.n - 1] * bsz)),
|
||||
np.array([env_m.action_space.nvec - 1] * bsz)
|
||||
np.array([env_m.action_space.nvec - 1] * bsz),
|
||||
)
|
||||
|
||||
|
||||
@ -352,9 +365,11 @@ def test_venv_wrapper_envpool_gym_reset_return_info():
|
||||
assert v.shape[0] == num_envs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_venv_norm_obs()
|
||||
test_venv_wrapper_gym()
|
||||
test_venv_wrapper_envpool()
|
||||
test_venv_wrapper_envpool_gym_reset_return_info()
|
||||
test_env_obs_dtype()
|
||||
test_vecenv()
|
||||
test_attr_unwrapped()
|
||||
|
@ -135,7 +135,7 @@ class Collector(object):
|
||||
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
|
||||
rval = self.env.reset(**gym_reset_kwargs)
|
||||
returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
|
||||
isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore
|
||||
isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
|
||||
)
|
||||
if returns_info:
|
||||
obs, info = rval
|
||||
@ -173,7 +173,7 @@ class Collector(object):
|
||||
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
|
||||
rval = self.env.reset(global_ids, **gym_reset_kwargs)
|
||||
returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
|
||||
isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore
|
||||
isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
|
||||
)
|
||||
if returns_info:
|
||||
obs_reset, info = rval
|
||||
|
20
tianshou/env/venv_wrappers.py
vendored
20
tianshou/env/venv_wrappers.py
vendored
@ -41,7 +41,7 @@ class VectorEnvWrapper(BaseVectorEnv):
|
||||
self,
|
||||
id: Optional[Union[int, List[int], np.ndarray]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
|
||||
) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:
|
||||
return self.venv.reset(id, **kwargs)
|
||||
|
||||
def step(
|
||||
@ -84,15 +84,15 @@ class VectorEnvNormObs(VectorEnvWrapper):
|
||||
self,
|
||||
id: Optional[Union[int, List[int], np.ndarray]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
|
||||
retval = self.venv.reset(id, **kwargs)
|
||||
reset_returns_info = isinstance(
|
||||
retval, (tuple, list)
|
||||
) and len(retval) == 2 and isinstance(retval[1], dict)
|
||||
if reset_returns_info:
|
||||
obs, info = retval
|
||||
) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:
|
||||
rval = self.venv.reset(id, **kwargs)
|
||||
returns_info = isinstance(rval, (tuple, list)) and (len(rval) == 2) and (
|
||||
isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
|
||||
)
|
||||
if returns_info:
|
||||
obs, info = rval
|
||||
else:
|
||||
obs = retval
|
||||
obs = rval
|
||||
|
||||
if isinstance(obs, tuple):
|
||||
raise TypeError(
|
||||
@ -103,7 +103,7 @@ class VectorEnvNormObs(VectorEnvWrapper):
|
||||
if self.obs_rms and self.update_obs_rms:
|
||||
self.obs_rms.update(obs)
|
||||
obs = self._norm_obs(obs)
|
||||
if reset_returns_info:
|
||||
if returns_info:
|
||||
return obs, info
|
||||
else:
|
||||
return obs
|
||||
|
2
tianshou/env/venvs.py
vendored
2
tianshou/env/venvs.py
vendored
@ -185,7 +185,7 @@ class BaseVectorEnv(object):
|
||||
self,
|
||||
id: Optional[Union[int, List[int], np.ndarray]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
|
||||
) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:
|
||||
"""Reset the state of some envs and return initial observations.
|
||||
|
||||
If id is None, reset the state of all the environments and return
|
||||
|
Loading…
x
Reference in New Issue
Block a user