From e3329b35e5378ad570d96125fff4102af0a35224 Mon Sep 17 00:00:00 2001 From: NM512 Date: Sun, 18 Jun 2023 16:57:05 +0900 Subject: [PATCH] applied formatter --- dreamer.py | 10 ++-- envs/memorymaze.py | 137 ++++++++++++++++++++++----------------------- 2 files changed, 74 insertions(+), 73 deletions(-) diff --git a/dreamer.py b/dreamer.py index fec590f..782d13a 100644 --- a/dreamer.py +++ b/dreamer.py @@ -213,13 +213,15 @@ def make_env(config, logger, mode, train_eps, eval_eps): env = wrappers.OneHotAction(env) elif suite == "MemoryMaze": import gym - if task == '9x9': - env = gym.make('memory_maze:MemoryMaze-9x9-v0') - elif task == '15x15': - env = gym.make('memory_maze:MemoryMaze-15x15-v0') + + if task == "9x9": + env = gym.make("memory_maze:MemoryMaze-9x9-v0") + elif task == "15x15": + env = gym.make("memory_maze:MemoryMaze-15x15-v0") else: raise NotImplementedError(suite) from envs.memorymaze import MemoryMaze + env = MemoryMaze(env) env = wrappers.OneHotAction(env) elif suite == "crafter": diff --git a/envs/memorymaze.py b/envs/memorymaze.py index a194368..09603a3 100644 --- a/envs/memorymaze.py +++ b/envs/memorymaze.py @@ -8,80 +8,79 @@ import numpy as np ###from tf dreamerv2 code + class MemoryMaze: + def __init__(self, env, obs_key="image", act_key="action", size=(64, 64)): + self._env = env + self._obs_is_dict = hasattr(self._env.observation_space, "spaces") + self._act_is_dict = hasattr(self._env.action_space, "spaces") + self._obs_key = obs_key + self._act_key = act_key + self._size = size + self._gray = False - def __init__(self, env, obs_key='image', act_key='action', size=(64, 64)): - self._env = env - self._obs_is_dict = hasattr(self._env.observation_space, 'spaces') - self._act_is_dict = hasattr(self._env.action_space, 'spaces') - self._obs_key = obs_key - self._act_key = act_key - self._size = size - self._gray = False + def __getattr__(self, name): + if name.startswith("__"): + raise AttributeError(name) + try: + return getattr(self._env, name) + except AttributeError: + raise ValueError(name) - def __getattr__(self, name): - if name.startswith('__'): - raise AttributeError(name) - try: - return getattr(self._env, name) - except AttributeError: - raise ValueError(name) + @property + def obs_space(self): + if self._obs_is_dict: + spaces = self._env.observation_space.spaces.copy() + else: + spaces = {self._obs_key: self._env.observation_space} + return { + **spaces, + "reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), + "is_first": gym.spaces.Box(0, 1, (), dtype=np.bool), + "is_last": gym.spaces.Box(0, 1, (), dtype=np.bool), + "is_terminal": gym.spaces.Box(0, 1, (), dtype=np.bool), + } - @property - def obs_space(self): - if self._obs_is_dict: - spaces = self._env.observation_space.spaces.copy() - else: - spaces = {self._obs_key: self._env.observation_space} - return { - **spaces, - 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), - 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), - 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), - 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), - } + @property + def act_space(self): + if self._act_is_dict: + return self._env.action_space.spaces.copy() + else: + return {self._act_key: self._env.action_space} - @property - def act_space(self): - if self._act_is_dict: - return self._env.action_space.spaces.copy() - else: - return {self._act_key: self._env.action_space} + @property + def observation_space(self): + img_shape = self._size + ((1,) if self._gray else (3,)) + return gym.spaces.Dict( + { + "image": gym.spaces.Box(0, 255, img_shape, np.uint8), + } + ) - @property - def observation_space(self): - img_shape = self._size + ((1,) if self._gray else (3,)) - return gym.spaces.Dict( - { - "image": gym.spaces.Box(0, 255, img_shape, np.uint8), - } - ) + @property + def action_space(self): + space = self._env.action_space + space.discrete = True + return space - @property - def action_space(self): - space = self._env.action_space - space.discrete = True - return space - - def step(self, action): - # if not self._act_is_dict: - # action = action[self._act_key] - obs, reward, done, info = self._env.step(action) - if not self._obs_is_dict: - obs = {self._obs_key: obs} - # obs['reward'] = float(reward) - obs['is_first'] = False - obs['is_last'] = done - obs['is_terminal'] = info.get('is_terminal', False) - return obs, reward, done, info - - def reset(self): - obs = self._env.reset() - if not self._obs_is_dict: - obs = {self._obs_key: obs} - obs['reward'] = 0.0 - obs['is_first'] = True - obs['is_last'] = False - obs['is_terminal'] = False - return obs + def step(self, action): + # if not self._act_is_dict: + # action = action[self._act_key] + obs, reward, done, info = self._env.step(action) + if not self._obs_is_dict: + obs = {self._obs_key: obs} + # obs['reward'] = float(reward) + obs["is_first"] = False + obs["is_last"] = done + obs["is_terminal"] = info.get("is_terminal", False) + return obs, reward, done, info + def reset(self): + obs = self._env.reset() + if not self._obs_is_dict: + obs = {self._obs_key: obs} + obs["reward"] = 0.0 + obs["is_first"] = True + obs["is_last"] = False + obs["is_terminal"] = False + return obs