fix a problem of the atari dqn example (#861)

This commit is contained in:
Zhenjie Zhao 2023-04-30 23:44:27 +08:00 committed by GitHub
parent 7ce62a6ad4
commit f8808d236f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -135,13 +135,15 @@ class EpisodicLifeEnv(gym.Wrapper):
return obs, reward, term, trunc, info
return obs, reward, done, info
def reset(self):
def reset(self, **kwargs):
"""Calls the Gym environment reset, only when lives are exhausted. This
way all states are still reachable even though lives are episodic, and
the learner need not know about any of this behind-the-scenes.
"""
if self.was_real_done:
obs, info, self._return_info = _parse_reset_result(self.env.reset())
obs, info, self._return_info = _parse_reset_result(
self.env.reset(**kwargs)
)
else:
# no-op step to advance from terminal/lost life state
step_result = self.env.step(0)
@ -165,8 +167,8 @@ class FireResetEnv(gym.Wrapper):
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3
def reset(self):
_, _, return_info = _parse_reset_result(self.env.reset())
def reset(self, **kwargs):
_, _, return_info = _parse_reset_result(self.env.reset(**kwargs))
obs = self.env.step(1)[0]
return (obs, {}) if return_info else obs
@ -247,8 +249,8 @@ class FrameStack(gym.Wrapper):
dtype=env.observation_space.dtype
)
def reset(self):
obs, info, return_info = _parse_reset_result(self.env.reset())
def reset(self, **kwargs):
obs, info, return_info = _parse_reset_result(self.env.reset(**kwargs))
for _ in range(self.n_frames):
self.frames.append(obs)
return (self._get_ob(), info) if return_info else self._get_ob()