diff --git a/configs.yaml b/configs.yaml index 6074feb..7c246e1 100644 --- a/configs.yaml +++ b/configs.yaml @@ -25,7 +25,7 @@ defaults: action_repeat: 2 time_limit: 1000 grayscale: False - prefill: 2500 + prefill: 250 #0 eval_noise: 0.0 reward_EMA: True diff --git a/dreamer.py b/dreamer.py index 61ae538..04b46ab 100644 --- a/dreamer.py +++ b/dreamer.py @@ -215,7 +215,7 @@ def make_env(config, logger, mode, train_eps, eval_eps): env = gym.make('memory_maze:MemoryMaze-9x9-v0') from envs.memmazeEnv import MZGymWrapper env = MZGymWrapper(env) - + #from envs.memmazeEnv import OneHotAction as OneHotAction2 env = wrappers.OneHotAction(env) elif suite == "---------mazed": from memory_maze import tasks diff --git a/envs/memmazeEnv.py b/envs/memmazeEnv.py index baabf90..37a835f 100644 --- a/envs/memmazeEnv.py +++ b/envs/memmazeEnv.py @@ -47,17 +47,32 @@ class MZGymWrapper: else: return {self._act_key: self._env.action_space} + @property + def observation_space(self): + img_shape = self._size + ((1,) if self._gray else (3,)) + return gym.spaces.Dict( + { + "image": gym.spaces.Box(0, 255, img_shape, np.uint8), + } + ) + + @property + def action_space(self): + space = self._env.action_space + space.discrete = True + return space + def step(self, action): - if not self._act_is_dict: - action = action[self._act_key] + # if not self._act_is_dict: + # action = action[self._act_key] obs, reward, done, info = self._env.step(action) if not self._obs_is_dict: obs = {self._obs_key: obs} - obs['reward'] = float(reward) + # obs['reward'] = float(reward) obs['is_first'] = False obs['is_last'] = done obs['is_terminal'] = info.get('is_terminal', done) - return obs + return obs, reward, done, info def reset(self): obs = self._env.reset() diff --git a/envs/wrappers.py b/envs/wrappers.py index 1a4a58b..9769fc9 100644 --- a/envs/wrappers.py +++ b/envs/wrappers.py @@ -77,8 +77,8 @@ class CollectDataset: dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision] elif np.issubdtype(value.dtype, np.uint8): dtype = np.uint8 - elif np.issubdtype(value.dtype, np.bool): - dtype = np.bool + elif np.issubdtype(value.dtype, np.bool_): + dtype = np.bool_ else: raise NotImplementedError(value.dtype) return value.astype(dtype) @@ -96,6 +96,7 @@ class TimeLimit: def step(self, action): assert self._step is not None, "Must reset environment." obs, reward, done, info = self._env.step(action) + # teets = self._env.step(action) self._step += 1 if self._step >= self._duration: done = True