modified memory maze and dependencies
This commit is contained in:
parent
bc7bd6f704
commit
edc26e42ed
12
dreamer.py
12
dreamer.py
@ -212,21 +212,11 @@ def make_env(config, logger, mode, train_eps, eval_eps):
|
|||||||
)
|
)
|
||||||
env = wrappers.OneHotAction(env)
|
env = wrappers.OneHotAction(env)
|
||||||
elif suite == "MemoryMaze":
|
elif suite == "MemoryMaze":
|
||||||
import gym
|
|
||||||
|
|
||||||
if task == "9x9":
|
|
||||||
env = gym.make("memory_maze:MemoryMaze-9x9-v0")
|
|
||||||
elif task == "15x15":
|
|
||||||
env = gym.make("memory_maze:MemoryMaze-15x15-v0")
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(suite)
|
|
||||||
from envs.memorymaze import MemoryMaze
|
from envs.memorymaze import MemoryMaze
|
||||||
|
env = MemoryMaze(task)
|
||||||
env = MemoryMaze(env)
|
|
||||||
env = wrappers.OneHotAction(env)
|
env = wrappers.OneHotAction(env)
|
||||||
elif suite == "crafter":
|
elif suite == "crafter":
|
||||||
import envs.crafter as crafter
|
import envs.crafter as crafter
|
||||||
|
|
||||||
env = crafter.Crafter(task, config.size)
|
env = crafter.Crafter(task, config.size)
|
||||||
env = wrappers.OneHotAction(env)
|
env = wrappers.OneHotAction(env)
|
||||||
else:
|
else:
|
||||||
|
@ -5,10 +5,14 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
class MemoryMaze:
|
class MemoryMaze:
|
||||||
def __init__(self, env, obs_key="image", act_key="action", size=(64, 64)):
|
def __init__(self, task, obs_key="image", act_key="action", size=(64, 64)):
|
||||||
self._env = env
|
if task == "9x9":
|
||||||
|
self._env = gym.make("memory_maze:MemoryMaze-9x9-v0")
|
||||||
|
elif task == "15x15":
|
||||||
|
self._env = gym.make("memory_maze:MemoryMaze-15x15-v0")
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(task)
|
||||||
self._obs_is_dict = hasattr(self._env.observation_space, "spaces")
|
self._obs_is_dict = hasattr(self._env.observation_space, "spaces")
|
||||||
self._act_is_dict = hasattr(self._env.action_space, "spaces")
|
|
||||||
self._obs_key = obs_key
|
self._obs_key = obs_key
|
||||||
self._act_key = act_key
|
self._act_key = act_key
|
||||||
self._size = size
|
self._size = size
|
||||||
@ -23,35 +27,18 @@ class MemoryMaze:
|
|||||||
raise ValueError(name)
|
raise ValueError(name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def obs_space(self):
|
def observation_space(self):
|
||||||
if self._obs_is_dict:
|
if self._obs_is_dict:
|
||||||
spaces = self._env.observation_space.spaces.copy()
|
spaces = self._env.observation_space.spaces.copy()
|
||||||
else:
|
else:
|
||||||
spaces = {self._obs_key: self._env.observation_space}
|
spaces = {self._obs_key: self._env.observation_space}
|
||||||
return {
|
return gym.spaces.Dict({
|
||||||
**spaces,
|
**spaces,
|
||||||
"reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32),
|
"reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32),
|
||||||
"is_first": gym.spaces.Box(0, 1, (), dtype=bool),
|
"is_first": gym.spaces.Box(0, 1, (), dtype=bool),
|
||||||
"is_last": gym.spaces.Box(0, 1, (), dtype=bool),
|
"is_last": gym.spaces.Box(0, 1, (), dtype=bool),
|
||||||
"is_terminal": gym.spaces.Box(0, 1, (), dtype=bool),
|
"is_terminal": gym.spaces.Box(0, 1, (), dtype=bool),
|
||||||
}
|
})
|
||||||
|
|
||||||
@property
|
|
||||||
def act_space(self):
|
|
||||||
if self._act_is_dict:
|
|
||||||
return self._env.action_space.spaces.copy()
|
|
||||||
else:
|
|
||||||
return {self._act_key: self._env.action_space}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def observation_space(self):
|
|
||||||
img_shape = self._size + ((1,) if self._gray else (3,))
|
|
||||||
return gym.spaces.Dict(
|
|
||||||
{
|
|
||||||
"image": gym.spaces.Box(0, 255, img_shape, np.uint8),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_space(self):
|
def action_space(self):
|
||||||
space = self._env.action_space
|
space = self._env.action_space
|
||||||
@ -59,12 +46,10 @@ class MemoryMaze:
|
|||||||
return space
|
return space
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
# if not self._act_is_dict:
|
|
||||||
# action = action[self._act_key]
|
|
||||||
obs, reward, done, info = self._env.step(action)
|
obs, reward, done, info = self._env.step(action)
|
||||||
if not self._obs_is_dict:
|
if not self._obs_is_dict:
|
||||||
obs = {self._obs_key: obs}
|
obs = {self._obs_key: obs}
|
||||||
# obs['reward'] = float(reward)
|
obs['reward'] = reward
|
||||||
obs["is_first"] = False
|
obs["is_first"] = False
|
||||||
obs["is_last"] = done
|
obs["is_last"] = done
|
||||||
obs["is_terminal"] = info.get("is_terminal", False)
|
obs["is_terminal"] = info.get("is_terminal", False)
|
||||||
|
@ -1,13 +1,15 @@
|
|||||||
|
setuptools==60.0.0
|
||||||
torch==2.0.0
|
torch==2.0.0
|
||||||
torchvision==0.15.1
|
torchvision==0.15.1
|
||||||
numpy==1.20.1
|
numpy==1.20.1
|
||||||
tensorboard==2.5.0
|
tensorboard==2.5.0
|
||||||
pandas==1.2.4
|
pandas==1.2.4
|
||||||
matplotlib==3.4.1
|
matplotlib==3.5.0
|
||||||
ruamel.yaml==0.17.4
|
ruamel.yaml==0.17.4
|
||||||
moviepy==1.0.3
|
moviepy==1.0.3
|
||||||
einops==0.3.0
|
einops==0.3.0
|
||||||
protobuf==3.20.0
|
protobuf==3.20.0
|
||||||
gym==0.19.0
|
gym==0.19.0
|
||||||
dm_control==1.0.9
|
dm_control==1.0.9
|
||||||
scipy==scipy==1.9.0
|
scipy==1.7.0
|
||||||
|
memory_maze
|
Loading…
x
Reference in New Issue
Block a user