Updated atari wrappers, fixed pre-commit (#781)

This PR addresses #772 (updates Atari wrappers to work with new Gym API)
and some additional issues:

- Pre-commit was using gitlab for flake8, which as of recently requires
authentication -> Replaced with GitHub
- Yapf was quietly failing in pre-commit. Changed it such that it fixes
formatting in-place
- There is an incompatibility between flake8 and yapf where yapf puts
binary operators after the line break and flake8 wants it before the
break. I added an exception for flake8.
- Also require `packaging` in setup.py

My changes shouldn't change the behaviour of the wrappers for older
versions, but please double check.
Idk whether it's just me, but there are always some incompatibilities
between yapf and flake8 that need to resolved manually. It might make
sense to try black instead.
This commit is contained in:
Markus Krimmel 2022-12-04 22:00:53 +01:00 committed by GitHub
parent 662af52820
commit 4c3791a459
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 72 additions and 19 deletions

View File

@ -9,11 +9,11 @@ repos:
# pass_filenames: false
# args: [--config-file=setup.cfg, tianshou]
- repo: https://github.com/pre-commit/mirrors-yapf
- repo: https://github.com/google/yapf
rev: v0.32.0
hooks:
- id: yapf
args: [-r]
args: [-r, -i]
- repo: https://github.com/pycqa/isort
rev: 5.10.1
@ -21,7 +21,7 @@ repos:
- id: isort
name: isort
- repo: https://gitlab.com/PyCQA/flake8
- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
hooks:
- id: flake8

View File

@ -16,6 +16,16 @@ except ImportError:
envpool = None
def _parse_reset_result(reset_result):
contains_info = (
isinstance(reset_result, tuple) and len(reset_result) == 2
and isinstance(reset_result[1], dict)
)
if contains_info:
return reset_result[0], reset_result[1], contains_info
return reset_result, {}, contains_info
class NoopResetEnv(gym.Wrapper):
"""Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
@ -30,16 +40,23 @@ class NoopResetEnv(gym.Wrapper):
self.noop_action = 0
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
def reset(self):
self.env.reset()
def reset(self, **kwargs):
_, info, return_info = _parse_reset_result(self.env.reset(**kwargs))
if hasattr(self.unwrapped.np_random, "integers"):
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
else:
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
for _ in range(noops):
obs, _, done, _ = self.env.step(self.noop_action)
step_result = self.env.step(self.noop_action)
if len(step_result) == 4:
obs, rew, done, info = step_result
else:
obs, rew, term, trunc, info = step_result
done = term or trunc
if done:
obs = self.env.reset()
obs, info, _ = _parse_reset_result(self.env.reset())
if return_info:
return obs, info
return obs
@ -59,14 +76,24 @@ class MaxAndSkipEnv(gym.Wrapper):
"""Step the environment with the given action. Repeat action, sum
reward, and max over last observations.
"""
obs_list, total_reward, done = [], 0., False
obs_list, total_reward = [], 0.
new_step_api = False
for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
step_result = self.env.step(action)
if len(step_result) == 4:
obs, reward, done, info = step_result
else:
obs, reward, term, trunc, info = step_result
done = term or trunc
new_step_api = True
obs_list.append(obs)
total_reward += reward
if done:
break
max_frame = np.max(obs_list[-2:], axis=0)
if new_step_api:
return max_frame, total_reward, term, trunc, info
return max_frame, total_reward, done, info
@ -81,9 +108,18 @@ class EpisodicLifeEnv(gym.Wrapper):
super().__init__(env)
self.lives = 0
self.was_real_done = True
self._return_info = False
def step(self, action):
obs, reward, done, info = self.env.step(action)
step_result = self.env.step(action)
if len(step_result) == 4:
obs, reward, done, info = step_result
new_step_api = False
else:
obs, reward, term, trunc, info = step_result
done = term or trunc
new_step_api = True
self.was_real_done = done
# check current lives, make loss of life terminal, then update lives to
# handle bonus lives
@ -93,7 +129,10 @@ class EpisodicLifeEnv(gym.Wrapper):
# frames, so its important to keep lives > 0, so that we only reset
# once the environment is actually done.
done = True
term = True
self.lives = lives
if new_step_api:
return obs, reward, term, trunc, info
return obs, reward, done, info
def reset(self):
@ -102,12 +141,16 @@ class EpisodicLifeEnv(gym.Wrapper):
the learner need not know about any of this behind-the-scenes.
"""
if self.was_real_done:
obs = self.env.reset()
obs, info, self._return_info = _parse_reset_result(self.env.reset())
else:
# no-op step to advance from terminal/lost life state
obs = self.env.step(0)[0]
step_result = self.env.step(0)
obs, info = step_result[0], step_result[-1]
self.lives = self.env.unwrapped.ale.lives()
return obs
if self._return_info:
return obs, info
else:
return obs
class FireResetEnv(gym.Wrapper):
@ -123,8 +166,9 @@ class FireResetEnv(gym.Wrapper):
assert len(env.unwrapped.get_action_meanings()) >= 3
def reset(self):
self.env.reset()
return self.env.step(1)[0]
_, _, return_info = _parse_reset_result(self.env.reset())
obs = self.env.step(1)[0]
return (obs, {}) if return_info else obs
class WarpFrame(gym.ObservationWrapper):
@ -204,14 +248,22 @@ class FrameStack(gym.Wrapper):
)
def reset(self):
obs = self.env.reset()
obs, info, return_info = _parse_reset_result(self.env.reset())
for _ in range(self.n_frames):
self.frames.append(obs)
return self._get_ob()
return (self._get_ob(), info) if return_info else self._get_ob()
def step(self, action):
obs, reward, done, info = self.env.step(action)
step_result = self.env.step(action)
if len(step_result) == 4:
obs, reward, done, info = step_result
new_step_api = False
else:
obs, reward, term, trunc, info = step_result
new_step_api = True
self.frames.append(obs)
if new_step_api:
return self._get_ob(), reward, term, trunc, info
return self._get_ob(), reward, done, info
def _get_ob(self):

View File

@ -8,7 +8,7 @@ exclude =
dist
*.egg-info
max-line-length = 87
ignore = B305,W504,B006,B008,B024
ignore = B305,W504,B006,B008,B024,W503
[yapf]
based_on_style = pep8

View File

@ -23,6 +23,7 @@ def get_install_requires() -> str:
"numba>=0.51.0",
"h5py>=2.10.0", # to match tensorflow's minimal requirements
"protobuf~=3.19.0", # breaking change, sphinx fail
"packaging",
]