diff --git a/dreamer.py b/dreamer.py index 5c8a17c..252194d 100644 --- a/dreamer.py +++ b/dreamer.py @@ -323,7 +323,8 @@ def main(config): agent.load_state_dict(torch.load(logdir / "latest_model.pt")) agent._should_pretrain._once = False - while agent._step < config.steps: + # make sure eval will be executed once after config.steps + while agent._step < config.steps + config.eval_every: logger.write() print("Start evaluation.") eval_policy = functools.partial(agent, training=False)