Updated atari wrappers, fixed pre-commit (#781)
This PR addresses #772 (updates Atari wrappers to work with new Gym API) and some additional issues: - Pre-commit was using gitlab for flake8, which as of recently requires authentication -> Replaced with GitHub - Yapf was quietly failing in pre-commit. Changed it such that it fixes formatting in-place - There is an incompatibility between flake8 and yapf where yapf puts binary operators after the line break and flake8 wants it before the break. I added an exception for flake8. - Also require `packaging` in setup.py My changes shouldn't change the behaviour of the wrappers for older versions, but please double check. Idk whether it's just me, but there are always some incompatibilities between yapf and flake8 that need to resolved manually. It might make sense to try black instead.
This commit is contained in:
parent
662af52820
commit
4c3791a459
@ -9,11 +9,11 @@ repos:
|
||||
# pass_filenames: false
|
||||
# args: [--config-file=setup.cfg, tianshou]
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||
- repo: https://github.com/google/yapf
|
||||
rev: v0.32.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
args: [-r]
|
||||
args: [-r, -i]
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.10.1
|
||||
@ -21,7 +21,7 @@ repos:
|
||||
- id: isort
|
||||
name: isort
|
||||
|
||||
- repo: https://gitlab.com/PyCQA/flake8
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 4.0.1
|
||||
hooks:
|
||||
- id: flake8
|
||||
|
@ -16,6 +16,16 @@ except ImportError:
|
||||
envpool = None
|
||||
|
||||
|
||||
def _parse_reset_result(reset_result):
|
||||
contains_info = (
|
||||
isinstance(reset_result, tuple) and len(reset_result) == 2
|
||||
and isinstance(reset_result[1], dict)
|
||||
)
|
||||
if contains_info:
|
||||
return reset_result[0], reset_result[1], contains_info
|
||||
return reset_result, {}, contains_info
|
||||
|
||||
|
||||
class NoopResetEnv(gym.Wrapper):
|
||||
"""Sample initial states by taking random number of no-ops on reset.
|
||||
No-op is assumed to be action 0.
|
||||
@ -30,16 +40,23 @@ class NoopResetEnv(gym.Wrapper):
|
||||
self.noop_action = 0
|
||||
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
|
||||
|
||||
def reset(self):
|
||||
self.env.reset()
|
||||
def reset(self, **kwargs):
|
||||
_, info, return_info = _parse_reset_result(self.env.reset(**kwargs))
|
||||
if hasattr(self.unwrapped.np_random, "integers"):
|
||||
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
|
||||
else:
|
||||
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
|
||||
for _ in range(noops):
|
||||
obs, _, done, _ = self.env.step(self.noop_action)
|
||||
step_result = self.env.step(self.noop_action)
|
||||
if len(step_result) == 4:
|
||||
obs, rew, done, info = step_result
|
||||
else:
|
||||
obs, rew, term, trunc, info = step_result
|
||||
done = term or trunc
|
||||
if done:
|
||||
obs = self.env.reset()
|
||||
obs, info, _ = _parse_reset_result(self.env.reset())
|
||||
if return_info:
|
||||
return obs, info
|
||||
return obs
|
||||
|
||||
|
||||
@ -59,14 +76,24 @@ class MaxAndSkipEnv(gym.Wrapper):
|
||||
"""Step the environment with the given action. Repeat action, sum
|
||||
reward, and max over last observations.
|
||||
"""
|
||||
obs_list, total_reward, done = [], 0., False
|
||||
obs_list, total_reward = [], 0.
|
||||
new_step_api = False
|
||||
for _ in range(self._skip):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
step_result = self.env.step(action)
|
||||
if len(step_result) == 4:
|
||||
obs, reward, done, info = step_result
|
||||
else:
|
||||
obs, reward, term, trunc, info = step_result
|
||||
done = term or trunc
|
||||
new_step_api = True
|
||||
obs_list.append(obs)
|
||||
total_reward += reward
|
||||
if done:
|
||||
break
|
||||
max_frame = np.max(obs_list[-2:], axis=0)
|
||||
if new_step_api:
|
||||
return max_frame, total_reward, term, trunc, info
|
||||
|
||||
return max_frame, total_reward, done, info
|
||||
|
||||
|
||||
@ -81,9 +108,18 @@ class EpisodicLifeEnv(gym.Wrapper):
|
||||
super().__init__(env)
|
||||
self.lives = 0
|
||||
self.was_real_done = True
|
||||
self._return_info = False
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
step_result = self.env.step(action)
|
||||
if len(step_result) == 4:
|
||||
obs, reward, done, info = step_result
|
||||
new_step_api = False
|
||||
else:
|
||||
obs, reward, term, trunc, info = step_result
|
||||
done = term or trunc
|
||||
new_step_api = True
|
||||
|
||||
self.was_real_done = done
|
||||
# check current lives, make loss of life terminal, then update lives to
|
||||
# handle bonus lives
|
||||
@ -93,7 +129,10 @@ class EpisodicLifeEnv(gym.Wrapper):
|
||||
# frames, so its important to keep lives > 0, so that we only reset
|
||||
# once the environment is actually done.
|
||||
done = True
|
||||
term = True
|
||||
self.lives = lives
|
||||
if new_step_api:
|
||||
return obs, reward, term, trunc, info
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
@ -102,12 +141,16 @@ class EpisodicLifeEnv(gym.Wrapper):
|
||||
the learner need not know about any of this behind-the-scenes.
|
||||
"""
|
||||
if self.was_real_done:
|
||||
obs = self.env.reset()
|
||||
obs, info, self._return_info = _parse_reset_result(self.env.reset())
|
||||
else:
|
||||
# no-op step to advance from terminal/lost life state
|
||||
obs = self.env.step(0)[0]
|
||||
step_result = self.env.step(0)
|
||||
obs, info = step_result[0], step_result[-1]
|
||||
self.lives = self.env.unwrapped.ale.lives()
|
||||
return obs
|
||||
if self._return_info:
|
||||
return obs, info
|
||||
else:
|
||||
return obs
|
||||
|
||||
|
||||
class FireResetEnv(gym.Wrapper):
|
||||
@ -123,8 +166,9 @@ class FireResetEnv(gym.Wrapper):
|
||||
assert len(env.unwrapped.get_action_meanings()) >= 3
|
||||
|
||||
def reset(self):
|
||||
self.env.reset()
|
||||
return self.env.step(1)[0]
|
||||
_, _, return_info = _parse_reset_result(self.env.reset())
|
||||
obs = self.env.step(1)[0]
|
||||
return (obs, {}) if return_info else obs
|
||||
|
||||
|
||||
class WarpFrame(gym.ObservationWrapper):
|
||||
@ -204,14 +248,22 @@ class FrameStack(gym.Wrapper):
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
obs = self.env.reset()
|
||||
obs, info, return_info = _parse_reset_result(self.env.reset())
|
||||
for _ in range(self.n_frames):
|
||||
self.frames.append(obs)
|
||||
return self._get_ob()
|
||||
return (self._get_ob(), info) if return_info else self._get_ob()
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
step_result = self.env.step(action)
|
||||
if len(step_result) == 4:
|
||||
obs, reward, done, info = step_result
|
||||
new_step_api = False
|
||||
else:
|
||||
obs, reward, term, trunc, info = step_result
|
||||
new_step_api = True
|
||||
self.frames.append(obs)
|
||||
if new_step_api:
|
||||
return self._get_ob(), reward, term, trunc, info
|
||||
return self._get_ob(), reward, done, info
|
||||
|
||||
def _get_ob(self):
|
||||
|
@ -8,7 +8,7 @@ exclude =
|
||||
dist
|
||||
*.egg-info
|
||||
max-line-length = 87
|
||||
ignore = B305,W504,B006,B008,B024
|
||||
ignore = B305,W504,B006,B008,B024,W503
|
||||
|
||||
[yapf]
|
||||
based_on_style = pep8
|
||||
|
Loading…
x
Reference in New Issue
Block a user