Fix venv wrapper reset retval error with gym env (#712)

* Fix venv wrapper reset retval error with gym env

* fix lint
This commit is contained in:
Jiayi Weng 2022-07-31 11:00:38 -07:00 committed by GitHub
parent f270e88461
commit 0f59e38b12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 35 additions and 20 deletions

View File

@ -18,7 +18,7 @@ from tianshou.env import (
) )
from tianshou.utils import RunningMeanStd from tianshou.utils import RunningMeanStd
if __name__ == '__main__': if __name__ == "__main__":
from env import MyTestEnv, NXEnv from env import MyTestEnv, NXEnv
else: # pytest else: # pytest
from test.base.env import MyTestEnv, NXEnv from test.base.env import MyTestEnv, NXEnv
@ -80,7 +80,7 @@ def test_async_env(size=10000, num=8, sleep=0.1):
spent_time = time.time() spent_time = time.time()
while current_idx_start < len(action_list): while current_idx_start < len(action_list):
A, B, C, D = v.step(action=act, id=env_ids) A, B, C, D = v.step(action=act, id=env_ids)
b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D}) b = Batch({"obs": A, "rew": B, "done": C, "info": D})
env_ids = b.info.env_id env_ids = b.info.env_id
o.append(b) o.append(b)
current_idx_start += len(act) current_idx_start += len(act)
@ -175,7 +175,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
for info in infos: for info in infos:
assert recurse_comp(infos[0], info) assert recurse_comp(infos[0], info)
if __name__ == '__main__': if __name__ == "__main__":
t = [0] * len(venv) t = [0] * len(venv)
for i, e in enumerate(venv): for i, e in enumerate(venv):
t[i] = time.time() t[i] = time.time()
@ -186,7 +186,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
e.reset(np.where(done)[0]) e.reset(np.where(done)[0])
t[i] = time.time() - t[i] t[i] = time.time() - t[i]
for i, v in enumerate(venv): for i, v in enumerate(venv):
print(f'{type(v)}: {t[i]:.6f}s') print(f"{type(v)}: {t[i]:.6f}s")
def assert_get(v, expected): def assert_get(v, expected):
assert v.get_env_attr("size") == expected assert v.get_env_attr("size") == expected
@ -242,6 +242,19 @@ def test_env_reset_optional_kwargs(size=10000, num=8):
assert isinstance(info[0], dict) assert isinstance(info[0], dict)
def test_venv_wrapper_gym(num_envs: int = 4):
# Issue 697
envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(num_envs)])
envs = VectorEnvNormObs(envs)
obs_ref = envs.reset(return_info=False)
obs, info = envs.reset(return_info=True)
assert isinstance(obs_ref, np.ndarray)
assert isinstance(obs, np.ndarray)
assert isinstance(info, list)
assert isinstance(info[0], dict)
assert obs_ref.shape[0] == obs.shape[0] == len(info) == num_envs
def run_align_norm_obs(raw_env, train_env, test_env, action_list): def run_align_norm_obs(raw_env, train_env, test_env, action_list):
eps = np.finfo(np.float32).eps.item() eps = np.finfo(np.float32).eps.item()
raw_obs, train_obs = [raw_env.reset()], [train_env.reset()] raw_obs, train_obs = [raw_env.reset()], [train_env.reset()]
@ -309,7 +322,7 @@ def test_gym_wrappers():
# check conversion is working properly for a batch of actions # check conversion is working properly for a batch of actions
np.testing.assert_allclose( np.testing.assert_allclose(
env_m.action(np.array([env_m.action_space.nvec - 1] * bsz)), env_m.action(np.array([env_m.action_space.nvec - 1] * bsz)),
np.array([original_act] * bsz) np.array([original_act] * bsz),
) )
# convert multidiscrete with different action number per # convert multidiscrete with different action number per
# dimension to discrete action space # dimension to discrete action space
@ -321,7 +334,7 @@ def test_gym_wrappers():
# check conversion is working properly for a batch of actions # check conversion is working properly for a batch of actions
np.testing.assert_allclose( np.testing.assert_allclose(
env_d.action(np.array([env_d.action_space.n - 1] * bsz)), env_d.action(np.array([env_d.action_space.n - 1] * bsz)),
np.array([env_m.action_space.nvec - 1] * bsz) np.array([env_m.action_space.nvec - 1] * bsz),
) )
@ -352,9 +365,11 @@ def test_venv_wrapper_envpool_gym_reset_return_info():
assert v.shape[0] == num_envs assert v.shape[0] == num_envs
if __name__ == '__main__': if __name__ == "__main__":
test_venv_norm_obs() test_venv_norm_obs()
test_venv_wrapper_gym()
test_venv_wrapper_envpool() test_venv_wrapper_envpool()
test_venv_wrapper_envpool_gym_reset_return_info()
test_env_obs_dtype() test_env_obs_dtype()
test_vecenv() test_vecenv()
test_attr_unwrapped() test_attr_unwrapped()

View File

@ -135,7 +135,7 @@ class Collector(object):
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {} gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
rval = self.env.reset(**gym_reset_kwargs) rval = self.env.reset(**gym_reset_kwargs)
returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and ( returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
) )
if returns_info: if returns_info:
obs, info = rval obs, info = rval
@ -173,7 +173,7 @@ class Collector(object):
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {} gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
rval = self.env.reset(global_ids, **gym_reset_kwargs) rval = self.env.reset(global_ids, **gym_reset_kwargs)
returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and ( returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
) )
if returns_info: if returns_info:
obs_reset, info = rval obs_reset, info = rval

View File

@ -41,7 +41,7 @@ class VectorEnvWrapper(BaseVectorEnv):
self, self,
id: Optional[Union[int, List[int], np.ndarray]] = None, id: Optional[Union[int, List[int], np.ndarray]] = None,
**kwargs: Any, **kwargs: Any,
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]: ) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:
return self.venv.reset(id, **kwargs) return self.venv.reset(id, **kwargs)
def step( def step(
@ -84,15 +84,15 @@ class VectorEnvNormObs(VectorEnvWrapper):
self, self,
id: Optional[Union[int, List[int], np.ndarray]] = None, id: Optional[Union[int, List[int], np.ndarray]] = None,
**kwargs: Any, **kwargs: Any,
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]: ) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:
retval = self.venv.reset(id, **kwargs) rval = self.venv.reset(id, **kwargs)
reset_returns_info = isinstance( returns_info = isinstance(rval, (tuple, list)) and (len(rval) == 2) and (
retval, (tuple, list) isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
) and len(retval) == 2 and isinstance(retval[1], dict) )
if reset_returns_info: if returns_info:
obs, info = retval obs, info = rval
else: else:
obs = retval obs = rval
if isinstance(obs, tuple): if isinstance(obs, tuple):
raise TypeError( raise TypeError(
@ -103,7 +103,7 @@ class VectorEnvNormObs(VectorEnvWrapper):
if self.obs_rms and self.update_obs_rms: if self.obs_rms and self.update_obs_rms:
self.obs_rms.update(obs) self.obs_rms.update(obs)
obs = self._norm_obs(obs) obs = self._norm_obs(obs)
if reset_returns_info: if returns_info:
return obs, info return obs, info
else: else:
return obs return obs

View File

@ -185,7 +185,7 @@ class BaseVectorEnv(object):
self, self,
id: Optional[Union[int, List[int], np.ndarray]] = None, id: Optional[Union[int, List[int], np.ndarray]] = None,
**kwargs: Any, **kwargs: Any,
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]: ) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:
"""Reset the state of some envs and return initial observations. """Reset the state of some envs and return initial observations.
If id is None, reset the state of all the environments and return If id is None, reset the state of all the environments and return