Fix type check in atari wrapper, solves #1111
This commit is contained in:
parent
60d1ba1c8f
commit
049907d9ab
@ -83,7 +83,7 @@ def get_args() -> argparse.Namespace:
|
|||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def test_dqn(args: argparse.Namespace = get_args()) -> None:
|
def main(args: argparse.Namespace = get_args()) -> None:
|
||||||
env, train_envs, test_envs = make_atari_env(
|
env, train_envs, test_envs = make_atari_env(
|
||||||
args.task,
|
args.task,
|
||||||
args.seed,
|
args.seed,
|
||||||
@ -260,4 +260,4 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_dqn(get_args())
|
main(get_args())
|
||||||
|
@ -39,6 +39,20 @@ def _parse_reset_result(reset_result: tuple) -> tuple[tuple, dict, bool]:
|
|||||||
return reset_result, {}, contains_info
|
return reset_result, {}, contains_info
|
||||||
|
|
||||||
|
|
||||||
|
def get_space_dtype(obs_space: gym.spaces.Box) -> type[np.floating] | type[np.integer]:
|
||||||
|
obs_space_dtype: type[np.integer] | type[np.floating]
|
||||||
|
if np.issubdtype(obs_space.dtype, np.integer):
|
||||||
|
obs_space_dtype = np.integer
|
||||||
|
elif np.issubdtype(obs_space.dtype, np.floating):
|
||||||
|
obs_space_dtype = np.floating
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"Unsupported observation space dtype: {obs_space.dtype}. "
|
||||||
|
f"This might be a bug in tianshou or gymnasium, please report it!",
|
||||||
|
)
|
||||||
|
return obs_space_dtype
|
||||||
|
|
||||||
|
|
||||||
class NoopResetEnv(gym.Wrapper):
|
class NoopResetEnv(gym.Wrapper):
|
||||||
"""Sample initial states by taking random number of no-ops on reset.
|
"""Sample initial states by taking random number of no-ops on reset.
|
||||||
|
|
||||||
@ -199,12 +213,8 @@ class WarpFrame(gym.ObservationWrapper):
|
|||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
self.size = 84
|
self.size = 84
|
||||||
obs_space = env.observation_space
|
obs_space = env.observation_space
|
||||||
obs_space_dtype: type[np.floating[Any]] | type[np.integer[Any]]
|
|
||||||
if np.issubdtype(type(obs_space.dtype), np.integer):
|
|
||||||
obs_space_dtype = np.integer
|
|
||||||
elif np.issubdtype(type(obs_space.dtype), np.floating):
|
|
||||||
obs_space_dtype = np.floating
|
|
||||||
assert isinstance(obs_space, gym.spaces.Box)
|
assert isinstance(obs_space, gym.spaces.Box)
|
||||||
|
obs_space_dtype = get_space_dtype(obs_space)
|
||||||
self.observation_space = gym.spaces.Box(
|
self.observation_space = gym.spaces.Box(
|
||||||
low=np.min(obs_space.low),
|
low=np.min(obs_space.low),
|
||||||
high=np.max(obs_space.high),
|
high=np.max(obs_space.high),
|
||||||
@ -273,15 +283,11 @@ class FrameStack(gym.Wrapper):
|
|||||||
obs_space_shape = env.observation_space.shape
|
obs_space_shape = env.observation_space.shape
|
||||||
assert obs_space_shape is not None
|
assert obs_space_shape is not None
|
||||||
shape = (n_frames, *obs_space_shape)
|
shape = (n_frames, *obs_space_shape)
|
||||||
assert isinstance(env.observation_space, gym.spaces.Box)
|
assert isinstance(obs_space, gym.spaces.Box)
|
||||||
obs_space_dtype: type[np.floating[Any]] | type[np.integer[Any]]
|
obs_space_dtype = get_space_dtype(obs_space)
|
||||||
if np.issubdtype(type(obs_space.dtype), np.integer):
|
|
||||||
obs_space_dtype = np.integer
|
|
||||||
elif np.issubdtype(type(obs_space.dtype), np.floating):
|
|
||||||
obs_space_dtype = np.floating
|
|
||||||
self.observation_space = gym.spaces.Box(
|
self.observation_space = gym.spaces.Box(
|
||||||
low=np.min(env.observation_space.low),
|
low=np.min(obs_space.low),
|
||||||
high=np.max(env.observation_space.high),
|
high=np.max(obs_space.high),
|
||||||
shape=shape,
|
shape=shape,
|
||||||
dtype=obs_space_dtype,
|
dtype=obs_space_dtype,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user