dreamerv3-torch/envs/memorymaze.py

67 lines
2.0 KiB
Python
Raw Normal View History

2023-06-13 09:58:03 +08:00
import gym
import numpy as np
2023-06-17 23:29:53 +08:00
###from tf dreamerv2 code
2023-06-13 09:58:03 +08:00
2023-06-18 16:57:05 +09:00
class MemoryMaze:
2023-06-18 19:42:48 +09:00
def __init__(self, task, obs_key="image", act_key="action", size=(64, 64)):
if task == "9x9":
self._env = gym.make("memory_maze:MemoryMaze-9x9-v0")
elif task == "15x15":
self._env = gym.make("memory_maze:MemoryMaze-15x15-v0")
else:
raise NotImplementedError(task)
2023-06-18 16:57:05 +09:00
self._obs_is_dict = hasattr(self._env.observation_space, "spaces")
self._obs_key = obs_key
self._act_key = act_key
self._size = size
self._gray = False
2023-06-13 09:58:03 +08:00
2023-06-18 16:57:05 +09:00
def __getattr__(self, name):
if name.startswith("__"):
raise AttributeError(name)
try:
return getattr(self._env, name)
except AttributeError:
raise ValueError(name)
2023-06-13 09:58:03 +08:00
2023-06-18 16:57:05 +09:00
@property
2023-06-18 19:42:48 +09:00
def observation_space(self):
2023-06-18 16:57:05 +09:00
if self._obs_is_dict:
spaces = self._env.observation_space.spaces.copy()
else:
spaces = {self._obs_key: self._env.observation_space}
2023-06-24 23:05:45 +09:00
return gym.spaces.Dict(
{
**spaces,
"is_first": gym.spaces.Box(0, 1, (), dtype=bool),
"is_last": gym.spaces.Box(0, 1, (), dtype=bool),
"is_terminal": gym.spaces.Box(0, 1, (), dtype=bool),
}
)
2023-06-18 16:57:05 +09:00
@property
def action_space(self):
space = self._env.action_space
space.discrete = True
return space
2023-06-13 09:58:03 +08:00
2023-06-18 16:57:05 +09:00
def step(self, action):
obs, reward, done, info = self._env.step(action)
if not self._obs_is_dict:
obs = {self._obs_key: obs}
obs["is_first"] = False
obs["is_last"] = done
obs["is_terminal"] = info.get("is_terminal", False)
return obs, reward, done, info
2023-06-13 09:58:03 +08:00
2023-06-18 16:57:05 +09:00
def reset(self):
obs = self._env.reset()
if not self._obs_is_dict:
obs = {self._obs_key: obs}
obs["is_first"] = True
obs["is_last"] = False
obs["is_terminal"] = False
return obs