fix a problem of the atari dqn example (#861)

This commit is contained in:
Zhenjie Zhao 2023-04-30 23:44:27 +08:00 committed by GitHub
parent 7ce62a6ad4
commit f8808d236f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -135,13 +135,15 @@ class EpisodicLifeEnv(gym.Wrapper):
return obs, reward, term, trunc, info return obs, reward, term, trunc, info
return obs, reward, done, info return obs, reward, done, info
def reset(self): def reset(self, **kwargs):
"""Calls the Gym environment reset, only when lives are exhausted. This """Calls the Gym environment reset, only when lives are exhausted. This
way all states are still reachable even though lives are episodic, and way all states are still reachable even though lives are episodic, and
the learner need not know about any of this behind-the-scenes. the learner need not know about any of this behind-the-scenes.
""" """
if self.was_real_done: if self.was_real_done:
obs, info, self._return_info = _parse_reset_result(self.env.reset()) obs, info, self._return_info = _parse_reset_result(
self.env.reset(**kwargs)
)
else: else:
# no-op step to advance from terminal/lost life state # no-op step to advance from terminal/lost life state
step_result = self.env.step(0) step_result = self.env.step(0)
@ -165,8 +167,8 @@ class FireResetEnv(gym.Wrapper):
assert env.unwrapped.get_action_meanings()[1] == 'FIRE' assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3 assert len(env.unwrapped.get_action_meanings()) >= 3
def reset(self): def reset(self, **kwargs):
_, _, return_info = _parse_reset_result(self.env.reset()) _, _, return_info = _parse_reset_result(self.env.reset(**kwargs))
obs = self.env.step(1)[0] obs = self.env.step(1)[0]
return (obs, {}) if return_info else obs return (obs, {}) if return_info else obs
@ -247,8 +249,8 @@ class FrameStack(gym.Wrapper):
dtype=env.observation_space.dtype dtype=env.observation_space.dtype
) )
def reset(self): def reset(self, **kwargs):
obs, info, return_info = _parse_reset_result(self.env.reset()) obs, info, return_info = _parse_reset_result(self.env.reset(**kwargs))
for _ in range(self.n_frames): for _ in range(self.n_frames):
self.frames.append(obs) self.frames.append(obs)
return (self._get_ob(), info) if return_info else self._get_ob() return (self._get_ob(), info) if return_info else self._get_ob()