Updated atari wrappers, fixed pre-commit (#781)

This PR addresses #772 (updates Atari wrappers to work with new Gym API)
and some additional issues:

- Pre-commit was using gitlab for flake8, which as of recently requires
authentication -> Replaced with GitHub
- Yapf was quietly failing in pre-commit. Changed it such that it fixes
formatting in-place
- There is an incompatibility between flake8 and yapf where yapf puts
binary operators after the line break and flake8 wants it before the
break. I added an exception for flake8.
- Also require `packaging` in setup.py

My changes shouldn't change the behaviour of the wrappers for older
versions, but please double check.
Idk whether it's just me, but there are always some incompatibilities
between yapf and flake8 that need to resolved manually. It might make
sense to try black instead.
This commit is contained in:
Markus Krimmel 2022-12-04 22:00:53 +01:00 committed by GitHub
parent 662af52820
commit 4c3791a459
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 72 additions and 19 deletions

View File

@ -9,11 +9,11 @@ repos:
# pass_filenames: false # pass_filenames: false
# args: [--config-file=setup.cfg, tianshou] # args: [--config-file=setup.cfg, tianshou]
- repo: https://github.com/pre-commit/mirrors-yapf - repo: https://github.com/google/yapf
rev: v0.32.0 rev: v0.32.0
hooks: hooks:
- id: yapf - id: yapf
args: [-r] args: [-r, -i]
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
rev: 5.10.1 rev: 5.10.1
@ -21,7 +21,7 @@ repos:
- id: isort - id: isort
name: isort name: isort
- repo: https://gitlab.com/PyCQA/flake8 - repo: https://github.com/PyCQA/flake8
rev: 4.0.1 rev: 4.0.1
hooks: hooks:
- id: flake8 - id: flake8

View File

@ -16,6 +16,16 @@ except ImportError:
envpool = None envpool = None
def _parse_reset_result(reset_result):
contains_info = (
isinstance(reset_result, tuple) and len(reset_result) == 2
and isinstance(reset_result[1], dict)
)
if contains_info:
return reset_result[0], reset_result[1], contains_info
return reset_result, {}, contains_info
class NoopResetEnv(gym.Wrapper): class NoopResetEnv(gym.Wrapper):
"""Sample initial states by taking random number of no-ops on reset. """Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0. No-op is assumed to be action 0.
@ -30,16 +40,23 @@ class NoopResetEnv(gym.Wrapper):
self.noop_action = 0 self.noop_action = 0
assert env.unwrapped.get_action_meanings()[0] == 'NOOP' assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
def reset(self): def reset(self, **kwargs):
self.env.reset() _, info, return_info = _parse_reset_result(self.env.reset(**kwargs))
if hasattr(self.unwrapped.np_random, "integers"): if hasattr(self.unwrapped.np_random, "integers"):
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
else: else:
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
for _ in range(noops): for _ in range(noops):
obs, _, done, _ = self.env.step(self.noop_action) step_result = self.env.step(self.noop_action)
if len(step_result) == 4:
obs, rew, done, info = step_result
else:
obs, rew, term, trunc, info = step_result
done = term or trunc
if done: if done:
obs = self.env.reset() obs, info, _ = _parse_reset_result(self.env.reset())
if return_info:
return obs, info
return obs return obs
@ -59,14 +76,24 @@ class MaxAndSkipEnv(gym.Wrapper):
"""Step the environment with the given action. Repeat action, sum """Step the environment with the given action. Repeat action, sum
reward, and max over last observations. reward, and max over last observations.
""" """
obs_list, total_reward, done = [], 0., False obs_list, total_reward = [], 0.
new_step_api = False
for _ in range(self._skip): for _ in range(self._skip):
obs, reward, done, info = self.env.step(action) step_result = self.env.step(action)
if len(step_result) == 4:
obs, reward, done, info = step_result
else:
obs, reward, term, trunc, info = step_result
done = term or trunc
new_step_api = True
obs_list.append(obs) obs_list.append(obs)
total_reward += reward total_reward += reward
if done: if done:
break break
max_frame = np.max(obs_list[-2:], axis=0) max_frame = np.max(obs_list[-2:], axis=0)
if new_step_api:
return max_frame, total_reward, term, trunc, info
return max_frame, total_reward, done, info return max_frame, total_reward, done, info
@ -81,9 +108,18 @@ class EpisodicLifeEnv(gym.Wrapper):
super().__init__(env) super().__init__(env)
self.lives = 0 self.lives = 0
self.was_real_done = True self.was_real_done = True
self._return_info = False
def step(self, action): def step(self, action):
obs, reward, done, info = self.env.step(action) step_result = self.env.step(action)
if len(step_result) == 4:
obs, reward, done, info = step_result
new_step_api = False
else:
obs, reward, term, trunc, info = step_result
done = term or trunc
new_step_api = True
self.was_real_done = done self.was_real_done = done
# check current lives, make loss of life terminal, then update lives to # check current lives, make loss of life terminal, then update lives to
# handle bonus lives # handle bonus lives
@ -93,7 +129,10 @@ class EpisodicLifeEnv(gym.Wrapper):
# frames, so its important to keep lives > 0, so that we only reset # frames, so its important to keep lives > 0, so that we only reset
# once the environment is actually done. # once the environment is actually done.
done = True done = True
term = True
self.lives = lives self.lives = lives
if new_step_api:
return obs, reward, term, trunc, info
return obs, reward, done, info return obs, reward, done, info
def reset(self): def reset(self):
@ -102,12 +141,16 @@ class EpisodicLifeEnv(gym.Wrapper):
the learner need not know about any of this behind-the-scenes. the learner need not know about any of this behind-the-scenes.
""" """
if self.was_real_done: if self.was_real_done:
obs = self.env.reset() obs, info, self._return_info = _parse_reset_result(self.env.reset())
else: else:
# no-op step to advance from terminal/lost life state # no-op step to advance from terminal/lost life state
obs = self.env.step(0)[0] step_result = self.env.step(0)
obs, info = step_result[0], step_result[-1]
self.lives = self.env.unwrapped.ale.lives() self.lives = self.env.unwrapped.ale.lives()
return obs if self._return_info:
return obs, info
else:
return obs
class FireResetEnv(gym.Wrapper): class FireResetEnv(gym.Wrapper):
@ -123,8 +166,9 @@ class FireResetEnv(gym.Wrapper):
assert len(env.unwrapped.get_action_meanings()) >= 3 assert len(env.unwrapped.get_action_meanings()) >= 3
def reset(self): def reset(self):
self.env.reset() _, _, return_info = _parse_reset_result(self.env.reset())
return self.env.step(1)[0] obs = self.env.step(1)[0]
return (obs, {}) if return_info else obs
class WarpFrame(gym.ObservationWrapper): class WarpFrame(gym.ObservationWrapper):
@ -204,14 +248,22 @@ class FrameStack(gym.Wrapper):
) )
def reset(self): def reset(self):
obs = self.env.reset() obs, info, return_info = _parse_reset_result(self.env.reset())
for _ in range(self.n_frames): for _ in range(self.n_frames):
self.frames.append(obs) self.frames.append(obs)
return self._get_ob() return (self._get_ob(), info) if return_info else self._get_ob()
def step(self, action): def step(self, action):
obs, reward, done, info = self.env.step(action) step_result = self.env.step(action)
if len(step_result) == 4:
obs, reward, done, info = step_result
new_step_api = False
else:
obs, reward, term, trunc, info = step_result
new_step_api = True
self.frames.append(obs) self.frames.append(obs)
if new_step_api:
return self._get_ob(), reward, term, trunc, info
return self._get_ob(), reward, done, info return self._get_ob(), reward, done, info
def _get_ob(self): def _get_ob(self):

View File

@ -8,7 +8,7 @@ exclude =
dist dist
*.egg-info *.egg-info
max-line-length = 87 max-line-length = 87
ignore = B305,W504,B006,B008,B024 ignore = B305,W504,B006,B008,B024,W503
[yapf] [yapf]
based_on_style = pep8 based_on_style = pep8

View File

@ -23,6 +23,7 @@ def get_install_requires() -> str:
"numba>=0.51.0", "numba>=0.51.0",
"h5py>=2.10.0", # to match tensorflow's minimal requirements "h5py>=2.10.0", # to match tensorflow's minimal requirements
"protobuf~=3.19.0", # breaking change, sphinx fail "protobuf~=3.19.0", # breaking change, sphinx fail
"packaging",
] ]