diff --git a/dreamer.py b/dreamer.py index 3224e51..eeadf5a 100644 --- a/dreamer.py +++ b/dreamer.py @@ -213,10 +213,12 @@ def make_env(config, logger, mode, train_eps, eval_eps): env = wrappers.OneHotAction(env) elif suite == "MemoryMaze": from envs.memorymaze import MemoryMaze + env = MemoryMaze(task) env = wrappers.OneHotAction(env) elif suite == "crafter": import envs.crafter as crafter + env = crafter.Crafter(task, config.size) env = wrappers.OneHotAction(env) else: @@ -254,17 +256,19 @@ class ProcessEpisodeWrap: length = len(episode["reward"]) - 1 score = float(episode["reward"].astype(np.float64).sum()) video = episode["image"] + # add new episode cache[str(filename)] = episode if mode == "train": - total = 0 + step_in_dataset = 0 for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])): - if not config.dataset_size or total <= config.dataset_size - length: - total += len(ep["reward"]) - 1 + if ( + not config.dataset_size + or step_in_dataset + (len(ep["reward"]) - 1) <= config.dataset_size + ): + step_in_dataset += len(ep["reward"]) - 1 else: del cache[key] - logger.scalar("dataset_size", total) - # use dataset_size as log step for a condition of envs > 1 - log_step = total * config.action_repeat + logger.scalar("dataset_size", step_in_dataset) elif mode == "eval": # keep only last item for saving memory while len(cache) > 1: @@ -285,7 +289,6 @@ class ProcessEpisodeWrap: score = sum(cls.eval_scores) / len(cls.eval_scores) length = sum(cls.eval_lengths) / len(cls.eval_lengths) episode_num = len(cls.eval_scores) - log_step = logger.step logger.video(f"{mode}_policy", video[None]) cls.eval_done = True @@ -295,7 +298,7 @@ class ProcessEpisodeWrap: logger.scalar( f"{mode}_episodes", len(cache) if mode == "train" else episode_num ) - logger.write(step=log_step) + logger.write(step=logger.step) def main(config): diff --git a/envs/memorymaze.py b/envs/memorymaze.py index d82971f..20717a8 100644 --- a/envs/memorymaze.py +++ b/envs/memorymaze.py @@ -32,13 +32,16 @@ class MemoryMaze: spaces = self._env.observation_space.spaces.copy() else: spaces = {self._obs_key: self._env.observation_space} - return gym.spaces.Dict({ - **spaces, - "reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), - "is_first": gym.spaces.Box(0, 1, (), dtype=bool), - "is_last": gym.spaces.Box(0, 1, (), dtype=bool), - "is_terminal": gym.spaces.Box(0, 1, (), dtype=bool), - }) + return gym.spaces.Dict( + { + **spaces, + "reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), + "is_first": gym.spaces.Box(0, 1, (), dtype=bool), + "is_last": gym.spaces.Box(0, 1, (), dtype=bool), + "is_terminal": gym.spaces.Box(0, 1, (), dtype=bool), + } + ) + @property def action_space(self): space = self._env.action_space @@ -49,7 +52,7 @@ class MemoryMaze: obs, reward, done, info = self._env.step(action) if not self._obs_is_dict: obs = {self._obs_key: obs} - obs['reward'] = reward + obs["reward"] = reward obs["is_first"] = False obs["is_last"] = done obs["is_terminal"] = info.get("is_terminal", False)