Modify to global config and simplified speed adjustment during training and testing

This commit is contained in:
TJU_Lu 2025-07-06 21:31:41 +08:00
parent 788e9bc979
commit bd94cfbd51
14 changed files with 98 additions and 110 deletions

0
YOPO/config/__init__.py Normal file
View File

22
YOPO/config/config.py Normal file
View File

@ -0,0 +1,22 @@
import os
from ruamel.yaml import YAML
# Global Configuration Management
class Config:
def __init__(self):
base_dir = os.path.dirname(os.path.abspath(__file__))
self._data = YAML().load(open(os.path.join(base_dir, "traj_opt.yaml"), 'r'))
self._data["train"] = True
self._data["goal_length"] = 2.0 * self._data['radio_range']
self._data["sgm_time"] = 2 * self._data["radio_range"] / self._data["vel_max_train"]
self._data["traj_num"] = self._data['horizon_num'] * self._data['vertical_num'] * self._data["radio_num"]
def __getitem__(self, key):
return self._data[key]
def __setitem__(self, key, value):
self._data[key] = value
cfg = Config()

View File

@ -1,11 +1,12 @@
# IMPORTANT PARAM: actual velocity in training / testing # IMPORTANT: velocity in testing (modifiable)
velocity: 6.0 velocity: 6.0
# used to align the vel/acc, ensure consistency between testing and training.
# during testing, if vel*n then acc*n*n, please refer to ../policy/primitive.py
vel_align: 6.0
acc_align: 6.0
# IMPORTANT PARAM: weight of penalties (for unit speed) # IMPORTANT: vel_max and acc_max in training
# not the actual values in testing, ensure consistency between testing and training, please refer to ../policy/primitive.py
vel_max_train: 6.0
acc_max_train: 6.0
# IMPORTANT: weight of costs for unit speed (can be visualized in tensorboard)
wg: 0.1 # guidance wg: 0.1 # guidance
ws: 10.0 # smoothness ws: 10.0 # smoothness
wc: 0.1 # collision wc: 0.1 # collision
@ -15,7 +16,7 @@ dataset_path: "../dataset"
image_height: 96 image_height: 96
image_width: 160 image_width: 160
# trajectory and primitive param (Note: traj_time = 2 * radio / vel_max) # trajectory and primitive param (Note: traj_time = 2 * radio / vel_max) (can be visualized in rviz)
horizon_num: 5 horizon_num: 5
vertical_num: 3 vertical_num: 3
horizon_camera_fov: 90.0 horizon_camera_fov: 90.0
@ -29,7 +30,7 @@ radio_num: 1 # only support 1 currently
d0: 1.2 d0: 1.2
r: 0.6 r: 0.6
# distribution parameters for unit state sampling # distribution for unit state sampling (can be visualized by 'python ./policy/yopo_dataset.py')
vx_mean_unit: 0.4 vx_mean_unit: 0.4
vy_mean_unit: 0.0 vy_mean_unit: 0.0
vz_mean_unit: 0.0 vz_mean_unit: 0.0

View File

@ -1,16 +1,13 @@
import os
import torch.nn as nn import torch.nn as nn
import torch as th import torch as th
import torch.nn.functional as F import torch.nn.functional as F
from ruamel.yaml import YAML from config.config import cfg
class GuidanceLoss(nn.Module): class GuidanceLoss(nn.Module):
def __init__(self): def __init__(self):
super(GuidanceLoss, self).__init__() super(GuidanceLoss, self).__init__()
base_dir = os.path.dirname(os.path.abspath(__file__)) self.goal_length = cfg['goal_length']
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
self.goal_length = 2.0 * cfg['radio_range']
def forward(self, Df, Dp, goal): def forward(self, Df, Dp, goal):
""" """

View File

@ -1,8 +1,7 @@
import os
import math import math
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
from ruamel.yaml import YAML from config.config import cfg
from loss.safety_loss import SafetyLoss from loss.safety_loss import SafetyLoss
from loss.smoothness_loss import SmoothnessLoss from loss.smoothness_loss import SmoothnessLoss
from loss.guidance_loss import GuidanceLoss from loss.guidance_loss import GuidanceLoss
@ -17,20 +16,18 @@ class YOPOLoss(nn.Module):
df: fixed parameters df: fixed parameters
""" """
super(YOPOLoss, self).__init__() super(YOPOLoss, self).__init__()
base_dir = os.path.dirname(os.path.abspath(__file__)) self.sgm_time = cfg["sgm_time"]
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
self.sgm_time = 2 * cfg["radio_range"] / cfg["velocity"]
self.device = th.device("cuda" if th.cuda.is_available() else "cpu") self.device = th.device("cuda" if th.cuda.is_available() else "cpu")
self._C, self._B, self._L, self._R = self.qp_generation() self._C, self._B, self._L, self._R = self.qp_generation()
self._R = self._R.to(self.device) self._R = self._R.to(self.device)
self._L = self._L.to(self.device) self._L = self._L.to(self.device)
vel_scale = cfg["velocity"] / 1.0 vel_scale = cfg["vel_max_train"] / 1.0
self.smoothness_weight = cfg["ws"] self.smoothness_weight = cfg["ws"]
self.safety_weight = cfg["wc"] self.safety_weight = cfg["wc"]
self.goal_weight = cfg["wg"] self.goal_weight = cfg["wg"]
self.denormalize_weight(vel_scale) self.denormalize_weight(vel_scale)
self.smoothness_loss = SmoothnessLoss(self._R) self.smoothness_loss = SmoothnessLoss(self._R)
self.safety_loss = SafetyLoss(self._L, self.sgm_time) self.safety_loss = SafetyLoss(self._L)
self.goal_loss = GuidanceLoss() self.goal_loss = GuidanceLoss()
print("------ Actual Loss ------") print("------ Actual Loss ------")
print(f"| {'smooth':<12} = {self.smoothness_weight:6.4f} |") print(f"| {'smooth':<12} = {self.smoothness_weight:6.4f} |")

View File

@ -5,23 +5,21 @@ import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import open3d as o3d import open3d as o3d
from ruamel.yaml import YAML
from scipy.ndimage import distance_transform_edt from scipy.ndimage import distance_transform_edt
from config.config import cfg
class SafetyLoss(nn.Module): class SafetyLoss(nn.Module):
def __init__(self, L, sgm_time): def __init__(self, L):
super(SafetyLoss, self).__init__() super(SafetyLoss, self).__init__()
base_dir = os.path.dirname(os.path.abspath(__file__)) self.traj_num = cfg['traj_num']
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
self.traj_num = cfg['horizon_num'] * cfg['vertical_num']
self.map_expand_min = np.array(cfg['map_expand_min']) self.map_expand_min = np.array(cfg['map_expand_min'])
self.map_expand_max = np.array(cfg['map_expand_max']) self.map_expand_max = np.array(cfg['map_expand_max'])
self.d0 = cfg["d0"] self.d0 = cfg["d0"]
self.r = cfg["r"] self.r = cfg["r"]
self._L = L self._L = L
self.sgm_time = sgm_time self.sgm_time = cfg["sgm_time"]
self.eval_points = 30 self.eval_points = 30
self.device = self._L.device self.device = self._L.device
@ -31,6 +29,7 @@ class SafetyLoss(nn.Module):
self.max_bounds = None # shape: (N, 3) self.max_bounds = None # shape: (N, 3)
self.sdf_shapes = None # shape: (1, 3) self.sdf_shapes = None # shape: (1, 3)
print("Building ESDF map...") print("Building ESDF map...")
base_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.join(base_dir, "../", cfg["dataset_path"]) data_dir = os.path.join(base_dir, "../", cfg["dataset_path"])
self.sdf_maps = self.get_sdf_from_ply(data_dir) self.sdf_maps = self.get_sdf_from_ply(data_dir)
print("Map built!") print("Map built!")

View File

@ -1,16 +1,18 @@
import torch import torch
from scipy.spatial.transform import Rotation as R from scipy.spatial.transform import Rotation as R
from config.config import cfg
class LatticeParam: class LatticeParam:
def __init__(self, cfg): def __init__(self):
ratio = cfg["velocity"] / cfg["vel_align"] ratio = 1.0 if cfg["train"] else cfg["velocity"] / cfg["vel_max_train"]
self.vel_max = ratio * cfg["vel_align"] self.vel_max = ratio * cfg["vel_max_train"]
self.acc_max = ratio * ratio * cfg["acc_align"] self.acc_max = ratio * ratio * cfg["acc_max_train"]
self.segment_time = 2 * cfg["radio_range"] / self.vel_max self.segment_time = cfg["sgm_time"] / ratio
self.horizon_num = cfg["horizon_num"] self.horizon_num = cfg["horizon_num"]
self.vertical_num = cfg["vertical_num"] self.vertical_num = cfg["vertical_num"]
self.radio_num = cfg["radio_num"] self.radio_num = cfg["radio_num"]
self.traj_num = cfg["traj_num"]
self.horizon_fov = cfg["horizon_camera_fov"] self.horizon_fov = cfg["horizon_camera_fov"]
self.vertical_fov = cfg["vertical_camera_fov"] self.vertical_fov = cfg["vertical_camera_fov"]
self.horizon_anchor_fov = cfg["horizon_anchor_fov"] self.horizon_anchor_fov = cfg["horizon_anchor_fov"]
@ -38,12 +40,10 @@ class LatticePrimitive(LatticeParam):
""" """
_instance = None _instance = None
def __init__(self, cfg): def __init__(self):
super().__init__(cfg) super().__init__()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.traj_num = self.vertical_num * self.horizon_num * self.radio_num
if self.horizon_num == 1: if self.horizon_num == 1:
direction_diff = 0 direction_diff = 0
else: else:
@ -105,16 +105,6 @@ class LatticePrimitive(LatticeParam):
return self.traj_num - id - 1 return self.traj_num - id - 1
@classmethod @classmethod
def get_instance(self, cfg): def get_instance(self):
if self._instance is None: self._instance = self(cfg) if self._instance is None: self._instance = self()
return self._instance return self._instance
if __name__ == '__main__':
import os
from ruamel.yaml import YAML
base_dir = os.path.dirname(os.path.abspath(__file__))
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
lattice_primitive = LatticePrimitive.get_instance(cfg)
print(lattice_primitive.getStateLattice(list(range(lattice_primitive.traj_num))))

View File

@ -1,16 +1,13 @@
import os
import torch import torch
import numpy as np import numpy as np
from ruamel.yaml import YAML from config.config import cfg
from policy.primitive import LatticePrimitive from policy.primitive import LatticePrimitive
class StateTransform: class StateTransform:
def __init__(self): def __init__(self):
base_dir = os.path.dirname(os.path.abspath(__file__)) self.lattice_primitive = LatticePrimitive.get_instance()
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r')) self.goal_length = cfg['goal_length']
self.lattice_primitive = LatticePrimitive.get_instance(cfg)
self.goal_length = 2.0 * cfg['radio_range']
def pred_to_endstate(self, endstate_pred: torch.Tensor) -> torch.Tensor: def pred_to_endstate(self, endstate_pred: torch.Tensor) -> torch.Tensor:
""" """

View File

@ -1,38 +1,37 @@
import os import os, sys
import cv2 import cv2
import time
import numpy as np import numpy as np
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from ruamel.yaml import YAML
import time
from scipy.spatial.transform import Rotation as R from scipy.spatial.transform import Rotation as R
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from config.config import cfg
class YOPODataset(Dataset): class YOPODataset(Dataset):
def __init__(self, mode='train', val_ratio=0.1): def __init__(self, mode='train', val_ratio=0.1):
super(YOPODataset, self).__init__() super(YOPODataset, self).__init__()
base_dir = os.path.dirname(os.path.abspath(__file__))
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
# image params # image params
self.height = int(cfg["image_height"]) self.height = int(cfg["image_height"])
self.width = int(cfg["image_width"]) self.width = int(cfg["image_width"])
# ramdom state: x-direction: log-normal distribution, yz-direction: normal distribution # ramdom state: x-direction: log-normal distribution, yz-direction: normal distribution
scale = cfg["velocity"] / cfg["vel_align"] self.vel_max = cfg["vel_max_train"]
self.vel_max = scale * cfg["vel_align"] self.acc_max = cfg["acc_max_train"]
self.acc_max = scale * scale * cfg["acc_align"]
self.vx_lognorm_mean = np.log(1 - cfg["vx_mean_unit"]) self.vx_lognorm_mean = np.log(1 - cfg["vx_mean_unit"])
self.vx_logmorm_sigma = np.log(cfg["vx_std_unit"]) self.vx_logmorm_sigma = np.log(cfg["vx_std_unit"])
self.v_mean = np.array([cfg["vx_mean_unit"], cfg["vy_mean_unit"], cfg["vz_mean_unit"]]) self.v_mean = np.array([cfg["vx_mean_unit"], cfg["vy_mean_unit"], cfg["vz_mean_unit"]])
self.v_std = np.array([cfg["vx_std_unit"], cfg["vy_std_unit"], cfg["vz_std_unit"]]) self.v_std = np.array([cfg["vx_std_unit"], cfg["vy_std_unit"], cfg["vz_std_unit"]])
self.a_mean = np.array([cfg["ax_mean_unit"], cfg["ay_mean_unit"], cfg["az_mean_unit"]]) self.a_mean = np.array([cfg["ax_mean_unit"], cfg["ay_mean_unit"], cfg["az_mean_unit"]])
self.a_std = np.array([cfg["ax_std_unit"], cfg["ay_std_unit"], cfg["az_std_unit"]]) self.a_std = np.array([cfg["ax_std_unit"], cfg["ay_std_unit"], cfg["az_std_unit"]])
self.goal_length = 2.0 * cfg['radio_range'] self.goal_length = cfg['goal_length']
self.goal_pitch_std = cfg["goal_pitch_std"] self.goal_pitch_std = cfg["goal_pitch_std"]
self.goal_yaw_std = cfg["goal_yaw_std"] self.goal_yaw_std = cfg["goal_yaw_std"]
if mode == 'train': self.print_data() if mode == 'train': self.print_data()
# dataset # dataset
print("Loading", mode, "dataset, it may take a while...") print("Loading", mode, "dataset, it may take a while...")
base_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.join(base_dir, "../", cfg["dataset_path"]) data_dir = os.path.join(base_dir, "../", cfg["dataset_path"])
self.img_list, self.map_idx, self.positions, self.quaternions = [], [], np.empty((0, 3), dtype=np.float32), np.empty((0, 4), dtype=np.float32) self.img_list, self.map_idx, self.positions, self.quaternions = [], [], np.empty((0, 3), dtype=np.float32), np.empty((0, 4), dtype=np.float32)

View File

@ -2,6 +2,7 @@
Training Strategy Training Strategy
supervised learning, imitation learning, testing, rollout supervised learning, imitation learning, testing, rollout
""" """
import os
import time import time
import atexit import atexit
from torch.nn import functional as F from torch.nn import functional as F
@ -9,6 +10,7 @@ from rich.progress import Progress
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter from torch.utils.tensorboard.writer import SummaryWriter
from config.config import cfg
from loss.loss_function import YOPOLoss from loss.loss_function import YOPOLoss
from policy.yopo_network import YopoNetwork from policy.yopo_network import YopoNetwork
from policy.yopo_dataset import YOPODataset from policy.yopo_dataset import YOPODataset
@ -35,10 +37,7 @@ class YopoTrainer:
self.tensorboard_path = self.get_next_log_path(tensorboard_path) self.tensorboard_path = self.get_next_log_path(tensorboard_path)
self.tensorboard_log = SummaryWriter(log_dir=self.tensorboard_path) self.tensorboard_log = SummaryWriter(log_dir=self.tensorboard_path)
# params # params
base_dir = os.path.dirname(os.path.abspath(__file__)) self.traj_num = cfg['traj_num']
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
self.lattice_primitive = LatticePrimitive.get_instance(cfg)
self.traj_num = self.lattice_primitive.traj_num
# loss # loss
self.yopo_loss = YOPOLoss() self.yopo_loss = YOPOLoss()

View File

@ -13,9 +13,9 @@ import time
import torch import torch
import numpy as np import numpy as np
import argparse import argparse
from ruamel.yaml import YAML
from scipy.spatial.transform import Rotation as R from scipy.spatial.transform import Rotation as R
from config.config import cfg
from control_msg import PositionCommand from control_msg import PositionCommand
from policy.yopo_network import YopoNetwork from policy.yopo_network import YopoNetwork
from policy.poly_solver import * from policy.poly_solver import *
@ -32,8 +32,7 @@ class YopoNet:
self.config = config self.config = config
rospy.init_node('yopo_net', anonymous=False) rospy.init_node('yopo_net', anonymous=False)
# load params # load params
base_dir = os.path.dirname(os.path.abspath(__file__)) cfg["train"] = False
cfg = YAML().load(open(os.path.join(base_dir, "config/traj_opt.yaml"), 'r'))
self.height = cfg['image_height'] self.height = cfg['image_height']
self.width = cfg['image_width'] self.width = cfg['image_width']
self.min_dis, self.max_dis = 0.04, 20.0 self.min_dis, self.max_dis = 0.04, 20.0
@ -64,7 +63,7 @@ class YopoNet:
self.lock = Lock() self.lock = Lock()
self.last_control_msg = None self.last_control_msg = None
self.state_transform = StateTransform() self.state_transform = StateTransform()
self.lattice_primitive = LatticePrimitive.get_instance(cfg) self.lattice_primitive = LatticePrimitive.get_instance()
self.traj_time = self.lattice_primitive.segment_time self.traj_time = self.lattice_primitive.segment_time
# eval # eval
@ -354,7 +353,7 @@ def parser():
return parser return parser
def main(): if __name__ == "__main__":
args = parser().parse_args() args = parser().parse_args()
base_dir = os.path.dirname(os.path.abspath(__file__)) base_dir = os.path.dirname(os.path.abspath(__file__))
weight = "yopo_trt.pth" if args.use_tensorrt else base_dir + "/saved/YOPO_{}/epoch{}.pth".format(args.trial, args.epoch) weight = "yopo_trt.pth" if args.use_tensorrt else base_dir + "/saved/YOPO_{}/epoch{}.pth".format(args.trial, args.epoch)
@ -372,7 +371,3 @@ def main():
'visualize': True # 可视化所有轨迹?(实飞改为False节省计算) 'visualize': True # 可视化所有轨迹?(实飞改为False节省计算)
} }
YopoNet(settings, weight) YopoNet(settings, weight)
if __name__ == "__main__":
main()

View File

@ -24,10 +24,9 @@ def parser():
return parser return parser
def main(): if __name__ == "__main__":
args = parser().parse_args() args = parser().parse_args()
# set random seed configure_random_seed(0) # set random seed
configure_random_seed(0)
# save the configuration and other files # save the configuration and other files
log_dir = os.path.dirname(os.path.abspath(__file__)) + "/saved" log_dir = os.path.dirname(os.path.abspath(__file__)) + "/saved"
@ -46,7 +45,3 @@ def main():
trainer.train(epoch=50) trainer.train(epoch=50)
print("Run YOPO Finish!") print("Run YOPO Finish!")
if __name__ == "__main__":
main()

View File

@ -1,12 +1,13 @@
Panels: Panels:
- Class: rviz/Displays - Class: rviz/Displays
Help Height: 78 Help Height: 138
Name: Displays Name: Displays
Property Tree Widget: Property Tree Widget:
Expanded: Expanded:
- /PointCloud21/Autocompute Value Bounds1 - /Map1/Autocompute Value Bounds1
Splitter Ratio: 0.5 - /Trajectory1
Tree Height: 326 Splitter Ratio: 0.6625221967697144
Tree Height: 476
- Class: rviz/Selection - Class: rviz/Selection
Name: Selection Name: Selection
- Class: rviz/Tool Properties - Class: rviz/Tool Properties
@ -24,7 +25,7 @@ Panels:
- Class: rviz/Time - Class: rviz/Time
Name: Time Name: Time
SyncMode: 0 SyncMode: 0
SyncSource: Image SyncSource: Depth
Preferences: Preferences:
PromptSaveOnExit: true PromptSaveOnExit: true
Toolbars: Toolbars:
@ -56,7 +57,7 @@ Visualization Manager:
Max Value: 1 Max Value: 1
Median window: 5 Median window: 5
Min Value: 0 Min Value: 0
Name: Image Name: Depth
Normalize Range: true Normalize Range: true
Queue Size: 2 Queue Size: 2
Transport Hint: raw Transport Hint: raw
@ -78,7 +79,7 @@ Visualization Manager:
Invert Rainbow: false Invert Rainbow: false
Max Color: 255; 255; 255 Max Color: 255; 255; 255
Min Color: 0; 0; 0 Min Color: 0; 0; 0
Name: PointCloud2 Name: Map
Position Transformer: XYZ Position Transformer: XYZ
Queue Size: 10 Queue Size: 10
Selectable: true Selectable: true
@ -108,7 +109,7 @@ Visualization Manager:
Invert Rainbow: false Invert Rainbow: false
Max Color: 255; 255; 255 Max Color: 255; 255; 255
Min Color: 0; 0; 0 Min Color: 0; 0; 0
Name: PointCloud2 Name: Best_Traj
Position Transformer: XYZ Position Transformer: XYZ
Queue Size: 10 Queue Size: 10
Selectable: true Selectable: true
@ -136,7 +137,7 @@ Visualization Manager:
Invert Rainbow: true Invert Rainbow: true
Max Color: 255; 255; 255 Max Color: 255; 255; 255
Min Color: 0; 0; 0 Min Color: 0; 0; 0
Name: PointCloud2 Name: All_traj
Position Transformer: XYZ Position Transformer: XYZ
Queue Size: 10 Queue Size: 10
Selectable: true Selectable: true
@ -164,7 +165,7 @@ Visualization Manager:
Invert Rainbow: false Invert Rainbow: false
Max Color: 255; 255; 255 Max Color: 255; 255; 255
Min Color: 0; 0; 0 Min Color: 0; 0; 0
Name: PointCloud2 Name: Primitive_Traj
Position Transformer: XYZ Position Transformer: XYZ
Queue Size: 10 Queue Size: 10
Selectable: true Selectable: true
@ -177,11 +178,11 @@ Visualization Manager:
Use rainbow: true Use rainbow: true
Value: false Value: false
Enabled: true Enabled: true
Name: Traj Name: Trajectory
- Class: rviz/Marker - Class: rviz/Marker
Enabled: true Enabled: true
Marker Topic: /quadrotor_simulator_so3/uav Marker Topic: /quadrotor_simulator_so3/uav
Name: Marker Name: Drone
Namespaces: Namespaces:
mesh: true mesh: true
Queue Size: 100 Queue Size: 100
@ -235,14 +236,14 @@ Visualization Manager:
Yaw: 3.140000104904175 Yaw: 3.140000104904175
Saved: ~ Saved: ~
Window Geometry: Window Geometry:
Depth:
collapsed: false
Displays: Displays:
collapsed: false collapsed: false
Height: 1016 Height: 1600
Hide Left Dock: false Hide Left Dock: false
Hide Right Dock: true Hide Right Dock: true
Image: QMainWindow State: 000000ff00000000fd0000000400000000000003310000053afc0200000009fb0000001200530065006c0065006300740069006f006e00000001e10000009b000000b000fffffffb0000001e0054006f006f006c002000500072006f007000650072007400690065007302000001ed000001df00000185000000b0fb000000120056006900650077007300200054006f006f02000001df000002110000018500000122fb000000200054006f006f006c002000500072006f0070006500720074006900650073003203000002880000011d000002210000017afb000000100044006900730070006c006100790073010000006e000002d40000018200fffffffb0000002000730065006c0065006300740069006f006e00200062007500660066006500720200000138000000aa0000023a00000294fb00000014005700690064006500530074006500720065006f02000000e6000000d2000003ee0000030bfb0000000c004b0069006e0065006300740200000186000001060000030c00000261fb0000000a00440065007000740068010000034e0000025a0000002600ffffff00000001000001b90000035afc0200000003fb0000001e0054006f006f006c002000500072006f00700065007200740069006500730100000041000000780000000000000000fb0000000a00560069006500770073000000003d0000035a0000013200fffffffb0000001200530065006c0065006300740069006f006e010000025a000000b200000000000000000000000200000490000000a9fc0100000001fb0000000a00560069006500770073030000004e00000080000002e1000001970000000300000b700000005efc0100000002fb0000000800540069006d0065010000000000000b70000006dc00fffffffb0000000800540069006d00650100000000000004500000000000000000000008330000053a00000004000000040000000800000008fc0000000100000002000000010000000a0054006f006f006c00730100000000ffffffff0000000000000000
collapsed: false
QMainWindow State: 000000ff00000000fd0000000400000000000002350000035afc0200000009fb0000001200530065006c0065006300740069006f006e00000001e10000009b0000005c00fffffffb0000001e0054006f006f006c002000500072006f007000650072007400690065007302000001ed000001df00000185000000a3fb000000120056006900650077007300200054006f006f02000001df000002110000018500000122fb000000200054006f006f006c002000500072006f0070006500720074006900650073003203000002880000011d000002210000017afb000000100044006900730070006c006100790073010000003d000001d1000000c900fffffffb0000002000730065006c0065006300740069006f006e00200062007500660066006500720200000138000000aa0000023a00000294fb00000014005700690064006500530074006500720065006f02000000e6000000d2000003ee0000030bfb0000000c004b0069006e0065006300740200000186000001060000030c00000261fb0000000a0049006d0061006700650100000214000001830000001600ffffff00000001000001b90000035afc0200000003fb0000001e0054006f006f006c002000500072006f00700065007200740069006500730100000041000000780000000000000000fb0000000a00560069006500770073000000003d0000035a000000a400fffffffb0000001200530065006c0065006300740069006f006e010000025a000000b200000000000000000000000200000490000000a9fc0100000001fb0000000a00560069006500770073030000004e00000080000002e10000019700000003000007380000003efc0100000002fb0000000800540069006d0065010000000000000738000003bc00fffffffb0000000800540069006d00650100000000000004500000000000000000000004fd0000035a00000004000000040000000800000008fc0000000100000002000000010000000a0054006f006f006c00730100000000ffffffff0000000000000000
Selection: Selection:
collapsed: false collapsed: false
Time: Time:
@ -251,6 +252,6 @@ Window Geometry:
collapsed: false collapsed: false
Views: Views:
collapsed: true collapsed: true
Width: 1848 Width: 2928
X: 72 X: 144
Y: 387 Y: 54

View File

@ -7,13 +7,13 @@
python setup.py install python setup.py install
""" """
import argparse
import os import os
import argparse
import time
import numpy as np import numpy as np
import torch import torch
from torch2trt import torch2trt from torch2trt import torch2trt
from ruamel.yaml import YAML from config.config import cfg
import time
from policy.yopo_network import YopoNetwork from policy.yopo_network import YopoNetwork
@ -25,16 +25,15 @@ def parser():
return parser return parser
def main(): if __name__ == "__main__":
args = parser().parse_args() args = parser().parse_args()
base_dir = os.path.dirname(os.path.abspath(__file__)) base_dir = os.path.dirname(os.path.abspath(__file__))
cfg = YAML().load(open(os.path.join(base_dir, "config/traj_opt.yaml"), 'r'))
weight = base_dir + "/saved/YOPO_{}/epoch{}.pth".format(args.trial, args.epoch) weight = base_dir + "/saved/YOPO_{}/epoch{}.pth".format(args.trial, args.epoch)
print("Loading Network...") print("Loading Network...")
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
state_dict = torch.load(weight, weights_only=True) state_dict = torch.load(weight, weights_only=True)
policy = YopoNetwork(horizon_num=cfg["horizon_num"], vertical_num=cfg["vertical_num"]) policy = YopoNetwork()
policy.load_state_dict(state_dict) policy.load_state_dict(state_dict)
policy = policy.to(device) policy = policy.to(device)
policy.eval() policy.eval()
@ -77,6 +76,3 @@ def main():
f"Transfer Trajectory Error: {traj_error.item():.6f}," f"Transfer Trajectory Error: {traj_error.item():.6f},"
f"Transfer Score Error: {score_error.item():.6f}") f"Transfer Score Error: {score_error.item():.6f}")
if __name__ == "__main__":
main()