53 lines
1.9 KiB
Python
53 lines
1.9 KiB
Python
|
import torch
|
||
|
def read_wei_course(epoch):
|
||
|
if(epoch >=0):
|
||
|
wei_anchors = torch.tensor(1.0)
|
||
|
wei_arc = torch.tensor(0.0)
|
||
|
wei_uni = torch.tensor(0.0)
|
||
|
wei_pos = torch.tensor(5.0)
|
||
|
wei_hol = torch.tensor(0.0)
|
||
|
wei_rot = torch.tensor(5.0)
|
||
|
wei_cur = torch.tensor(0.0)
|
||
|
wei_safety = torch.tensor(0.0)
|
||
|
wei_rsm = torch.tensor(0.0)
|
||
|
if(epoch >=1):
|
||
|
wei_anchors = torch.tensor(1.0)
|
||
|
wei_arc = torch.tensor(0.05)
|
||
|
wei_uni = torch.tensor(10.0)
|
||
|
wei_pos = torch.tensor(5.0)
|
||
|
wei_hol = torch.tensor(10.0)
|
||
|
wei_rot = torch.tensor(5.0)
|
||
|
wei_cur = torch.tensor(0.0)
|
||
|
wei_safety = torch.tensor(0.0)
|
||
|
wei_rsm = torch.tensor(0.0)
|
||
|
if(epoch >=2):
|
||
|
wei_anchors = torch.tensor(1.0)
|
||
|
wei_arc = torch.tensor(0.5)
|
||
|
wei_uni = torch.tensor(100.0)
|
||
|
wei_pos = torch.tensor(5.0)
|
||
|
wei_hol = torch.tensor(100.0)
|
||
|
wei_rot = torch.tensor(5.0)
|
||
|
wei_cur = torch.tensor(5.0)
|
||
|
wei_safety = torch.tensor(0.0)
|
||
|
wei_rsm = torch.tensor(0.5)
|
||
|
if(epoch >=3):
|
||
|
wei_anchors = torch.tensor(1.0)
|
||
|
wei_arc = torch.tensor(0.5)
|
||
|
wei_uni = torch.tensor(100.0)
|
||
|
wei_pos = torch.tensor(5.0)
|
||
|
wei_hol = torch.tensor(200.0)
|
||
|
wei_rot = torch.tensor(5.0)
|
||
|
wei_cur = torch.tensor(500.0)
|
||
|
wei_safety = torch.tensor(0.0)
|
||
|
wei_rsm = torch.tensor(0.25)
|
||
|
if(epoch >=4):
|
||
|
wei_anchors = torch.tensor(1.0)
|
||
|
wei_arc = torch.tensor(0.5)
|
||
|
wei_uni = torch.tensor(100.0)
|
||
|
wei_pos = torch.tensor(5.0)
|
||
|
wei_hol = torch.tensor(500.0)
|
||
|
wei_rot = torch.tensor(5.0)
|
||
|
wei_cur = torch.tensor(500.0)
|
||
|
wei_safety = torch.tensor(500.0)
|
||
|
wei_rsm = torch.tensor(0.5)
|
||
|
return wei_anchors, wei_arc, wei_uni, wei_pos, wei_hol, wei_rot, wei_cur, wei_safety, wei_rsm
|