diff --git a/configs.yaml b/configs.yaml index a9bd5bd..35e3fea 100644 --- a/configs.yaml +++ b/configs.yaml @@ -200,20 +200,16 @@ minecraft: break_speed: 100.0 time_limit: 36000 +memorymaze: + steps: 1e8 + action_repeat: 2 + actor_dist: 'onehot' + imag_gradient: 'reinforce' + task: 'memorymaze_9x9' + debug: debug: True pretrain: 1 prefill: 1 batch_size: 10 batch_length: 20 - -MemoryMaze: - steps: 1e8 - action_repeat: 2 - actor_dist: 'onehot' - imag_gradient: 'reinforce' - task: 'MemoryMaze_9x9' - - - - diff --git a/dreamer.py b/dreamer.py index 1b82f60..b3c5705 100644 --- a/dreamer.py +++ b/dreamer.py @@ -216,7 +216,7 @@ def make_env(config, mode): seed=config.seed, ) env = wrappers.OneHotAction(env) - elif suite == "MemoryMaze": + elif suite == "memorymaze": from envs.memorymaze import MemoryMaze env = MemoryMaze(task, seed=config.seed) diff --git a/envs/memorymaze.py b/envs/memorymaze.py index 19ca980..31f9178 100644 --- a/envs/memorymaze.py +++ b/envs/memorymaze.py @@ -6,12 +6,8 @@ import numpy as np class MemoryMaze: def __init__(self, task, obs_key="image", act_key="action", size=(64, 64), seed=0): - if task == "9x9": - self._env = gym.make("memory_maze:MemoryMaze-9x9-v0", seed=seed) - elif task == "15x15": - self._env = gym.make("memory_maze:MemoryMaze-15x15-v0", seed=seed) - else: - raise NotImplementedError(task) + # 9x9, 11x11, 13x13 and 15x15 are available + self._env = gym.make(f"memory_maze:MemoryMaze-{task}-v0", seed=seed) self._obs_is_dict = hasattr(self._env.observation_space, "spaces") self._obs_key = obs_key self._act_key = act_key