diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index f51c7fd..f237c5a 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -83,7 +83,7 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_dqn(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, @@ -260,4 +260,4 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: if __name__ == "__main__": - test_dqn(get_args()) + main(get_args()) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index db1b6b2..a321fb3 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -39,6 +39,20 @@ def _parse_reset_result(reset_result: tuple) -> tuple[tuple, dict, bool]: return reset_result, {}, contains_info +def get_space_dtype(obs_space: gym.spaces.Box) -> type[np.floating] | type[np.integer]: + obs_space_dtype: type[np.integer] | type[np.floating] + if np.issubdtype(obs_space.dtype, np.integer): + obs_space_dtype = np.integer + elif np.issubdtype(obs_space.dtype, np.floating): + obs_space_dtype = np.floating + else: + raise TypeError( + f"Unsupported observation space dtype: {obs_space.dtype}. " + f"This might be a bug in tianshou or gymnasium, please report it!", + ) + return obs_space_dtype + + class NoopResetEnv(gym.Wrapper): """Sample initial states by taking random number of no-ops on reset. @@ -199,12 +213,8 @@ class WarpFrame(gym.ObservationWrapper): super().__init__(env) self.size = 84 obs_space = env.observation_space - obs_space_dtype: type[np.floating[Any]] | type[np.integer[Any]] - if np.issubdtype(type(obs_space.dtype), np.integer): - obs_space_dtype = np.integer - elif np.issubdtype(type(obs_space.dtype), np.floating): - obs_space_dtype = np.floating assert isinstance(obs_space, gym.spaces.Box) + obs_space_dtype = get_space_dtype(obs_space) self.observation_space = gym.spaces.Box( low=np.min(obs_space.low), high=np.max(obs_space.high), @@ -273,15 +283,11 @@ class FrameStack(gym.Wrapper): obs_space_shape = env.observation_space.shape assert obs_space_shape is not None shape = (n_frames, *obs_space_shape) - assert isinstance(env.observation_space, gym.spaces.Box) - obs_space_dtype: type[np.floating[Any]] | type[np.integer[Any]] - if np.issubdtype(type(obs_space.dtype), np.integer): - obs_space_dtype = np.integer - elif np.issubdtype(type(obs_space.dtype), np.floating): - obs_space_dtype = np.floating + assert isinstance(obs_space, gym.spaces.Box) + obs_space_dtype = get_space_dtype(obs_space) self.observation_space = gym.spaces.Box( - low=np.min(env.observation_space.low), - high=np.max(env.observation_space.high), + low=np.min(obs_space.low), + high=np.max(obs_space.high), shape=shape, dtype=obs_space_dtype, )