From f8808d236f53d86583c74f459365c9cc4184256c Mon Sep 17 00:00:00 2001 From: Zhenjie Zhao <34164876+zhaozj89@users.noreply.github.com> Date: Sun, 30 Apr 2023 23:44:27 +0800 Subject: [PATCH] fix a problem of the atari dqn example (#861) --- examples/atari/atari_wrapper.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index fddbfb0..2ec398d 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -135,13 +135,15 @@ class EpisodicLifeEnv(gym.Wrapper): return obs, reward, term, trunc, info return obs, reward, done, info - def reset(self): + def reset(self, **kwargs): """Calls the Gym environment reset, only when lives are exhausted. This way all states are still reachable even though lives are episodic, and the learner need not know about any of this behind-the-scenes. """ if self.was_real_done: - obs, info, self._return_info = _parse_reset_result(self.env.reset()) + obs, info, self._return_info = _parse_reset_result( + self.env.reset(**kwargs) + ) else: # no-op step to advance from terminal/lost life state step_result = self.env.step(0) @@ -165,8 +167,8 @@ class FireResetEnv(gym.Wrapper): assert env.unwrapped.get_action_meanings()[1] == 'FIRE' assert len(env.unwrapped.get_action_meanings()) >= 3 - def reset(self): - _, _, return_info = _parse_reset_result(self.env.reset()) + def reset(self, **kwargs): + _, _, return_info = _parse_reset_result(self.env.reset(**kwargs)) obs = self.env.step(1)[0] return (obs, {}) if return_info else obs @@ -247,8 +249,8 @@ class FrameStack(gym.Wrapper): dtype=env.observation_space.dtype ) - def reset(self): - obs, info, return_info = _parse_reset_result(self.env.reset()) + def reset(self, **kwargs): + obs, info, return_info = _parse_reset_result(self.env.reset(**kwargs)) for _ in range(self.n_frames): self.frames.append(obs) return (self._get_ob(), info) if return_info else self._get_ob()