diff --git a/YOPO/policy/yopo_trainer.py b/YOPO/policy/yopo_trainer.py index f07352a..5c766a3 100644 --- a/YOPO/policy/yopo_trainer.py +++ b/YOPO/policy/yopo_trainer.py @@ -39,9 +39,6 @@ class YopoTrainer: # params self.traj_num = cfg['traj_num'] - # loss - self.yopo_loss = YOPOLoss() - # network print("Loading network...") self.policy = YopoNetwork() @@ -53,6 +50,9 @@ class YopoTrainer: except FileNotFoundError: print("Training from scratch") + # loss + self.yopo_loss = YOPOLoss() + # optimizer self.optimizer = torch.optim.AdamW(self.policy.parameters(), lr=learning_rate, fused=True) print("Network Loaded! Loading Dataset...")