diff --git a/YOPO/policy/yopo_trainer.py b/YOPO/policy/yopo_trainer.py index fbed679..9e46716 100644 --- a/YOPO/policy/yopo_trainer.py +++ b/YOPO/policy/yopo_trainer.py @@ -72,7 +72,7 @@ class YopoTrainer: self.eval_one_epoch(self.epoch_i) if save_interval is not None and (self.epoch_i + 1) % save_interval == 0: self.progress_log.console.log("Saving model...") - policy_path = self.tensorboard_path + "/epoch{}.pth".format(epoch + 1, 0) + policy_path = self.tensorboard_path + "/epoch{}.pth".format(self.epoch_i + 1, 0) torch.save(self.policy.state_dict(), policy_path) self.progress_log.console.log("Train YOPO Finish!") self.progress_log.remove_task(total_progress)