Fix type check in atari wrapper, solves #1111
This commit is contained in:
parent
60d1ba1c8f
commit
049907d9ab
@ -83,7 +83,7 @@ def get_args() -> argparse.Namespace:
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_dqn(args: argparse.Namespace = get_args()) -> None:
|
||||
def main(args: argparse.Namespace = get_args()) -> None:
|
||||
env, train_envs, test_envs = make_atari_env(
|
||||
args.task,
|
||||
args.seed,
|
||||
@ -260,4 +260,4 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dqn(get_args())
|
||||
main(get_args())
|
||||
|
@ -39,6 +39,20 @@ def _parse_reset_result(reset_result: tuple) -> tuple[tuple, dict, bool]:
|
||||
return reset_result, {}, contains_info
|
||||
|
||||
|
||||
def get_space_dtype(obs_space: gym.spaces.Box) -> type[np.floating] | type[np.integer]:
|
||||
obs_space_dtype: type[np.integer] | type[np.floating]
|
||||
if np.issubdtype(obs_space.dtype, np.integer):
|
||||
obs_space_dtype = np.integer
|
||||
elif np.issubdtype(obs_space.dtype, np.floating):
|
||||
obs_space_dtype = np.floating
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unsupported observation space dtype: {obs_space.dtype}. "
|
||||
f"This might be a bug in tianshou or gymnasium, please report it!",
|
||||
)
|
||||
return obs_space_dtype
|
||||
|
||||
|
||||
class NoopResetEnv(gym.Wrapper):
|
||||
"""Sample initial states by taking random number of no-ops on reset.
|
||||
|
||||
@ -199,12 +213,8 @@ class WarpFrame(gym.ObservationWrapper):
|
||||
super().__init__(env)
|
||||
self.size = 84
|
||||
obs_space = env.observation_space
|
||||
obs_space_dtype: type[np.floating[Any]] | type[np.integer[Any]]
|
||||
if np.issubdtype(type(obs_space.dtype), np.integer):
|
||||
obs_space_dtype = np.integer
|
||||
elif np.issubdtype(type(obs_space.dtype), np.floating):
|
||||
obs_space_dtype = np.floating
|
||||
assert isinstance(obs_space, gym.spaces.Box)
|
||||
obs_space_dtype = get_space_dtype(obs_space)
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=np.min(obs_space.low),
|
||||
high=np.max(obs_space.high),
|
||||
@ -273,15 +283,11 @@ class FrameStack(gym.Wrapper):
|
||||
obs_space_shape = env.observation_space.shape
|
||||
assert obs_space_shape is not None
|
||||
shape = (n_frames, *obs_space_shape)
|
||||
assert isinstance(env.observation_space, gym.spaces.Box)
|
||||
obs_space_dtype: type[np.floating[Any]] | type[np.integer[Any]]
|
||||
if np.issubdtype(type(obs_space.dtype), np.integer):
|
||||
obs_space_dtype = np.integer
|
||||
elif np.issubdtype(type(obs_space.dtype), np.floating):
|
||||
obs_space_dtype = np.floating
|
||||
assert isinstance(obs_space, gym.spaces.Box)
|
||||
obs_space_dtype = get_space_dtype(obs_space)
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=np.min(env.observation_space.low),
|
||||
high=np.max(env.observation_space.high),
|
||||
low=np.min(obs_space.low),
|
||||
high=np.max(obs_space.high),
|
||||
shape=shape,
|
||||
dtype=obs_space_dtype,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user