diff --git a/configs.yaml b/configs.yaml index b5c1e31..e6dc616 100644 --- a/configs.yaml +++ b/configs.yaml @@ -177,6 +177,9 @@ minecraft: task: minecraft_diamond step: 1e8 envs: 16 + # no eval + eval_episode_num: 0 + eval_every: 1e8 action_repeat: 1 train_ratio: 16 video_pred_log: true diff --git a/dreamer.py b/dreamer.py index 252194d..aede28c 100644 --- a/dreamer.py +++ b/dreamer.py @@ -230,7 +230,8 @@ def make_env(config, mode): env = wrappers.TimeLimit(env, config.time_limit) env = wrappers.SelectAction(env, key="action") env = wrappers.UUID(env) - env = wrappers.RewardObs(env) + if suite == "minecraft": + env = wrappers.RewardObs(env) return env @@ -326,20 +327,21 @@ def main(config): # make sure eval will be executed once after config.steps while agent._step < config.steps + config.eval_every: logger.write() - print("Start evaluation.") - eval_policy = functools.partial(agent, training=False) - tools.simulate( - eval_policy, - eval_envs, - eval_eps, - config.evaldir, - logger, - is_eval=True, - episodes=config.eval_episode_num, - ) - if config.video_pred_log: - video_pred = agent._wm.video_pred(next(eval_dataset)) - logger.video("eval_openl", to_np(video_pred)) + if config.eval_episode_num > 0: + print("Start evaluation.") + eval_policy = functools.partial(agent, training=False) + tools.simulate( + eval_policy, + eval_envs, + eval_eps, + config.evaldir, + logger, + is_eval=True, + episodes=config.eval_episode_num, + ) + if config.video_pred_log: + video_pred = agent._wm.video_pred(next(eval_dataset)) + logger.video("eval_openl", to_np(video_pred)) print("Start training.") state = tools.simulate( agent, diff --git a/networks.py b/networks.py index 5d05a9b..b1edfcc 100644 --- a/networks.py +++ b/networks.py @@ -415,7 +415,7 @@ class MultiDecoder(nn.Module): vector_dist, ): super(MultiDecoder, self).__init__() - excluded = ("is_first", "is_last", "is_terminal", "reward") + excluded = ("is_first", "is_last", "is_terminal") shapes = {k: v for k, v in shapes.items() if k not in excluded} self.cnn_shapes = { k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k)