applied formatter

This commit is contained in:
NM512 2023-06-18 16:57:05 +09:00
parent 775eb94e7f
commit e3329b35e5
2 changed files with 74 additions and 73 deletions

View File

@ -213,13 +213,15 @@ def make_env(config, logger, mode, train_eps, eval_eps):
env = wrappers.OneHotAction(env) env = wrappers.OneHotAction(env)
elif suite == "MemoryMaze": elif suite == "MemoryMaze":
import gym import gym
if task == '9x9':
env = gym.make('memory_maze:MemoryMaze-9x9-v0') if task == "9x9":
elif task == '15x15': env = gym.make("memory_maze:MemoryMaze-9x9-v0")
env = gym.make('memory_maze:MemoryMaze-15x15-v0') elif task == "15x15":
env = gym.make("memory_maze:MemoryMaze-15x15-v0")
else: else:
raise NotImplementedError(suite) raise NotImplementedError(suite)
from envs.memorymaze import MemoryMaze from envs.memorymaze import MemoryMaze
env = MemoryMaze(env) env = MemoryMaze(env)
env = wrappers.OneHotAction(env) env = wrappers.OneHotAction(env)
elif suite == "crafter": elif suite == "crafter":

View File

@ -8,80 +8,79 @@ import numpy as np
###from tf dreamerv2 code ###from tf dreamerv2 code
class MemoryMaze: class MemoryMaze:
def __init__(self, env, obs_key="image", act_key="action", size=(64, 64)):
self._env = env
self._obs_is_dict = hasattr(self._env.observation_space, "spaces")
self._act_is_dict = hasattr(self._env.action_space, "spaces")
self._obs_key = obs_key
self._act_key = act_key
self._size = size
self._gray = False
def __init__(self, env, obs_key='image', act_key='action', size=(64, 64)): def __getattr__(self, name):
self._env = env if name.startswith("__"):
self._obs_is_dict = hasattr(self._env.observation_space, 'spaces') raise AttributeError(name)
self._act_is_dict = hasattr(self._env.action_space, 'spaces') try:
self._obs_key = obs_key return getattr(self._env, name)
self._act_key = act_key except AttributeError:
self._size = size raise ValueError(name)
self._gray = False
def __getattr__(self, name): @property
if name.startswith('__'): def obs_space(self):
raise AttributeError(name) if self._obs_is_dict:
try: spaces = self._env.observation_space.spaces.copy()
return getattr(self._env, name) else:
except AttributeError: spaces = {self._obs_key: self._env.observation_space}
raise ValueError(name) return {
**spaces,
"reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32),
"is_first": gym.spaces.Box(0, 1, (), dtype=np.bool),
"is_last": gym.spaces.Box(0, 1, (), dtype=np.bool),
"is_terminal": gym.spaces.Box(0, 1, (), dtype=np.bool),
}
@property @property
def obs_space(self): def act_space(self):
if self._obs_is_dict: if self._act_is_dict:
spaces = self._env.observation_space.spaces.copy() return self._env.action_space.spaces.copy()
else: else:
spaces = {self._obs_key: self._env.observation_space} return {self._act_key: self._env.action_space}
return {
**spaces,
'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32),
'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool),
'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool),
'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool),
}
@property @property
def act_space(self): def observation_space(self):
if self._act_is_dict: img_shape = self._size + ((1,) if self._gray else (3,))
return self._env.action_space.spaces.copy() return gym.spaces.Dict(
else: {
return {self._act_key: self._env.action_space} "image": gym.spaces.Box(0, 255, img_shape, np.uint8),
}
)
@property @property
def observation_space(self): def action_space(self):
img_shape = self._size + ((1,) if self._gray else (3,)) space = self._env.action_space
return gym.spaces.Dict( space.discrete = True
{ return space
"image": gym.spaces.Box(0, 255, img_shape, np.uint8),
}
)
@property def step(self, action):
def action_space(self): # if not self._act_is_dict:
space = self._env.action_space # action = action[self._act_key]
space.discrete = True obs, reward, done, info = self._env.step(action)
return space if not self._obs_is_dict:
obs = {self._obs_key: obs}
def step(self, action): # obs['reward'] = float(reward)
# if not self._act_is_dict: obs["is_first"] = False
# action = action[self._act_key] obs["is_last"] = done
obs, reward, done, info = self._env.step(action) obs["is_terminal"] = info.get("is_terminal", False)
if not self._obs_is_dict: return obs, reward, done, info
obs = {self._obs_key: obs}
# obs['reward'] = float(reward)
obs['is_first'] = False
obs['is_last'] = done
obs['is_terminal'] = info.get('is_terminal', False)
return obs, reward, done, info
def reset(self):
obs = self._env.reset()
if not self._obs_is_dict:
obs = {self._obs_key: obs}
obs['reward'] = 0.0
obs['is_first'] = True
obs['is_last'] = False
obs['is_terminal'] = False
return obs
def reset(self):
obs = self._env.reset()
if not self._obs_is_dict:
obs = {self._obs_key: obs}
obs["reward"] = 0.0
obs["is_first"] = True
obs["is_last"] = False
obs["is_terminal"] = False
return obs