modified memory maze and dependencies

This commit is contained in:
NM512 2023-06-18 19:42:48 +09:00
parent bc7bd6f704
commit edc26e42ed
3 changed files with 16 additions and 39 deletions

View File

@ -212,21 +212,11 @@ def make_env(config, logger, mode, train_eps, eval_eps):
) )
env = wrappers.OneHotAction(env) env = wrappers.OneHotAction(env)
elif suite == "MemoryMaze": elif suite == "MemoryMaze":
import gym
if task == "9x9":
env = gym.make("memory_maze:MemoryMaze-9x9-v0")
elif task == "15x15":
env = gym.make("memory_maze:MemoryMaze-15x15-v0")
else:
raise NotImplementedError(suite)
from envs.memorymaze import MemoryMaze from envs.memorymaze import MemoryMaze
env = MemoryMaze(task)
env = MemoryMaze(env)
env = wrappers.OneHotAction(env) env = wrappers.OneHotAction(env)
elif suite == "crafter": elif suite == "crafter":
import envs.crafter as crafter import envs.crafter as crafter
env = crafter.Crafter(task, config.size) env = crafter.Crafter(task, config.size)
env = wrappers.OneHotAction(env) env = wrappers.OneHotAction(env)
else: else:

View File

@ -5,10 +5,14 @@ import numpy as np
class MemoryMaze: class MemoryMaze:
def __init__(self, env, obs_key="image", act_key="action", size=(64, 64)): def __init__(self, task, obs_key="image", act_key="action", size=(64, 64)):
self._env = env if task == "9x9":
self._env = gym.make("memory_maze:MemoryMaze-9x9-v0")
elif task == "15x15":
self._env = gym.make("memory_maze:MemoryMaze-15x15-v0")
else:
raise NotImplementedError(task)
self._obs_is_dict = hasattr(self._env.observation_space, "spaces") self._obs_is_dict = hasattr(self._env.observation_space, "spaces")
self._act_is_dict = hasattr(self._env.action_space, "spaces")
self._obs_key = obs_key self._obs_key = obs_key
self._act_key = act_key self._act_key = act_key
self._size = size self._size = size
@ -23,35 +27,18 @@ class MemoryMaze:
raise ValueError(name) raise ValueError(name)
@property @property
def obs_space(self): def observation_space(self):
if self._obs_is_dict: if self._obs_is_dict:
spaces = self._env.observation_space.spaces.copy() spaces = self._env.observation_space.spaces.copy()
else: else:
spaces = {self._obs_key: self._env.observation_space} spaces = {self._obs_key: self._env.observation_space}
return { return gym.spaces.Dict({
**spaces, **spaces,
"reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), "reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32),
"is_first": gym.spaces.Box(0, 1, (), dtype=bool), "is_first": gym.spaces.Box(0, 1, (), dtype=bool),
"is_last": gym.spaces.Box(0, 1, (), dtype=bool), "is_last": gym.spaces.Box(0, 1, (), dtype=bool),
"is_terminal": gym.spaces.Box(0, 1, (), dtype=bool), "is_terminal": gym.spaces.Box(0, 1, (), dtype=bool),
} })
@property
def act_space(self):
if self._act_is_dict:
return self._env.action_space.spaces.copy()
else:
return {self._act_key: self._env.action_space}
@property
def observation_space(self):
img_shape = self._size + ((1,) if self._gray else (3,))
return gym.spaces.Dict(
{
"image": gym.spaces.Box(0, 255, img_shape, np.uint8),
}
)
@property @property
def action_space(self): def action_space(self):
space = self._env.action_space space = self._env.action_space
@ -59,12 +46,10 @@ class MemoryMaze:
return space return space
def step(self, action): def step(self, action):
# if not self._act_is_dict:
# action = action[self._act_key]
obs, reward, done, info = self._env.step(action) obs, reward, done, info = self._env.step(action)
if not self._obs_is_dict: if not self._obs_is_dict:
obs = {self._obs_key: obs} obs = {self._obs_key: obs}
# obs['reward'] = float(reward) obs['reward'] = reward
obs["is_first"] = False obs["is_first"] = False
obs["is_last"] = done obs["is_last"] = done
obs["is_terminal"] = info.get("is_terminal", False) obs["is_terminal"] = info.get("is_terminal", False)

View File

@ -1,13 +1,15 @@
setuptools==60.0.0
torch==2.0.0 torch==2.0.0
torchvision==0.15.1 torchvision==0.15.1
numpy==1.20.1 numpy==1.20.1
tensorboard==2.5.0 tensorboard==2.5.0
pandas==1.2.4 pandas==1.2.4
matplotlib==3.4.1 matplotlib==3.5.0
ruamel.yaml==0.17.4 ruamel.yaml==0.17.4
moviepy==1.0.3 moviepy==1.0.3
einops==0.3.0 einops==0.3.0
protobuf==3.20.0 protobuf==3.20.0
gym==0.19.0 gym==0.19.0
dm_control==1.0.9 dm_control==1.0.9
scipy==scipy==1.9.0 scipy==1.7.0
memory_maze