print params first
This commit is contained in:
parent
bd94cfbd51
commit
b8725720da
@ -39,9 +39,6 @@ class YopoTrainer:
|
|||||||
# params
|
# params
|
||||||
self.traj_num = cfg['traj_num']
|
self.traj_num = cfg['traj_num']
|
||||||
|
|
||||||
# loss
|
|
||||||
self.yopo_loss = YOPOLoss()
|
|
||||||
|
|
||||||
# network
|
# network
|
||||||
print("Loading network...")
|
print("Loading network...")
|
||||||
self.policy = YopoNetwork()
|
self.policy = YopoNetwork()
|
||||||
@ -53,6 +50,9 @@ class YopoTrainer:
|
|||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print("Training from scratch")
|
print("Training from scratch")
|
||||||
|
|
||||||
|
# loss
|
||||||
|
self.yopo_loss = YOPOLoss()
|
||||||
|
|
||||||
# optimizer
|
# optimizer
|
||||||
self.optimizer = torch.optim.AdamW(self.policy.parameters(), lr=learning_rate, fused=True)
|
self.optimizer = torch.optim.AdamW(self.policy.parameters(), lr=learning_rate, fused=True)
|
||||||
print("Network Loaded! Loading Dataset...")
|
print("Network Loaded! Loading Dataset...")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user