print params first
This commit is contained in:
parent
bd94cfbd51
commit
b8725720da
@ -39,9 +39,6 @@ class YopoTrainer:
|
||||
# params
|
||||
self.traj_num = cfg['traj_num']
|
||||
|
||||
# loss
|
||||
self.yopo_loss = YOPOLoss()
|
||||
|
||||
# network
|
||||
print("Loading network...")
|
||||
self.policy = YopoNetwork()
|
||||
@ -53,6 +50,9 @@ class YopoTrainer:
|
||||
except FileNotFoundError:
|
||||
print("Training from scratch")
|
||||
|
||||
# loss
|
||||
self.yopo_loss = YOPOLoss()
|
||||
|
||||
# optimizer
|
||||
self.optimizer = torch.optim.AdamW(self.policy.parameters(), lr=learning_rate, fused=True)
|
||||
print("Network Loaded! Loading Dataset...")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user