From 6f0e6c6963f55c8d035c93388ad4dfbbfc25afeb Mon Sep 17 00:00:00 2001 From: NM512 Date: Sun, 23 Apr 2023 22:52:30 +0900 Subject: [PATCH] applied formatter to envs --- envs/atari.py | 248 +++++++++++++++++++++------------------- envs/dmc.py | 104 +++++++++-------- envs/dmlab.py | 177 +++++++++++++++-------------- envs/wrappers.py | 288 +++++++++++++++++++++++------------------------ 4 files changed, 417 insertions(+), 400 deletions(-) diff --git a/envs/atari.py b/envs/atari.py index 5c3c05c..e99fa20 100644 --- a/envs/atari.py +++ b/envs/atari.py @@ -2,127 +2,145 @@ import numpy as np class Atari: + LOCK = None - LOCK = None + def __init__( + self, + name, + action_repeat=4, + size=(84, 84), + gray=True, + noops=0, + lives="unused", + sticky=True, + actions="all", + length=108000, + resize="opencv", + seed=None, + ): + assert size[0] == size[1] + assert lives in ("unused", "discount", "reset"), lives + assert actions in ("all", "needed"), actions + assert resize in ("opencv", "pillow"), resize + if self.LOCK is None: + import multiprocessing as mp - def __init__( - self, name, action_repeat=4, size=(84, 84), gray=True, noops=0, lives='unused', - sticky=True, actions='all', length=108000, resize='opencv', seed=None): - assert size[0] == size[1] - assert lives in ('unused', 'discount', 'reset'), lives - assert actions in ('all', 'needed'), actions - assert resize in ('opencv', 'pillow'), resize - if self.LOCK is None: - import multiprocessing as mp - mp = mp.get_context('spawn') - self.LOCK = mp.Lock() - self._resize = resize - if self._resize == 'opencv': - import cv2 - self._cv2 = cv2 - if self._resize == 'pillow': - from PIL import Image - self._image = Image - import gym.envs.atari - if name == 'james_bond': - name = 'jamesbond' - self._repeat = action_repeat - self._size = size - self._gray = gray - self._noops = noops - self._lives = lives - self._sticky = sticky - self._length = length - self._random = np.random.RandomState(seed) - with self.LOCK: - self._env = gym.envs.atari.AtariEnv( - game=name, - obs_type='image', - frameskip=1, repeat_action_probability=0.25 if sticky else 0.0, - full_action_space=(actions == 'all')) - assert self._env.unwrapped.get_action_meanings()[0] == 'NOOP' - shape = self._env.observation_space.shape - self._buffer = [np.zeros(shape, np.uint8) for _ in range(2)] - self._ale = self._env.unwrapped.ale - self._last_lives = None - self._done = True - self._step = 0 + mp = mp.get_context("spawn") + self.LOCK = mp.Lock() + self._resize = resize + if self._resize == "opencv": + import cv2 - @property - def action_space(self): - space = self._env.action_space - space.discrete = True - return space + self._cv2 = cv2 + if self._resize == "pillow": + from PIL import Image - def step(self, action): - # if action['reset'] or self._done: - # with self.LOCK: - # self._reset() - # self._done = False - # self._step = 0 - # return self._obs(0.0, is_first=True) - total = 0.0 - dead = False - if len(action.shape) >= 1: - action = np.argmax(action) - for repeat in range(self._repeat): - _, reward, over, info = self._env.step(action) - self._step += 1 - total += reward - if repeat == self._repeat - 2: - self._screen(self._buffer[1]) - if over: - break - if self._lives != 'unused': - current = self._ale.lives() - if current < self._last_lives: - dead = True - self._last_lives = current - break - if not self._repeat: - self._buffer[1][:] = self._buffer[0][:] - self._screen(self._buffer[0]) - self._done = over or (self._length and self._step >= self._length) or dead - return self._obs( - total, - is_last=self._done or (dead and self._lives == 'reset'), - is_terminal=dead or over) + self._image = Image + import gym.envs.atari - def reset(self): - self._env.reset() - if self._noops: - for _ in range(self._random.randint(self._noops)): - _, _, dead, _ = self._env.step(0) - if dead: - self._env.reset() - self._last_lives = self._ale.lives() - self._screen(self._buffer[0]) - self._buffer[1].fill(0) + if name == "james_bond": + name = "jamesbond" + self._repeat = action_repeat + self._size = size + self._gray = gray + self._noops = noops + self._lives = lives + self._sticky = sticky + self._length = length + self._random = np.random.RandomState(seed) + with self.LOCK: + self._env = gym.envs.atari.AtariEnv( + game=name, + obs_type="image", + frameskip=1, + repeat_action_probability=0.25 if sticky else 0.0, + full_action_space=(actions == "all"), + ) + assert self._env.unwrapped.get_action_meanings()[0] == "NOOP" + shape = self._env.observation_space.shape + self._buffer = [np.zeros(shape, np.uint8) for _ in range(2)] + self._ale = self._env.unwrapped.ale + self._last_lives = None + self._done = True + self._step = 0 - self._done = False - self._step = 0 - obs, reward, is_terminal, _ = self._obs(0.0, is_first=True) - return obs + @property + def action_space(self): + space = self._env.action_space + space.discrete = True + return space - def _obs(self, reward, is_first=False, is_last=False, is_terminal=False): - np.maximum(self._buffer[0], self._buffer[1], out=self._buffer[0]) - image = self._buffer[0] - if image.shape[:2] != self._size: - if self._resize == 'opencv': - image = self._cv2.resize( - image, self._size, interpolation=self._cv2.INTER_AREA) - if self._resize == 'pillow': - image = self._image.fromarray(image) - image = image.resize(self._size, self._image.NEAREST) - image = np.array(image) - if self._gray: - weights = [0.299, 0.587, 1 - (0.299 + 0.587)] - image = np.tensordot(image, weights, (-1, 0)).astype(image.dtype) - image = image[:, :, None] - return {'image':image, 'is_terminal':is_terminal}, reward, is_last, {} + def step(self, action): + # if action['reset'] or self._done: + # with self.LOCK: + # self._reset() + # self._done = False + # self._step = 0 + # return self._obs(0.0, is_first=True) + total = 0.0 + dead = False + if len(action.shape) >= 1: + action = np.argmax(action) + for repeat in range(self._repeat): + _, reward, over, info = self._env.step(action) + self._step += 1 + total += reward + if repeat == self._repeat - 2: + self._screen(self._buffer[1]) + if over: + break + if self._lives != "unused": + current = self._ale.lives() + if current < self._last_lives: + dead = True + self._last_lives = current + break + if not self._repeat: + self._buffer[1][:] = self._buffer[0][:] + self._screen(self._buffer[0]) + self._done = over or (self._length and self._step >= self._length) or dead + return self._obs( + total, + is_last=self._done or (dead and self._lives == "reset"), + is_terminal=dead or over, + ) - def _screen(self, array): - self._ale.getScreenRGB2(array) + def reset(self): + self._env.reset() + if self._noops: + for _ in range(self._random.randint(self._noops)): + _, _, dead, _ = self._env.step(0) + if dead: + self._env.reset() + self._last_lives = self._ale.lives() + self._screen(self._buffer[0]) + self._buffer[1].fill(0) - def close(self): - return self._env.close() + self._done = False + self._step = 0 + obs, reward, is_terminal, _ = self._obs(0.0, is_first=True) + return obs + + def _obs(self, reward, is_first=False, is_last=False, is_terminal=False): + np.maximum(self._buffer[0], self._buffer[1], out=self._buffer[0]) + image = self._buffer[0] + if image.shape[:2] != self._size: + if self._resize == "opencv": + image = self._cv2.resize( + image, self._size, interpolation=self._cv2.INTER_AREA + ) + if self._resize == "pillow": + image = self._image.fromarray(image) + image = image.resize(self._size, self._image.NEAREST) + image = np.array(image) + if self._gray: + weights = [0.299, 0.587, 1 - (0.299 + 0.587)] + image = np.tensordot(image, weights, (-1, 0)).astype(image.dtype) + image = image[:, :, None] + return {"image": image, "is_terminal": is_terminal}, reward, is_last, {} + + def _screen(self, array): + self._ale.getScreenRGB2(array) + + def close(self): + return self._env.close() diff --git a/envs/dmc.py b/envs/dmc.py index ece7600..efffdf3 100644 --- a/envs/dmc.py +++ b/envs/dmc.py @@ -3,62 +3,60 @@ import numpy as np class DeepMindControl: + def __init__(self, name, action_repeat=1, size=(64, 64), camera=None): + domain, task = name.split("_", 1) + if domain == "cup": # Only domain with multiple words. + domain = "ball_in_cup" + if isinstance(domain, str): + from dm_control import suite - def __init__(self, name, action_repeat=1, size=(64, 64), camera=None): - domain, task = name.split('_', 1) - if domain == 'cup': # Only domain with multiple words. - domain = 'ball_in_cup' - if isinstance(domain, str): - from dm_control import suite - self._env = suite.load(domain, task) - else: - assert task is None - self._env = domain() - self._action_repeat = action_repeat - self._size = size - if camera is None: - camera = dict(quadruped=2).get(domain, 0) - self._camera = camera + self._env = suite.load(domain, task) + else: + assert task is None + self._env = domain() + self._action_repeat = action_repeat + self._size = size + if camera is None: + camera = dict(quadruped=2).get(domain, 0) + self._camera = camera - @property - def observation_space(self): - spaces = {} - for key, value in self._env.observation_spec().items(): - spaces[key] = gym.spaces.Box( - -np.inf, np.inf, value.shape, dtype=np.float32) - spaces['image'] = gym.spaces.Box( - 0, 255, self._size + (3,), dtype=np.uint8) - return gym.spaces.Dict(spaces) + @property + def observation_space(self): + spaces = {} + for key, value in self._env.observation_spec().items(): + spaces[key] = gym.spaces.Box(-np.inf, np.inf, value.shape, dtype=np.float32) + spaces["image"] = gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8) + return gym.spaces.Dict(spaces) - @property - def action_space(self): - spec = self._env.action_spec() - return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32) + @property + def action_space(self): + spec = self._env.action_spec() + return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32) - def step(self, action): - assert np.isfinite(action).all(), action - reward = 0 - for _ in range(self._action_repeat): - time_step = self._env.step(action) - reward += time_step.reward or 0 - if time_step.last(): - break - obs = dict(time_step.observation) - obs['image'] = self.render() - # There is no terminal state in DMC - obs['is_terminal'] = False - done = time_step.last() - info = {'discount': np.array(time_step.discount, np.float32)} - return obs, reward, done, info + def step(self, action): + assert np.isfinite(action).all(), action + reward = 0 + for _ in range(self._action_repeat): + time_step = self._env.step(action) + reward += time_step.reward or 0 + if time_step.last(): + break + obs = dict(time_step.observation) + obs["image"] = self.render() + # There is no terminal state in DMC + obs["is_terminal"] = False + done = time_step.last() + info = {"discount": np.array(time_step.discount, np.float32)} + return obs, reward, done, info - def reset(self): - time_step = self._env.reset() - obs = dict(time_step.observation) - obs['image'] = self.render() - obs['is_terminal'] = False - return obs + def reset(self): + time_step = self._env.reset() + obs = dict(time_step.observation) + obs["image"] = self.render() + obs["is_terminal"] = False + return obs - def render(self, *args, **kwargs): - if kwargs.get('mode', 'rgb_array') != 'rgb_array': - raise ValueError("Only render mode 'rgb_array' is supported.") - return self._env.physics.render(*self._size, camera_id=self._camera) + def render(self, *args, **kwargs): + if kwargs.get("mode", "rgb_array") != "rgb_array": + raise ValueError("Only render mode 'rgb_array' is supported.") + return self._env.physics.render(*self._size, camera_id=self._camera) diff --git a/envs/dmlab.py b/envs/dmlab.py index a17e696..9d8c867 100644 --- a/envs/dmlab.py +++ b/envs/dmlab.py @@ -4,98 +4,105 @@ import deepmind_lab class DeepMindLabyrinth(object): + ACTION_SET_DEFAULT = ( + (0, 0, 0, 1, 0, 0, 0), # Forward + (0, 0, 0, -1, 0, 0, 0), # Backward + (0, 0, -1, 0, 0, 0, 0), # Strafe Left + (0, 0, 1, 0, 0, 0, 0), # Strafe Right + (-20, 0, 0, 0, 0, 0, 0), # Look Left + (20, 0, 0, 0, 0, 0, 0), # Look Right + (-20, 0, 0, 1, 0, 0, 0), # Look Left + Forward + (20, 0, 0, 1, 0, 0, 0), # Look Right + Forward + (0, 0, 0, 0, 1, 0, 0), # Fire + ) - ACTION_SET_DEFAULT = ( - (0, 0, 0, 1, 0, 0, 0), # Forward - (0, 0, 0, -1, 0, 0, 0), # Backward - (0, 0, -1, 0, 0, 0, 0), # Strafe Left - (0, 0, 1, 0, 0, 0, 0), # Strafe Right - (-20, 0, 0, 0, 0, 0, 0), # Look Left - (20, 0, 0, 0, 0, 0, 0), # Look Right - (-20, 0, 0, 1, 0, 0, 0), # Look Left + Forward - (20, 0, 0, 1, 0, 0, 0), # Look Right + Forward - (0, 0, 0, 0, 1, 0, 0), # Fire - ) + ACTION_SET_MEDIUM = ( + (0, 0, 0, 1, 0, 0, 0), # Forward + (0, 0, 0, -1, 0, 0, 0), # Backward + (0, 0, -1, 0, 0, 0, 0), # Strafe Left + (0, 0, 1, 0, 0, 0, 0), # Strafe Right + (-20, 0, 0, 0, 0, 0, 0), # Look Left + (20, 0, 0, 0, 0, 0, 0), # Look Right + (0, 0, 0, 0, 0, 0, 0), # Idle. + ) - ACTION_SET_MEDIUM = ( - (0, 0, 0, 1, 0, 0, 0), # Forward - (0, 0, 0, -1, 0, 0, 0), # Backward - (0, 0, -1, 0, 0, 0, 0), # Strafe Left - (0, 0, 1, 0, 0, 0, 0), # Strafe Right - (-20, 0, 0, 0, 0, 0, 0), # Look Left - (20, 0, 0, 0, 0, 0, 0), # Look Right - (0, 0, 0, 0, 0, 0, 0), # Idle. - ) + ACTION_SET_SMALL = ( + (0, 0, 0, 1, 0, 0, 0), # Forward + (-20, 0, 0, 0, 0, 0, 0), # Look Left + (20, 0, 0, 0, 0, 0, 0), # Look Right + ) - ACTION_SET_SMALL = ( - (0, 0, 0, 1, 0, 0, 0), # Forward - (-20, 0, 0, 0, 0, 0, 0), # Look Left - (20, 0, 0, 0, 0, 0, 0), # Look Right - ) + def __init__( + self, + level, + mode, + action_repeat=4, + render_size=(64, 64), + action_set=ACTION_SET_DEFAULT, + level_cache=None, + seed=None, + runfiles_path=None, + ): + assert mode in ("train", "test") + if runfiles_path: + print("Setting DMLab runfiles path:", runfiles_path) + deepmind_lab.set_runfiles_path(runfiles_path) + self._config = {} + self._config["width"] = render_size[0] + self._config["height"] = render_size[1] + self._config["logLevel"] = "WARN" + if mode == "test": + self._config["allowHoldOutLevels"] = "true" + self._config["mixerSeed"] = 0x600D5EED + self._action_repeat = action_repeat + self._random = np.random.RandomState(seed) + self._env = deepmind_lab.Lab( + level="contributed/dmlab30/" + level, + observations=["RGB_INTERLEAVED"], + config={k: str(v) for k, v in self._config.items()}, + level_cache=level_cache, + ) + self._action_set = action_set + self._last_image = None + self._done = True - def __init__( - self, level, mode, action_repeat=4, render_size=(64, 64), - action_set=ACTION_SET_DEFAULT, level_cache=None, seed=None, - runfiles_path=None): - assert mode in ('train', 'test') - if runfiles_path: - print('Setting DMLab runfiles path:', runfiles_path) - deepmind_lab.set_runfiles_path(runfiles_path) - self._config = {} - self._config['width'] = render_size[0] - self._config['height'] = render_size[1] - self._config['logLevel'] = 'WARN' - if mode == 'test': - self._config['allowHoldOutLevels'] = 'true' - self._config['mixerSeed'] = 0x600D5EED - self._action_repeat = action_repeat - self._random = np.random.RandomState(seed) - self._env = deepmind_lab.Lab( - level='contributed/dmlab30/'+level, - observations=['RGB_INTERLEAVED'], - config={k: str(v) for k, v in self._config.items()}, - level_cache=level_cache) - self._action_set = action_set - self._last_image = None - self._done = True + @property + def observation_space(self): + shape = (self._config["height"], self._config["width"], 3) + space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8) + return gym.spaces.Dict({"image": space}) - @property - def observation_space(self): - shape = (self._config['height'], self._config['width'], 3) - space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8) - return gym.spaces.Dict({'image': space}) + @property + def action_space(self): + return gym.spaces.Discrete(len(self._action_set)) - @property - def action_space(self): - return gym.spaces.Discrete(len(self._action_set)) + def reset(self): + self._done = False + self._env.reset(seed=self._random.randint(0, 2**31 - 1)) + obs = self._get_obs() + return obs - def reset(self): - self._done = False - self._env.reset(seed=self._random.randint(0, 2 ** 31 - 1)) - obs = self._get_obs() - return obs + def step(self, action): + raw_action = np.array(self._action_set[action], np.intc) + reward = self._env.step(raw_action, num_steps=self._action_repeat) + self._done = not self._env.is_running() + obs = self._get_obs() + return obs, reward, self._done, {} - def step(self, action): - raw_action = np.array(self._action_set[action], np.intc) - reward = self._env.step(raw_action, num_steps=self._action_repeat) - self._done = not self._env.is_running() - obs = self._get_obs() - return obs, reward, self._done, {} + def render(self, *args, **kwargs): + if kwargs.get("mode", "rgb_array") != "rgb_array": + raise ValueError("Only render mode 'rgb_array' is supported.") + del args # Unused + del kwargs # Unused + return self._last_image - def render(self, *args, **kwargs): - if kwargs.get('mode', 'rgb_array') != 'rgb_array': - raise ValueError("Only render mode 'rgb_array' is supported.") - del args # Unused - del kwargs # Unused - return self._last_image + def close(self): + self._env.close() - def close(self): - self._env.close() - - def _get_obs(self): - if self._done: - image = 0 * self._last_image - else: - image = self._env.observations()['RGB_INTERLEAVED'] - self._last_image = image - return {'image': image} + def _get_obs(self): + if self._done: + image = 0 * self._last_image + else: + image = self._env.observations()["RGB_INTERLEAVED"] + self._last_image = image + return {"image": image} diff --git a/envs/wrappers.py b/envs/wrappers.py index 9341d03..177f2d9 100644 --- a/envs/wrappers.py +++ b/envs/wrappers.py @@ -3,186 +3,180 @@ import numpy as np class CollectDataset: + def __init__(self, env, callbacks=None, precision=32): + self._env = env + self._callbacks = callbacks or () + self._precision = precision + self._episode = None - def __init__(self, env, callbacks=None, precision=32): - self._env = env - self._callbacks = callbacks or () - self._precision = precision - self._episode = None + def __getattr__(self, name): + return getattr(self._env, name) - def __getattr__(self, name): - return getattr(self._env, name) + def step(self, action): + obs, reward, done, info = self._env.step(action) + obs = {k: self._convert(v) for k, v in obs.items()} + transition = obs.copy() + if isinstance(action, dict): + transition.update(action) + else: + transition["action"] = action + transition["reward"] = reward + transition["discount"] = info.get("discount", np.array(1 - float(done))) + self._episode.append(transition) + if done: + for key, value in self._episode[1].items(): + if key not in self._episode[0]: + self._episode[0][key] = 0 * value + episode = {k: [t[k] for t in self._episode] for k in self._episode[0]} + episode = {k: self._convert(v) for k, v in episode.items()} + info["episode"] = episode + for callback in self._callbacks: + callback(episode) + return obs, reward, done, info - def step(self, action): - obs, reward, done, info = self._env.step(action) - obs = {k: self._convert(v) for k, v in obs.items()} - transition = obs.copy() - if isinstance(action, dict): - transition.update(action) - else: - transition['action'] = action - transition['reward'] = reward - transition['discount'] = info.get('discount', np.array(1 - float(done))) - self._episode.append(transition) - if done: - for key, value in self._episode[1].items(): - if key not in self._episode[0]: - self._episode[0][key] = 0 * value - episode = {k: [t[k] for t in self._episode] for k in self._episode[0]} - episode = {k: self._convert(v) for k, v in episode.items()} - info['episode'] = episode - for callback in self._callbacks: - callback(episode) - return obs, reward, done, info + def reset(self): + obs = self._env.reset() + transition = obs.copy() + # Missing keys will be filled with a zeroed out version of the first + # transition, because we do not know what action information the agent will + # pass yet. + transition["reward"] = 0.0 + transition["discount"] = 1.0 + self._episode = [transition] + return obs - def reset(self): - obs = self._env.reset() - transition = obs.copy() - # Missing keys will be filled with a zeroed out version of the first - # transition, because we do not know what action information the agent will - # pass yet. - transition['reward'] = 0.0 - transition['discount'] = 1.0 - self._episode = [transition] - return obs - - def _convert(self, value): - value = np.array(value) - if np.issubdtype(value.dtype, np.floating): - dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self._precision] - elif np.issubdtype(value.dtype, np.signedinteger): - dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision] - elif np.issubdtype(value.dtype, np.uint8): - dtype = np.uint8 - elif np.issubdtype(value.dtype, np.bool): - dtype = np.bool - else: - raise NotImplementedError(value.dtype) - return value.astype(dtype) + def _convert(self, value): + value = np.array(value) + if np.issubdtype(value.dtype, np.floating): + dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self._precision] + elif np.issubdtype(value.dtype, np.signedinteger): + dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision] + elif np.issubdtype(value.dtype, np.uint8): + dtype = np.uint8 + elif np.issubdtype(value.dtype, np.bool): + dtype = np.bool + else: + raise NotImplementedError(value.dtype) + return value.astype(dtype) class TimeLimit: + def __init__(self, env, duration): + self._env = env + self._duration = duration + self._step = None - def __init__(self, env, duration): - self._env = env - self._duration = duration - self._step = None + def __getattr__(self, name): + return getattr(self._env, name) - def __getattr__(self, name): - return getattr(self._env, name) + def step(self, action): + assert self._step is not None, "Must reset environment." + obs, reward, done, info = self._env.step(action) + self._step += 1 + if self._step >= self._duration: + done = True + if "discount" not in info: + info["discount"] = np.array(1.0).astype(np.float32) + self._step = None + return obs, reward, done, info - def step(self, action): - assert self._step is not None, 'Must reset environment.' - obs, reward, done, info = self._env.step(action) - self._step += 1 - if self._step >= self._duration: - done = True - if 'discount' not in info: - info['discount'] = np.array(1.0).astype(np.float32) - self._step = None - return obs, reward, done, info - - def reset(self): - self._step = 0 - return self._env.reset() + def reset(self): + self._step = 0 + return self._env.reset() class NormalizeActions: + def __init__(self, env): + self._env = env + self._mask = np.logical_and( + np.isfinite(env.action_space.low), np.isfinite(env.action_space.high) + ) + self._low = np.where(self._mask, env.action_space.low, -1) + self._high = np.where(self._mask, env.action_space.high, 1) - def __init__(self, env): - self._env = env - self._mask = np.logical_and( - np.isfinite(env.action_space.low), - np.isfinite(env.action_space.high)) - self._low = np.where(self._mask, env.action_space.low, -1) - self._high = np.where(self._mask, env.action_space.high, 1) + def __getattr__(self, name): + return getattr(self._env, name) - def __getattr__(self, name): - return getattr(self._env, name) + @property + def action_space(self): + low = np.where(self._mask, -np.ones_like(self._low), self._low) + high = np.where(self._mask, np.ones_like(self._low), self._high) + return gym.spaces.Box(low, high, dtype=np.float32) - @property - def action_space(self): - low = np.where(self._mask, -np.ones_like(self._low), self._low) - high = np.where(self._mask, np.ones_like(self._low), self._high) - return gym.spaces.Box(low, high, dtype=np.float32) - - def step(self, action): - original = (action + 1) / 2 * (self._high - self._low) + self._low - original = np.where(self._mask, original, action) - return self._env.step(original) + def step(self, action): + original = (action + 1) / 2 * (self._high - self._low) + self._low + original = np.where(self._mask, original, action) + return self._env.step(original) class OneHotAction: + def __init__(self, env): + assert isinstance(env.action_space, gym.spaces.Discrete) + self._env = env + self._random = np.random.RandomState() - def __init__(self, env): - assert isinstance(env.action_space, gym.spaces.Discrete) - self._env = env - self._random = np.random.RandomState() + def __getattr__(self, name): + return getattr(self._env, name) - def __getattr__(self, name): - return getattr(self._env, name) + @property + def action_space(self): + shape = (self._env.action_space.n,) + space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) + space.sample = self._sample_action + space.discrete = True + return space - @property - def action_space(self): - shape = (self._env.action_space.n,) - space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) - space.sample = self._sample_action - space.discrete = True - return space + def step(self, action): + index = np.argmax(action).astype(int) + reference = np.zeros_like(action) + reference[index] = 1 + if not np.allclose(reference, action): + raise ValueError(f"Invalid one-hot action:\n{action}") + return self._env.step(index) - def step(self, action): - index = np.argmax(action).astype(int) - reference = np.zeros_like(action) - reference[index] = 1 - if not np.allclose(reference, action): - raise ValueError(f'Invalid one-hot action:\n{action}') - return self._env.step(index) + def reset(self): + return self._env.reset() - def reset(self): - return self._env.reset() - - def _sample_action(self): - actions = self._env.action_space.n - index = self._random.randint(0, actions) - reference = np.zeros(actions, dtype=np.float32) - reference[index] = 1.0 - return reference + def _sample_action(self): + actions = self._env.action_space.n + index = self._random.randint(0, actions) + reference = np.zeros(actions, dtype=np.float32) + reference[index] = 1.0 + return reference class RewardObs: + def __init__(self, env): + self._env = env - def __init__(self, env): - self._env = env + def __getattr__(self, name): + return getattr(self._env, name) - def __getattr__(self, name): - return getattr(self._env, name) + @property + def observation_space(self): + spaces = self._env.observation_space.spaces + assert "reward" not in spaces + spaces["reward"] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32) + return gym.spaces.Dict(spaces) - @property - def observation_space(self): - spaces = self._env.observation_space.spaces - assert 'reward' not in spaces - spaces['reward'] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32) - return gym.spaces.Dict(spaces) + def step(self, action): + obs, reward, done, info = self._env.step(action) + obs["reward"] = reward + return obs, reward, done, info - def step(self, action): - obs, reward, done, info = self._env.step(action) - obs['reward'] = reward - return obs, reward, done, info - - def reset(self): - obs = self._env.reset() - obs['reward'] = 0.0 - return obs + def reset(self): + obs = self._env.reset() + obs["reward"] = 0.0 + return obs class SelectAction: + def __init__(self, env, key): + self._env = env + self._key = key - def __init__(self, env, key): - self._env = env - self._key = key + def __getattr__(self, name): + return getattr(self._env, name) - def __getattr__(self, name): - return getattr(self._env, name) - - def step(self, action): - return self._env.step(action[self._key]) + def step(self, action): + return self._env.step(action[self._key])