Fix type check in atari wrapper, solves #1111

This commit is contained in:
Michael Panchenko 2024-04-16 10:52:48 +02:00
parent 60d1ba1c8f
commit 049907d9ab
2 changed files with 21 additions and 15 deletions

View File

@ -83,7 +83,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_args()
def test_dqn(args: argparse.Namespace = get_args()) -> None:
def main(args: argparse.Namespace = get_args()) -> None:
env, train_envs, test_envs = make_atari_env(
args.task,
args.seed,
@ -260,4 +260,4 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
if __name__ == "__main__":
test_dqn(get_args())
main(get_args())

View File

@ -39,6 +39,20 @@ def _parse_reset_result(reset_result: tuple) -> tuple[tuple, dict, bool]:
return reset_result, {}, contains_info
def get_space_dtype(obs_space: gym.spaces.Box) -> type[np.floating] | type[np.integer]:
obs_space_dtype: type[np.integer] | type[np.floating]
if np.issubdtype(obs_space.dtype, np.integer):
obs_space_dtype = np.integer
elif np.issubdtype(obs_space.dtype, np.floating):
obs_space_dtype = np.floating
else:
raise TypeError(
f"Unsupported observation space dtype: {obs_space.dtype}. "
f"This might be a bug in tianshou or gymnasium, please report it!",
)
return obs_space_dtype
class NoopResetEnv(gym.Wrapper):
"""Sample initial states by taking random number of no-ops on reset.
@ -199,12 +213,8 @@ class WarpFrame(gym.ObservationWrapper):
super().__init__(env)
self.size = 84
obs_space = env.observation_space
obs_space_dtype: type[np.floating[Any]] | type[np.integer[Any]]
if np.issubdtype(type(obs_space.dtype), np.integer):
obs_space_dtype = np.integer
elif np.issubdtype(type(obs_space.dtype), np.floating):
obs_space_dtype = np.floating
assert isinstance(obs_space, gym.spaces.Box)
obs_space_dtype = get_space_dtype(obs_space)
self.observation_space = gym.spaces.Box(
low=np.min(obs_space.low),
high=np.max(obs_space.high),
@ -273,15 +283,11 @@ class FrameStack(gym.Wrapper):
obs_space_shape = env.observation_space.shape
assert obs_space_shape is not None
shape = (n_frames, *obs_space_shape)
assert isinstance(env.observation_space, gym.spaces.Box)
obs_space_dtype: type[np.floating[Any]] | type[np.integer[Any]]
if np.issubdtype(type(obs_space.dtype), np.integer):
obs_space_dtype = np.integer
elif np.issubdtype(type(obs_space.dtype), np.floating):
obs_space_dtype = np.floating
assert isinstance(obs_space, gym.spaces.Box)
obs_space_dtype = get_space_dtype(obs_space)
self.observation_space = gym.spaces.Box(
low=np.min(env.observation_space.low),
high=np.max(env.observation_space.high),
low=np.min(obs_space.low),
high=np.max(obs_space.high),
shape=shape,
dtype=obs_space_dtype,
)