modified the memorymaze environment

This commit is contained in:
NM512 2023-08-16 21:54:09 +09:00
parent 606ec8af8c
commit 7607a92d71
3 changed files with 10 additions and 18 deletions

View File

@ -200,20 +200,16 @@ minecraft:
break_speed: 100.0
time_limit: 36000
memorymaze:
steps: 1e8
action_repeat: 2
actor_dist: 'onehot'
imag_gradient: 'reinforce'
task: 'memorymaze_9x9'
debug:
debug: True
pretrain: 1
prefill: 1
batch_size: 10
batch_length: 20
MemoryMaze:
steps: 1e8
action_repeat: 2
actor_dist: 'onehot'
imag_gradient: 'reinforce'
task: 'MemoryMaze_9x9'

View File

@ -216,7 +216,7 @@ def make_env(config, mode):
seed=config.seed,
)
env = wrappers.OneHotAction(env)
elif suite == "MemoryMaze":
elif suite == "memorymaze":
from envs.memorymaze import MemoryMaze
env = MemoryMaze(task, seed=config.seed)

View File

@ -6,12 +6,8 @@ import numpy as np
class MemoryMaze:
def __init__(self, task, obs_key="image", act_key="action", size=(64, 64), seed=0):
if task == "9x9":
self._env = gym.make("memory_maze:MemoryMaze-9x9-v0", seed=seed)
elif task == "15x15":
self._env = gym.make("memory_maze:MemoryMaze-15x15-v0", seed=seed)
else:
raise NotImplementedError(task)
# 9x9, 11x11, 13x13 and 15x15 are available
self._env = gym.make(f"memory_maze:MemoryMaze-{task}-v0", seed=seed)
self._obs_is_dict = hasattr(self._env.observation_space, "spaces")
self._obs_key = obs_key
self._act_key = act_key