Fix venv wrapper reset retval error with gym env (#712)

* Fix venv wrapper reset retval error with gym env

* fix lint
This commit is contained in:
Jiayi Weng 2022-07-31 11:00:38 -07:00 committed by GitHub
parent f270e88461
commit 0f59e38b12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 35 additions and 20 deletions

View File

@ -18,7 +18,7 @@ from tianshou.env import (
)
from tianshou.utils import RunningMeanStd
if __name__ == '__main__':
if __name__ == "__main__":
from env import MyTestEnv, NXEnv
else: # pytest
from test.base.env import MyTestEnv, NXEnv
@ -80,7 +80,7 @@ def test_async_env(size=10000, num=8, sleep=0.1):
spent_time = time.time()
while current_idx_start < len(action_list):
A, B, C, D = v.step(action=act, id=env_ids)
b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D})
b = Batch({"obs": A, "rew": B, "done": C, "info": D})
env_ids = b.info.env_id
o.append(b)
current_idx_start += len(act)
@ -175,7 +175,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
for info in infos:
assert recurse_comp(infos[0], info)
if __name__ == '__main__':
if __name__ == "__main__":
t = [0] * len(venv)
for i, e in enumerate(venv):
t[i] = time.time()
@ -186,7 +186,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
e.reset(np.where(done)[0])
t[i] = time.time() - t[i]
for i, v in enumerate(venv):
print(f'{type(v)}: {t[i]:.6f}s')
print(f"{type(v)}: {t[i]:.6f}s")
def assert_get(v, expected):
assert v.get_env_attr("size") == expected
@ -242,6 +242,19 @@ def test_env_reset_optional_kwargs(size=10000, num=8):
assert isinstance(info[0], dict)
def test_venv_wrapper_gym(num_envs: int = 4):
# Issue 697
envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(num_envs)])
envs = VectorEnvNormObs(envs)
obs_ref = envs.reset(return_info=False)
obs, info = envs.reset(return_info=True)
assert isinstance(obs_ref, np.ndarray)
assert isinstance(obs, np.ndarray)
assert isinstance(info, list)
assert isinstance(info[0], dict)
assert obs_ref.shape[0] == obs.shape[0] == len(info) == num_envs
def run_align_norm_obs(raw_env, train_env, test_env, action_list):
eps = np.finfo(np.float32).eps.item()
raw_obs, train_obs = [raw_env.reset()], [train_env.reset()]
@ -309,7 +322,7 @@ def test_gym_wrappers():
# check conversion is working properly for a batch of actions
np.testing.assert_allclose(
env_m.action(np.array([env_m.action_space.nvec - 1] * bsz)),
np.array([original_act] * bsz)
np.array([original_act] * bsz),
)
# convert multidiscrete with different action number per
# dimension to discrete action space
@ -321,7 +334,7 @@ def test_gym_wrappers():
# check conversion is working properly for a batch of actions
np.testing.assert_allclose(
env_d.action(np.array([env_d.action_space.n - 1] * bsz)),
np.array([env_m.action_space.nvec - 1] * bsz)
np.array([env_m.action_space.nvec - 1] * bsz),
)
@ -352,9 +365,11 @@ def test_venv_wrapper_envpool_gym_reset_return_info():
assert v.shape[0] == num_envs
if __name__ == '__main__':
if __name__ == "__main__":
test_venv_norm_obs()
test_venv_wrapper_gym()
test_venv_wrapper_envpool()
test_venv_wrapper_envpool_gym_reset_return_info()
test_env_obs_dtype()
test_vecenv()
test_attr_unwrapped()

View File

@ -135,7 +135,7 @@ class Collector(object):
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
rval = self.env.reset(**gym_reset_kwargs)
returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore
isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
)
if returns_info:
obs, info = rval
@ -173,7 +173,7 @@ class Collector(object):
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
rval = self.env.reset(global_ids, **gym_reset_kwargs)
returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore
isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
)
if returns_info:
obs_reset, info = rval

View File

@ -41,7 +41,7 @@ class VectorEnvWrapper(BaseVectorEnv):
self,
id: Optional[Union[int, List[int], np.ndarray]] = None,
**kwargs: Any,
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:
return self.venv.reset(id, **kwargs)
def step(
@ -84,15 +84,15 @@ class VectorEnvNormObs(VectorEnvWrapper):
self,
id: Optional[Union[int, List[int], np.ndarray]] = None,
**kwargs: Any,
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
retval = self.venv.reset(id, **kwargs)
reset_returns_info = isinstance(
retval, (tuple, list)
) and len(retval) == 2 and isinstance(retval[1], dict)
if reset_returns_info:
obs, info = retval
) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:
rval = self.venv.reset(id, **kwargs)
returns_info = isinstance(rval, (tuple, list)) and (len(rval) == 2) and (
isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
)
if returns_info:
obs, info = rval
else:
obs = retval
obs = rval
if isinstance(obs, tuple):
raise TypeError(
@ -103,7 +103,7 @@ class VectorEnvNormObs(VectorEnvWrapper):
if self.obs_rms and self.update_obs_rms:
self.obs_rms.update(obs)
obs = self._norm_obs(obs)
if reset_returns_info:
if returns_info:
return obs, info
else:
return obs

View File

@ -185,7 +185,7 @@ class BaseVectorEnv(object):
self,
id: Optional[Union[int, List[int], np.ndarray]] = None,
**kwargs: Any,
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:
"""Reset the state of some envs and return initial observations.
If id is None, reset the state of all the environments and return