fix a problem of the atari dqn example (#861)
This commit is contained in:
parent
7ce62a6ad4
commit
f8808d236f
@ -135,13 +135,15 @@ class EpisodicLifeEnv(gym.Wrapper):
|
|||||||
return obs, reward, term, trunc, info
|
return obs, reward, term, trunc, info
|
||||||
return obs, reward, done, info
|
return obs, reward, done, info
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, **kwargs):
|
||||||
"""Calls the Gym environment reset, only when lives are exhausted. This
|
"""Calls the Gym environment reset, only when lives are exhausted. This
|
||||||
way all states are still reachable even though lives are episodic, and
|
way all states are still reachable even though lives are episodic, and
|
||||||
the learner need not know about any of this behind-the-scenes.
|
the learner need not know about any of this behind-the-scenes.
|
||||||
"""
|
"""
|
||||||
if self.was_real_done:
|
if self.was_real_done:
|
||||||
obs, info, self._return_info = _parse_reset_result(self.env.reset())
|
obs, info, self._return_info = _parse_reset_result(
|
||||||
|
self.env.reset(**kwargs)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# no-op step to advance from terminal/lost life state
|
# no-op step to advance from terminal/lost life state
|
||||||
step_result = self.env.step(0)
|
step_result = self.env.step(0)
|
||||||
@ -165,8 +167,8 @@ class FireResetEnv(gym.Wrapper):
|
|||||||
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
|
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
|
||||||
assert len(env.unwrapped.get_action_meanings()) >= 3
|
assert len(env.unwrapped.get_action_meanings()) >= 3
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, **kwargs):
|
||||||
_, _, return_info = _parse_reset_result(self.env.reset())
|
_, _, return_info = _parse_reset_result(self.env.reset(**kwargs))
|
||||||
obs = self.env.step(1)[0]
|
obs = self.env.step(1)[0]
|
||||||
return (obs, {}) if return_info else obs
|
return (obs, {}) if return_info else obs
|
||||||
|
|
||||||
@ -247,8 +249,8 @@ class FrameStack(gym.Wrapper):
|
|||||||
dtype=env.observation_space.dtype
|
dtype=env.observation_space.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, **kwargs):
|
||||||
obs, info, return_info = _parse_reset_result(self.env.reset())
|
obs, info, return_info = _parse_reset_result(self.env.reset(**kwargs))
|
||||||
for _ in range(self.n_frames):
|
for _ in range(self.n_frames):
|
||||||
self.frames.append(obs)
|
self.frames.append(obs)
|
||||||
return (self._get_ob(), info) if return_info else self._get_ob()
|
return (self._get_ob(), info) if return_info else self._get_ob()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user