diff --git a/YOPO/policy/yopo_trainer.py b/YOPO/policy/yopo_trainer.py index 96d837e..626cbcc 100644 --- a/YOPO/policy/yopo_trainer.py +++ b/YOPO/policy/yopo_trainer.py @@ -70,6 +70,7 @@ class YopoTrainer: for self.epoch_i in range(epoch): self.policy.train() self.train_one_epoch(self.epoch_i, total_progress) + self.policy.eval() self.eval_one_epoch(self.epoch_i) if save_interval is not None and (self.epoch_i + 1) % save_interval == 0: self.progress_log.console.log("Saving model...")