applied formatter
This commit is contained in:
parent
775eb94e7f
commit
e3329b35e5
10
dreamer.py
10
dreamer.py
@ -213,13 +213,15 @@ def make_env(config, logger, mode, train_eps, eval_eps):
|
|||||||
env = wrappers.OneHotAction(env)
|
env = wrappers.OneHotAction(env)
|
||||||
elif suite == "MemoryMaze":
|
elif suite == "MemoryMaze":
|
||||||
import gym
|
import gym
|
||||||
if task == '9x9':
|
|
||||||
env = gym.make('memory_maze:MemoryMaze-9x9-v0')
|
if task == "9x9":
|
||||||
elif task == '15x15':
|
env = gym.make("memory_maze:MemoryMaze-9x9-v0")
|
||||||
env = gym.make('memory_maze:MemoryMaze-15x15-v0')
|
elif task == "15x15":
|
||||||
|
env = gym.make("memory_maze:MemoryMaze-15x15-v0")
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(suite)
|
raise NotImplementedError(suite)
|
||||||
from envs.memorymaze import MemoryMaze
|
from envs.memorymaze import MemoryMaze
|
||||||
|
|
||||||
env = MemoryMaze(env)
|
env = MemoryMaze(env)
|
||||||
env = wrappers.OneHotAction(env)
|
env = wrappers.OneHotAction(env)
|
||||||
elif suite == "crafter":
|
elif suite == "crafter":
|
||||||
|
@ -8,19 +8,19 @@ import numpy as np
|
|||||||
|
|
||||||
###from tf dreamerv2 code
|
###from tf dreamerv2 code
|
||||||
|
|
||||||
class MemoryMaze:
|
|
||||||
|
|
||||||
def __init__(self, env, obs_key='image', act_key='action', size=(64, 64)):
|
class MemoryMaze:
|
||||||
|
def __init__(self, env, obs_key="image", act_key="action", size=(64, 64)):
|
||||||
self._env = env
|
self._env = env
|
||||||
self._obs_is_dict = hasattr(self._env.observation_space, 'spaces')
|
self._obs_is_dict = hasattr(self._env.observation_space, "spaces")
|
||||||
self._act_is_dict = hasattr(self._env.action_space, 'spaces')
|
self._act_is_dict = hasattr(self._env.action_space, "spaces")
|
||||||
self._obs_key = obs_key
|
self._obs_key = obs_key
|
||||||
self._act_key = act_key
|
self._act_key = act_key
|
||||||
self._size = size
|
self._size = size
|
||||||
self._gray = False
|
self._gray = False
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
if name.startswith('__'):
|
if name.startswith("__"):
|
||||||
raise AttributeError(name)
|
raise AttributeError(name)
|
||||||
try:
|
try:
|
||||||
return getattr(self._env, name)
|
return getattr(self._env, name)
|
||||||
@ -35,10 +35,10 @@ class MemoryMaze:
|
|||||||
spaces = {self._obs_key: self._env.observation_space}
|
spaces = {self._obs_key: self._env.observation_space}
|
||||||
return {
|
return {
|
||||||
**spaces,
|
**spaces,
|
||||||
'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32),
|
"reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32),
|
||||||
'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool),
|
"is_first": gym.spaces.Box(0, 1, (), dtype=np.bool),
|
||||||
'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool),
|
"is_last": gym.spaces.Box(0, 1, (), dtype=np.bool),
|
||||||
'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool),
|
"is_terminal": gym.spaces.Box(0, 1, (), dtype=np.bool),
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -70,18 +70,17 @@ class MemoryMaze:
|
|||||||
if not self._obs_is_dict:
|
if not self._obs_is_dict:
|
||||||
obs = {self._obs_key: obs}
|
obs = {self._obs_key: obs}
|
||||||
# obs['reward'] = float(reward)
|
# obs['reward'] = float(reward)
|
||||||
obs['is_first'] = False
|
obs["is_first"] = False
|
||||||
obs['is_last'] = done
|
obs["is_last"] = done
|
||||||
obs['is_terminal'] = info.get('is_terminal', False)
|
obs["is_terminal"] = info.get("is_terminal", False)
|
||||||
return obs, reward, done, info
|
return obs, reward, done, info
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
obs = self._env.reset()
|
obs = self._env.reset()
|
||||||
if not self._obs_is_dict:
|
if not self._obs_is_dict:
|
||||||
obs = {self._obs_key: obs}
|
obs = {self._obs_key: obs}
|
||||||
obs['reward'] = 0.0
|
obs["reward"] = 0.0
|
||||||
obs['is_first'] = True
|
obs["is_first"] = True
|
||||||
obs['is_last'] = False
|
obs["is_last"] = False
|
||||||
obs['is_terminal'] = False
|
obs["is_terminal"] = False
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user