import torch def read_wei_course(epoch): if(epoch >=0): wei_anchors = torch.tensor(1.0) wei_arc = torch.tensor(0.0) wei_uni = torch.tensor(0.0) wei_pos = torch.tensor(5.0) wei_hol = torch.tensor(0.0) wei_rot = torch.tensor(5.0) wei_cur = torch.tensor(0.0) wei_safety = torch.tensor(0.0) wei_rsm = torch.tensor(0.0) if(epoch >=1): wei_anchors = torch.tensor(1.0) wei_arc = torch.tensor(0.05) wei_uni = torch.tensor(10.0) wei_pos = torch.tensor(5.0) wei_hol = torch.tensor(10.0) wei_rot = torch.tensor(5.0) wei_cur = torch.tensor(0.0) wei_safety = torch.tensor(0.0) wei_rsm = torch.tensor(0.0) if(epoch >=2): wei_anchors = torch.tensor(1.0) wei_arc = torch.tensor(0.5) wei_uni = torch.tensor(100.0) wei_pos = torch.tensor(5.0) wei_hol = torch.tensor(100.0) wei_rot = torch.tensor(5.0) wei_cur = torch.tensor(5.0) wei_safety = torch.tensor(0.0) wei_rsm = torch.tensor(0.5) if(epoch >=3): wei_anchors = torch.tensor(1.0) wei_arc = torch.tensor(0.5) wei_uni = torch.tensor(100.0) wei_pos = torch.tensor(5.0) wei_hol = torch.tensor(200.0) wei_rot = torch.tensor(5.0) wei_cur = torch.tensor(500.0) wei_safety = torch.tensor(0.0) wei_rsm = torch.tensor(0.25) if(epoch >=4): wei_anchors = torch.tensor(1.0) wei_arc = torch.tensor(0.5) wei_uni = torch.tensor(100.0) wei_pos = torch.tensor(5.0) wei_hol = torch.tensor(500.0) wei_rot = torch.tensor(5.0) wei_cur = torch.tensor(500.0) wei_safety = torch.tensor(500.0) wei_rsm = torch.tensor(0.5) return wei_anchors, wei_arc, wei_uni, wei_pos, wei_hol, wei_rot, wei_cur, wei_safety, wei_rsm