From 4c3791a459f8ff909a38c1b008ed8b71d74e1b98 Mon Sep 17 00:00:00 2001 From: Markus Krimmel Date: Sun, 4 Dec 2022 22:00:53 +0100 Subject: [PATCH] Updated atari wrappers, fixed pre-commit (#781) This PR addresses #772 (updates Atari wrappers to work with new Gym API) and some additional issues: - Pre-commit was using gitlab for flake8, which as of recently requires authentication -> Replaced with GitHub - Yapf was quietly failing in pre-commit. Changed it such that it fixes formatting in-place - There is an incompatibility between flake8 and yapf where yapf puts binary operators after the line break and flake8 wants it before the break. I added an exception for flake8. - Also require `packaging` in setup.py My changes shouldn't change the behaviour of the wrappers for older versions, but please double check. Idk whether it's just me, but there are always some incompatibilities between yapf and flake8 that need to resolved manually. It might make sense to try black instead. --- .pre-commit-config.yaml | 6 +-- examples/atari/atari_wrapper.py | 82 +++++++++++++++++++++++++++------ setup.cfg | 2 +- setup.py | 1 + 4 files changed, 72 insertions(+), 19 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4380db4..804de4d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,11 +9,11 @@ repos: # pass_filenames: false # args: [--config-file=setup.cfg, tianshou] - - repo: https://github.com/pre-commit/mirrors-yapf + - repo: https://github.com/google/yapf rev: v0.32.0 hooks: - id: yapf - args: [-r] + args: [-r, -i] - repo: https://github.com/pycqa/isort rev: 5.10.1 @@ -21,7 +21,7 @@ repos: - id: isort name: isort - - repo: https://gitlab.com/PyCQA/flake8 + - repo: https://github.com/PyCQA/flake8 rev: 4.0.1 hooks: - id: flake8 diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index 293c1ce..901b166 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -16,6 +16,16 @@ except ImportError: envpool = None +def _parse_reset_result(reset_result): + contains_info = ( + isinstance(reset_result, tuple) and len(reset_result) == 2 + and isinstance(reset_result[1], dict) + ) + if contains_info: + return reset_result[0], reset_result[1], contains_info + return reset_result, {}, contains_info + + class NoopResetEnv(gym.Wrapper): """Sample initial states by taking random number of no-ops on reset. No-op is assumed to be action 0. @@ -30,16 +40,23 @@ class NoopResetEnv(gym.Wrapper): self.noop_action = 0 assert env.unwrapped.get_action_meanings()[0] == 'NOOP' - def reset(self): - self.env.reset() + def reset(self, **kwargs): + _, info, return_info = _parse_reset_result(self.env.reset(**kwargs)) if hasattr(self.unwrapped.np_random, "integers"): noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) else: noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) for _ in range(noops): - obs, _, done, _ = self.env.step(self.noop_action) + step_result = self.env.step(self.noop_action) + if len(step_result) == 4: + obs, rew, done, info = step_result + else: + obs, rew, term, trunc, info = step_result + done = term or trunc if done: - obs = self.env.reset() + obs, info, _ = _parse_reset_result(self.env.reset()) + if return_info: + return obs, info return obs @@ -59,14 +76,24 @@ class MaxAndSkipEnv(gym.Wrapper): """Step the environment with the given action. Repeat action, sum reward, and max over last observations. """ - obs_list, total_reward, done = [], 0., False + obs_list, total_reward = [], 0. + new_step_api = False for _ in range(self._skip): - obs, reward, done, info = self.env.step(action) + step_result = self.env.step(action) + if len(step_result) == 4: + obs, reward, done, info = step_result + else: + obs, reward, term, trunc, info = step_result + done = term or trunc + new_step_api = True obs_list.append(obs) total_reward += reward if done: break max_frame = np.max(obs_list[-2:], axis=0) + if new_step_api: + return max_frame, total_reward, term, trunc, info + return max_frame, total_reward, done, info @@ -81,9 +108,18 @@ class EpisodicLifeEnv(gym.Wrapper): super().__init__(env) self.lives = 0 self.was_real_done = True + self._return_info = False def step(self, action): - obs, reward, done, info = self.env.step(action) + step_result = self.env.step(action) + if len(step_result) == 4: + obs, reward, done, info = step_result + new_step_api = False + else: + obs, reward, term, trunc, info = step_result + done = term or trunc + new_step_api = True + self.was_real_done = done # check current lives, make loss of life terminal, then update lives to # handle bonus lives @@ -93,7 +129,10 @@ class EpisodicLifeEnv(gym.Wrapper): # frames, so its important to keep lives > 0, so that we only reset # once the environment is actually done. done = True + term = True self.lives = lives + if new_step_api: + return obs, reward, term, trunc, info return obs, reward, done, info def reset(self): @@ -102,12 +141,16 @@ class EpisodicLifeEnv(gym.Wrapper): the learner need not know about any of this behind-the-scenes. """ if self.was_real_done: - obs = self.env.reset() + obs, info, self._return_info = _parse_reset_result(self.env.reset()) else: # no-op step to advance from terminal/lost life state - obs = self.env.step(0)[0] + step_result = self.env.step(0) + obs, info = step_result[0], step_result[-1] self.lives = self.env.unwrapped.ale.lives() - return obs + if self._return_info: + return obs, info + else: + return obs class FireResetEnv(gym.Wrapper): @@ -123,8 +166,9 @@ class FireResetEnv(gym.Wrapper): assert len(env.unwrapped.get_action_meanings()) >= 3 def reset(self): - self.env.reset() - return self.env.step(1)[0] + _, _, return_info = _parse_reset_result(self.env.reset()) + obs = self.env.step(1)[0] + return (obs, {}) if return_info else obs class WarpFrame(gym.ObservationWrapper): @@ -204,14 +248,22 @@ class FrameStack(gym.Wrapper): ) def reset(self): - obs = self.env.reset() + obs, info, return_info = _parse_reset_result(self.env.reset()) for _ in range(self.n_frames): self.frames.append(obs) - return self._get_ob() + return (self._get_ob(), info) if return_info else self._get_ob() def step(self, action): - obs, reward, done, info = self.env.step(action) + step_result = self.env.step(action) + if len(step_result) == 4: + obs, reward, done, info = step_result + new_step_api = False + else: + obs, reward, term, trunc, info = step_result + new_step_api = True self.frames.append(obs) + if new_step_api: + return self._get_ob(), reward, term, trunc, info return self._get_ob(), reward, done, info def _get_ob(self): diff --git a/setup.cfg b/setup.cfg index 9631524..2960359 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,7 +8,7 @@ exclude = dist *.egg-info max-line-length = 87 -ignore = B305,W504,B006,B008,B024 +ignore = B305,W504,B006,B008,B024,W503 [yapf] based_on_style = pep8 diff --git a/setup.py b/setup.py index cbed99d..3d96d75 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ def get_install_requires() -> str: "numba>=0.51.0", "h5py>=2.10.0", # to match tensorflow's minimal requirements "protobuf~=3.19.0", # breaking change, sphinx fail + "packaging", ]