Updated atari wrappers, fixed pre-commit (#781)
This PR addresses #772 (updates Atari wrappers to work with new Gym API) and some additional issues: - Pre-commit was using gitlab for flake8, which as of recently requires authentication -> Replaced with GitHub - Yapf was quietly failing in pre-commit. Changed it such that it fixes formatting in-place - There is an incompatibility between flake8 and yapf where yapf puts binary operators after the line break and flake8 wants it before the break. I added an exception for flake8. - Also require `packaging` in setup.py My changes shouldn't change the behaviour of the wrappers for older versions, but please double check. Idk whether it's just me, but there are always some incompatibilities between yapf and flake8 that need to resolved manually. It might make sense to try black instead.
This commit is contained in:
parent
662af52820
commit
4c3791a459
@ -9,11 +9,11 @@ repos:
|
|||||||
# pass_filenames: false
|
# pass_filenames: false
|
||||||
# args: [--config-file=setup.cfg, tianshou]
|
# args: [--config-file=setup.cfg, tianshou]
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
- repo: https://github.com/google/yapf
|
||||||
rev: v0.32.0
|
rev: v0.32.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: yapf
|
- id: yapf
|
||||||
args: [-r]
|
args: [-r, -i]
|
||||||
|
|
||||||
- repo: https://github.com/pycqa/isort
|
- repo: https://github.com/pycqa/isort
|
||||||
rev: 5.10.1
|
rev: 5.10.1
|
||||||
@ -21,7 +21,7 @@ repos:
|
|||||||
- id: isort
|
- id: isort
|
||||||
name: isort
|
name: isort
|
||||||
|
|
||||||
- repo: https://gitlab.com/PyCQA/flake8
|
- repo: https://github.com/PyCQA/flake8
|
||||||
rev: 4.0.1
|
rev: 4.0.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
|
@ -16,6 +16,16 @@ except ImportError:
|
|||||||
envpool = None
|
envpool = None
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_reset_result(reset_result):
|
||||||
|
contains_info = (
|
||||||
|
isinstance(reset_result, tuple) and len(reset_result) == 2
|
||||||
|
and isinstance(reset_result[1], dict)
|
||||||
|
)
|
||||||
|
if contains_info:
|
||||||
|
return reset_result[0], reset_result[1], contains_info
|
||||||
|
return reset_result, {}, contains_info
|
||||||
|
|
||||||
|
|
||||||
class NoopResetEnv(gym.Wrapper):
|
class NoopResetEnv(gym.Wrapper):
|
||||||
"""Sample initial states by taking random number of no-ops on reset.
|
"""Sample initial states by taking random number of no-ops on reset.
|
||||||
No-op is assumed to be action 0.
|
No-op is assumed to be action 0.
|
||||||
@ -30,16 +40,23 @@ class NoopResetEnv(gym.Wrapper):
|
|||||||
self.noop_action = 0
|
self.noop_action = 0
|
||||||
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
|
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, **kwargs):
|
||||||
self.env.reset()
|
_, info, return_info = _parse_reset_result(self.env.reset(**kwargs))
|
||||||
if hasattr(self.unwrapped.np_random, "integers"):
|
if hasattr(self.unwrapped.np_random, "integers"):
|
||||||
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
|
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
|
||||||
else:
|
else:
|
||||||
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
|
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
|
||||||
for _ in range(noops):
|
for _ in range(noops):
|
||||||
obs, _, done, _ = self.env.step(self.noop_action)
|
step_result = self.env.step(self.noop_action)
|
||||||
|
if len(step_result) == 4:
|
||||||
|
obs, rew, done, info = step_result
|
||||||
|
else:
|
||||||
|
obs, rew, term, trunc, info = step_result
|
||||||
|
done = term or trunc
|
||||||
if done:
|
if done:
|
||||||
obs = self.env.reset()
|
obs, info, _ = _parse_reset_result(self.env.reset())
|
||||||
|
if return_info:
|
||||||
|
return obs, info
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
|
|
||||||
@ -59,14 +76,24 @@ class MaxAndSkipEnv(gym.Wrapper):
|
|||||||
"""Step the environment with the given action. Repeat action, sum
|
"""Step the environment with the given action. Repeat action, sum
|
||||||
reward, and max over last observations.
|
reward, and max over last observations.
|
||||||
"""
|
"""
|
||||||
obs_list, total_reward, done = [], 0., False
|
obs_list, total_reward = [], 0.
|
||||||
|
new_step_api = False
|
||||||
for _ in range(self._skip):
|
for _ in range(self._skip):
|
||||||
obs, reward, done, info = self.env.step(action)
|
step_result = self.env.step(action)
|
||||||
|
if len(step_result) == 4:
|
||||||
|
obs, reward, done, info = step_result
|
||||||
|
else:
|
||||||
|
obs, reward, term, trunc, info = step_result
|
||||||
|
done = term or trunc
|
||||||
|
new_step_api = True
|
||||||
obs_list.append(obs)
|
obs_list.append(obs)
|
||||||
total_reward += reward
|
total_reward += reward
|
||||||
if done:
|
if done:
|
||||||
break
|
break
|
||||||
max_frame = np.max(obs_list[-2:], axis=0)
|
max_frame = np.max(obs_list[-2:], axis=0)
|
||||||
|
if new_step_api:
|
||||||
|
return max_frame, total_reward, term, trunc, info
|
||||||
|
|
||||||
return max_frame, total_reward, done, info
|
return max_frame, total_reward, done, info
|
||||||
|
|
||||||
|
|
||||||
@ -81,9 +108,18 @@ class EpisodicLifeEnv(gym.Wrapper):
|
|||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
self.lives = 0
|
self.lives = 0
|
||||||
self.was_real_done = True
|
self.was_real_done = True
|
||||||
|
self._return_info = False
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
obs, reward, done, info = self.env.step(action)
|
step_result = self.env.step(action)
|
||||||
|
if len(step_result) == 4:
|
||||||
|
obs, reward, done, info = step_result
|
||||||
|
new_step_api = False
|
||||||
|
else:
|
||||||
|
obs, reward, term, trunc, info = step_result
|
||||||
|
done = term or trunc
|
||||||
|
new_step_api = True
|
||||||
|
|
||||||
self.was_real_done = done
|
self.was_real_done = done
|
||||||
# check current lives, make loss of life terminal, then update lives to
|
# check current lives, make loss of life terminal, then update lives to
|
||||||
# handle bonus lives
|
# handle bonus lives
|
||||||
@ -93,7 +129,10 @@ class EpisodicLifeEnv(gym.Wrapper):
|
|||||||
# frames, so its important to keep lives > 0, so that we only reset
|
# frames, so its important to keep lives > 0, so that we only reset
|
||||||
# once the environment is actually done.
|
# once the environment is actually done.
|
||||||
done = True
|
done = True
|
||||||
|
term = True
|
||||||
self.lives = lives
|
self.lives = lives
|
||||||
|
if new_step_api:
|
||||||
|
return obs, reward, term, trunc, info
|
||||||
return obs, reward, done, info
|
return obs, reward, done, info
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
@ -102,12 +141,16 @@ class EpisodicLifeEnv(gym.Wrapper):
|
|||||||
the learner need not know about any of this behind-the-scenes.
|
the learner need not know about any of this behind-the-scenes.
|
||||||
"""
|
"""
|
||||||
if self.was_real_done:
|
if self.was_real_done:
|
||||||
obs = self.env.reset()
|
obs, info, self._return_info = _parse_reset_result(self.env.reset())
|
||||||
else:
|
else:
|
||||||
# no-op step to advance from terminal/lost life state
|
# no-op step to advance from terminal/lost life state
|
||||||
obs = self.env.step(0)[0]
|
step_result = self.env.step(0)
|
||||||
|
obs, info = step_result[0], step_result[-1]
|
||||||
self.lives = self.env.unwrapped.ale.lives()
|
self.lives = self.env.unwrapped.ale.lives()
|
||||||
return obs
|
if self._return_info:
|
||||||
|
return obs, info
|
||||||
|
else:
|
||||||
|
return obs
|
||||||
|
|
||||||
|
|
||||||
class FireResetEnv(gym.Wrapper):
|
class FireResetEnv(gym.Wrapper):
|
||||||
@ -123,8 +166,9 @@ class FireResetEnv(gym.Wrapper):
|
|||||||
assert len(env.unwrapped.get_action_meanings()) >= 3
|
assert len(env.unwrapped.get_action_meanings()) >= 3
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.env.reset()
|
_, _, return_info = _parse_reset_result(self.env.reset())
|
||||||
return self.env.step(1)[0]
|
obs = self.env.step(1)[0]
|
||||||
|
return (obs, {}) if return_info else obs
|
||||||
|
|
||||||
|
|
||||||
class WarpFrame(gym.ObservationWrapper):
|
class WarpFrame(gym.ObservationWrapper):
|
||||||
@ -204,14 +248,22 @@ class FrameStack(gym.Wrapper):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
obs = self.env.reset()
|
obs, info, return_info = _parse_reset_result(self.env.reset())
|
||||||
for _ in range(self.n_frames):
|
for _ in range(self.n_frames):
|
||||||
self.frames.append(obs)
|
self.frames.append(obs)
|
||||||
return self._get_ob()
|
return (self._get_ob(), info) if return_info else self._get_ob()
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
obs, reward, done, info = self.env.step(action)
|
step_result = self.env.step(action)
|
||||||
|
if len(step_result) == 4:
|
||||||
|
obs, reward, done, info = step_result
|
||||||
|
new_step_api = False
|
||||||
|
else:
|
||||||
|
obs, reward, term, trunc, info = step_result
|
||||||
|
new_step_api = True
|
||||||
self.frames.append(obs)
|
self.frames.append(obs)
|
||||||
|
if new_step_api:
|
||||||
|
return self._get_ob(), reward, term, trunc, info
|
||||||
return self._get_ob(), reward, done, info
|
return self._get_ob(), reward, done, info
|
||||||
|
|
||||||
def _get_ob(self):
|
def _get_ob(self):
|
||||||
|
@ -8,7 +8,7 @@ exclude =
|
|||||||
dist
|
dist
|
||||||
*.egg-info
|
*.egg-info
|
||||||
max-line-length = 87
|
max-line-length = 87
|
||||||
ignore = B305,W504,B006,B008,B024
|
ignore = B305,W504,B006,B008,B024,W503
|
||||||
|
|
||||||
[yapf]
|
[yapf]
|
||||||
based_on_style = pep8
|
based_on_style = pep8
|
||||||
|
1
setup.py
1
setup.py
@ -23,6 +23,7 @@ def get_install_requires() -> str:
|
|||||||
"numba>=0.51.0",
|
"numba>=0.51.0",
|
||||||
"h5py>=2.10.0", # to match tensorflow's minimal requirements
|
"h5py>=2.10.0", # to match tensorflow's minimal requirements
|
||||||
"protobuf~=3.19.0", # breaking change, sphinx fail
|
"protobuf~=3.19.0", # breaking change, sphinx fail
|
||||||
|
"packaging",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user