diff --git a/dreamer.py b/dreamer.py index 782d13a..3224e51 100644 --- a/dreamer.py +++ b/dreamer.py @@ -212,21 +212,11 @@ def make_env(config, logger, mode, train_eps, eval_eps): ) env = wrappers.OneHotAction(env) elif suite == "MemoryMaze": - import gym - - if task == "9x9": - env = gym.make("memory_maze:MemoryMaze-9x9-v0") - elif task == "15x15": - env = gym.make("memory_maze:MemoryMaze-15x15-v0") - else: - raise NotImplementedError(suite) from envs.memorymaze import MemoryMaze - - env = MemoryMaze(env) + env = MemoryMaze(task) env = wrappers.OneHotAction(env) elif suite == "crafter": import envs.crafter as crafter - env = crafter.Crafter(task, config.size) env = wrappers.OneHotAction(env) else: diff --git a/envs/memorymaze.py b/envs/memorymaze.py index 0eaefd0..d82971f 100644 --- a/envs/memorymaze.py +++ b/envs/memorymaze.py @@ -5,10 +5,14 @@ import numpy as np class MemoryMaze: - def __init__(self, env, obs_key="image", act_key="action", size=(64, 64)): - self._env = env + def __init__(self, task, obs_key="image", act_key="action", size=(64, 64)): + if task == "9x9": + self._env = gym.make("memory_maze:MemoryMaze-9x9-v0") + elif task == "15x15": + self._env = gym.make("memory_maze:MemoryMaze-15x15-v0") + else: + raise NotImplementedError(task) self._obs_is_dict = hasattr(self._env.observation_space, "spaces") - self._act_is_dict = hasattr(self._env.action_space, "spaces") self._obs_key = obs_key self._act_key = act_key self._size = size @@ -23,35 +27,18 @@ class MemoryMaze: raise ValueError(name) @property - def obs_space(self): + def observation_space(self): if self._obs_is_dict: spaces = self._env.observation_space.spaces.copy() else: spaces = {self._obs_key: self._env.observation_space} - return { + return gym.spaces.Dict({ **spaces, "reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), "is_first": gym.spaces.Box(0, 1, (), dtype=bool), "is_last": gym.spaces.Box(0, 1, (), dtype=bool), "is_terminal": gym.spaces.Box(0, 1, (), dtype=bool), - } - - @property - def act_space(self): - if self._act_is_dict: - return self._env.action_space.spaces.copy() - else: - return {self._act_key: self._env.action_space} - - @property - def observation_space(self): - img_shape = self._size + ((1,) if self._gray else (3,)) - return gym.spaces.Dict( - { - "image": gym.spaces.Box(0, 255, img_shape, np.uint8), - } - ) - + }) @property def action_space(self): space = self._env.action_space @@ -59,12 +46,10 @@ class MemoryMaze: return space def step(self, action): - # if not self._act_is_dict: - # action = action[self._act_key] obs, reward, done, info = self._env.step(action) if not self._obs_is_dict: obs = {self._obs_key: obs} - # obs['reward'] = float(reward) + obs['reward'] = reward obs["is_first"] = False obs["is_last"] = done obs["is_terminal"] = info.get("is_terminal", False) diff --git a/requirements.txt b/requirements.txt index 934d685..c47853e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,15 @@ +setuptools==60.0.0 torch==2.0.0 torchvision==0.15.1 numpy==1.20.1 tensorboard==2.5.0 pandas==1.2.4 -matplotlib==3.4.1 +matplotlib==3.5.0 ruamel.yaml==0.17.4 moviepy==1.0.3 einops==0.3.0 protobuf==3.20.0 gym==0.19.0 dm_control==1.0.9 -scipy==scipy==1.9.0 \ No newline at end of file +scipy==1.7.0 +memory_maze \ No newline at end of file