modified the memorymaze environment
This commit is contained in:
parent
606ec8af8c
commit
7607a92d71
18
configs.yaml
18
configs.yaml
@ -200,20 +200,16 @@ minecraft:
|
|||||||
break_speed: 100.0
|
break_speed: 100.0
|
||||||
time_limit: 36000
|
time_limit: 36000
|
||||||
|
|
||||||
|
memorymaze:
|
||||||
|
steps: 1e8
|
||||||
|
action_repeat: 2
|
||||||
|
actor_dist: 'onehot'
|
||||||
|
imag_gradient: 'reinforce'
|
||||||
|
task: 'memorymaze_9x9'
|
||||||
|
|
||||||
debug:
|
debug:
|
||||||
debug: True
|
debug: True
|
||||||
pretrain: 1
|
pretrain: 1
|
||||||
prefill: 1
|
prefill: 1
|
||||||
batch_size: 10
|
batch_size: 10
|
||||||
batch_length: 20
|
batch_length: 20
|
||||||
|
|
||||||
MemoryMaze:
|
|
||||||
steps: 1e8
|
|
||||||
action_repeat: 2
|
|
||||||
actor_dist: 'onehot'
|
|
||||||
imag_gradient: 'reinforce'
|
|
||||||
task: 'MemoryMaze_9x9'
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -216,7 +216,7 @@ def make_env(config, mode):
|
|||||||
seed=config.seed,
|
seed=config.seed,
|
||||||
)
|
)
|
||||||
env = wrappers.OneHotAction(env)
|
env = wrappers.OneHotAction(env)
|
||||||
elif suite == "MemoryMaze":
|
elif suite == "memorymaze":
|
||||||
from envs.memorymaze import MemoryMaze
|
from envs.memorymaze import MemoryMaze
|
||||||
|
|
||||||
env = MemoryMaze(task, seed=config.seed)
|
env = MemoryMaze(task, seed=config.seed)
|
||||||
|
@ -6,12 +6,8 @@ import numpy as np
|
|||||||
|
|
||||||
class MemoryMaze:
|
class MemoryMaze:
|
||||||
def __init__(self, task, obs_key="image", act_key="action", size=(64, 64), seed=0):
|
def __init__(self, task, obs_key="image", act_key="action", size=(64, 64), seed=0):
|
||||||
if task == "9x9":
|
# 9x9, 11x11, 13x13 and 15x15 are available
|
||||||
self._env = gym.make("memory_maze:MemoryMaze-9x9-v0", seed=seed)
|
self._env = gym.make(f"memory_maze:MemoryMaze-{task}-v0", seed=seed)
|
||||||
elif task == "15x15":
|
|
||||||
self._env = gym.make("memory_maze:MemoryMaze-15x15-v0", seed=seed)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(task)
|
|
||||||
self._obs_is_dict = hasattr(self._env.observation_space, "spaces")
|
self._obs_is_dict = hasattr(self._env.observation_space, "spaces")
|
||||||
self._obs_key = obs_key
|
self._obs_key = obs_key
|
||||||
self._act_key = act_key
|
self._act_key = act_key
|
||||||
|
Loading…
x
Reference in New Issue
Block a user