Modify to global config and simplified speed adjustment during training and testing

This commit is contained in:
TJU_Lu 2025-07-06 21:31:41 +08:00
parent 788e9bc979
commit bd94cfbd51
14 changed files with 98 additions and 110 deletions

0
YOPO/config/__init__.py Normal file
View File

22
YOPO/config/config.py Normal file
View File

@ -0,0 +1,22 @@
import os
from ruamel.yaml import YAML
# Global Configuration Management
class Config:
def __init__(self):
base_dir = os.path.dirname(os.path.abspath(__file__))
self._data = YAML().load(open(os.path.join(base_dir, "traj_opt.yaml"), 'r'))
self._data["train"] = True
self._data["goal_length"] = 2.0 * self._data['radio_range']
self._data["sgm_time"] = 2 * self._data["radio_range"] / self._data["vel_max_train"]
self._data["traj_num"] = self._data['horizon_num'] * self._data['vertical_num'] * self._data["radio_num"]
def __getitem__(self, key):
return self._data[key]
def __setitem__(self, key, value):
self._data[key] = value
cfg = Config()

View File

@ -1,11 +1,12 @@
# IMPORTANT PARAM: actual velocity in training / testing
# IMPORTANT: velocity in testing (modifiable)
velocity: 6.0
# used to align the vel/acc, ensure consistency between testing and training.
# during testing, if vel*n then acc*n*n, please refer to ../policy/primitive.py
vel_align: 6.0
acc_align: 6.0
# IMPORTANT PARAM: weight of penalties (for unit speed)
# IMPORTANT: vel_max and acc_max in training
# not the actual values in testing, ensure consistency between testing and training, please refer to ../policy/primitive.py
vel_max_train: 6.0
acc_max_train: 6.0
# IMPORTANT: weight of costs for unit speed (can be visualized in tensorboard)
wg: 0.1 # guidance
ws: 10.0 # smoothness
wc: 0.1 # collision
@ -15,7 +16,7 @@ dataset_path: "../dataset"
image_height: 96
image_width: 160
# trajectory and primitive param (Note: traj_time = 2 * radio / vel_max)
# trajectory and primitive param (Note: traj_time = 2 * radio / vel_max) (can be visualized in rviz)
horizon_num: 5
vertical_num: 3
horizon_camera_fov: 90.0
@ -29,7 +30,7 @@ radio_num: 1 # only support 1 currently
d0: 1.2
r: 0.6
# distribution parameters for unit state sampling
# distribution for unit state sampling (can be visualized by 'python ./policy/yopo_dataset.py')
vx_mean_unit: 0.4
vy_mean_unit: 0.0
vz_mean_unit: 0.0

View File

@ -1,16 +1,13 @@
import os
import torch.nn as nn
import torch as th
import torch.nn.functional as F
from ruamel.yaml import YAML
from config.config import cfg
class GuidanceLoss(nn.Module):
def __init__(self):
super(GuidanceLoss, self).__init__()
base_dir = os.path.dirname(os.path.abspath(__file__))
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
self.goal_length = 2.0 * cfg['radio_range']
self.goal_length = cfg['goal_length']
def forward(self, Df, Dp, goal):
"""

View File

@ -1,8 +1,7 @@
import os
import math
import torch as th
import torch.nn as nn
from ruamel.yaml import YAML
from config.config import cfg
from loss.safety_loss import SafetyLoss
from loss.smoothness_loss import SmoothnessLoss
from loss.guidance_loss import GuidanceLoss
@ -17,20 +16,18 @@ class YOPOLoss(nn.Module):
df: fixed parameters
"""
super(YOPOLoss, self).__init__()
base_dir = os.path.dirname(os.path.abspath(__file__))
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
self.sgm_time = 2 * cfg["radio_range"] / cfg["velocity"]
self.sgm_time = cfg["sgm_time"]
self.device = th.device("cuda" if th.cuda.is_available() else "cpu")
self._C, self._B, self._L, self._R = self.qp_generation()
self._R = self._R.to(self.device)
self._L = self._L.to(self.device)
vel_scale = cfg["velocity"] / 1.0
vel_scale = cfg["vel_max_train"] / 1.0
self.smoothness_weight = cfg["ws"]
self.safety_weight = cfg["wc"]
self.goal_weight = cfg["wg"]
self.denormalize_weight(vel_scale)
self.smoothness_loss = SmoothnessLoss(self._R)
self.safety_loss = SafetyLoss(self._L, self.sgm_time)
self.safety_loss = SafetyLoss(self._L)
self.goal_loss = GuidanceLoss()
print("------ Actual Loss ------")
print(f"| {'smooth':<12} = {self.smoothness_weight:6.4f} |")

View File

@ -5,23 +5,21 @@ import torch as th
import torch.nn as nn
import torch.nn.functional as F
import open3d as o3d
from ruamel.yaml import YAML
from scipy.ndimage import distance_transform_edt
from config.config import cfg
class SafetyLoss(nn.Module):
def __init__(self, L, sgm_time):
def __init__(self, L):
super(SafetyLoss, self).__init__()
base_dir = os.path.dirname(os.path.abspath(__file__))
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
self.traj_num = cfg['horizon_num'] * cfg['vertical_num']
self.traj_num = cfg['traj_num']
self.map_expand_min = np.array(cfg['map_expand_min'])
self.map_expand_max = np.array(cfg['map_expand_max'])
self.d0 = cfg["d0"]
self.r = cfg["r"]
self._L = L
self.sgm_time = sgm_time
self.sgm_time = cfg["sgm_time"]
self.eval_points = 30
self.device = self._L.device
@ -31,6 +29,7 @@ class SafetyLoss(nn.Module):
self.max_bounds = None # shape: (N, 3)
self.sdf_shapes = None # shape: (1, 3)
print("Building ESDF map...")
base_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.join(base_dir, "../", cfg["dataset_path"])
self.sdf_maps = self.get_sdf_from_ply(data_dir)
print("Map built!")

View File

@ -1,16 +1,18 @@
import torch
from scipy.spatial.transform import Rotation as R
from config.config import cfg
class LatticeParam:
def __init__(self, cfg):
ratio = cfg["velocity"] / cfg["vel_align"]
self.vel_max = ratio * cfg["vel_align"]
self.acc_max = ratio * ratio * cfg["acc_align"]
self.segment_time = 2 * cfg["radio_range"] / self.vel_max
def __init__(self):
ratio = 1.0 if cfg["train"] else cfg["velocity"] / cfg["vel_max_train"]
self.vel_max = ratio * cfg["vel_max_train"]
self.acc_max = ratio * ratio * cfg["acc_max_train"]
self.segment_time = cfg["sgm_time"] / ratio
self.horizon_num = cfg["horizon_num"]
self.vertical_num = cfg["vertical_num"]
self.radio_num = cfg["radio_num"]
self.traj_num = cfg["traj_num"]
self.horizon_fov = cfg["horizon_camera_fov"]
self.vertical_fov = cfg["vertical_camera_fov"]
self.horizon_anchor_fov = cfg["horizon_anchor_fov"]
@ -38,12 +40,10 @@ class LatticePrimitive(LatticeParam):
"""
_instance = None
def __init__(self, cfg):
super().__init__(cfg)
def __init__(self):
super().__init__()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.traj_num = self.vertical_num * self.horizon_num * self.radio_num
if self.horizon_num == 1:
direction_diff = 0
else:
@ -105,16 +105,6 @@ class LatticePrimitive(LatticeParam):
return self.traj_num - id - 1
@classmethod
def get_instance(self, cfg):
if self._instance is None: self._instance = self(cfg)
def get_instance(self):
if self._instance is None: self._instance = self()
return self._instance
if __name__ == '__main__':
import os
from ruamel.yaml import YAML
base_dir = os.path.dirname(os.path.abspath(__file__))
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
lattice_primitive = LatticePrimitive.get_instance(cfg)
print(lattice_primitive.getStateLattice(list(range(lattice_primitive.traj_num))))

View File

@ -1,16 +1,13 @@
import os
import torch
import numpy as np
from ruamel.yaml import YAML
from config.config import cfg
from policy.primitive import LatticePrimitive
class StateTransform:
def __init__(self):
base_dir = os.path.dirname(os.path.abspath(__file__))
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
self.lattice_primitive = LatticePrimitive.get_instance(cfg)
self.goal_length = 2.0 * cfg['radio_range']
self.lattice_primitive = LatticePrimitive.get_instance()
self.goal_length = cfg['goal_length']
def pred_to_endstate(self, endstate_pred: torch.Tensor) -> torch.Tensor:
"""

View File

@ -1,38 +1,37 @@
import os
import os, sys
import cv2
import time
import numpy as np
from torch.utils.data import Dataset, DataLoader
from ruamel.yaml import YAML
import time
from scipy.spatial.transform import Rotation as R
from sklearn.model_selection import train_test_split
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from config.config import cfg
class YOPODataset(Dataset):
def __init__(self, mode='train', val_ratio=0.1):
super(YOPODataset, self).__init__()
base_dir = os.path.dirname(os.path.abspath(__file__))
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
# image params
self.height = int(cfg["image_height"])
self.width = int(cfg["image_width"])
# ramdom state: x-direction: log-normal distribution, yz-direction: normal distribution
scale = cfg["velocity"] / cfg["vel_align"]
self.vel_max = scale * cfg["vel_align"]
self.acc_max = scale * scale * cfg["acc_align"]
self.vel_max = cfg["vel_max_train"]
self.acc_max = cfg["acc_max_train"]
self.vx_lognorm_mean = np.log(1 - cfg["vx_mean_unit"])
self.vx_logmorm_sigma = np.log(cfg["vx_std_unit"])
self.v_mean = np.array([cfg["vx_mean_unit"], cfg["vy_mean_unit"], cfg["vz_mean_unit"]])
self.v_std = np.array([cfg["vx_std_unit"], cfg["vy_std_unit"], cfg["vz_std_unit"]])
self.a_mean = np.array([cfg["ax_mean_unit"], cfg["ay_mean_unit"], cfg["az_mean_unit"]])
self.a_std = np.array([cfg["ax_std_unit"], cfg["ay_std_unit"], cfg["az_std_unit"]])
self.goal_length = 2.0 * cfg['radio_range']
self.goal_length = cfg['goal_length']
self.goal_pitch_std = cfg["goal_pitch_std"]
self.goal_yaw_std = cfg["goal_yaw_std"]
if mode == 'train': self.print_data()
# dataset
print("Loading", mode, "dataset, it may take a while...")
base_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.join(base_dir, "../", cfg["dataset_path"])
self.img_list, self.map_idx, self.positions, self.quaternions = [], [], np.empty((0, 3), dtype=np.float32), np.empty((0, 4), dtype=np.float32)

View File

@ -2,6 +2,7 @@
Training Strategy
supervised learning, imitation learning, testing, rollout
"""
import os
import time
import atexit
from torch.nn import functional as F
@ -9,6 +10,7 @@ from rich.progress import Progress
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter
from config.config import cfg
from loss.loss_function import YOPOLoss
from policy.yopo_network import YopoNetwork
from policy.yopo_dataset import YOPODataset
@ -35,10 +37,7 @@ class YopoTrainer:
self.tensorboard_path = self.get_next_log_path(tensorboard_path)
self.tensorboard_log = SummaryWriter(log_dir=self.tensorboard_path)
# params
base_dir = os.path.dirname(os.path.abspath(__file__))
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
self.lattice_primitive = LatticePrimitive.get_instance(cfg)
self.traj_num = self.lattice_primitive.traj_num
self.traj_num = cfg['traj_num']
# loss
self.yopo_loss = YOPOLoss()

View File

@ -13,9 +13,9 @@ import time
import torch
import numpy as np
import argparse
from ruamel.yaml import YAML
from scipy.spatial.transform import Rotation as R
from config.config import cfg
from control_msg import PositionCommand
from policy.yopo_network import YopoNetwork
from policy.poly_solver import *
@ -32,8 +32,7 @@ class YopoNet:
self.config = config
rospy.init_node('yopo_net', anonymous=False)
# load params
base_dir = os.path.dirname(os.path.abspath(__file__))
cfg = YAML().load(open(os.path.join(base_dir, "config/traj_opt.yaml"), 'r'))
cfg["train"] = False
self.height = cfg['image_height']
self.width = cfg['image_width']
self.min_dis, self.max_dis = 0.04, 20.0
@ -64,7 +63,7 @@ class YopoNet:
self.lock = Lock()
self.last_control_msg = None
self.state_transform = StateTransform()
self.lattice_primitive = LatticePrimitive.get_instance(cfg)
self.lattice_primitive = LatticePrimitive.get_instance()
self.traj_time = self.lattice_primitive.segment_time
# eval
@ -354,7 +353,7 @@ def parser():
return parser
def main():
if __name__ == "__main__":
args = parser().parse_args()
base_dir = os.path.dirname(os.path.abspath(__file__))
weight = "yopo_trt.pth" if args.use_tensorrt else base_dir + "/saved/YOPO_{}/epoch{}.pth".format(args.trial, args.epoch)
@ -372,7 +371,3 @@ def main():
'visualize': True # 可视化所有轨迹?(实飞改为False节省计算)
}
YopoNet(settings, weight)
if __name__ == "__main__":
main()

View File

@ -24,10 +24,9 @@ def parser():
return parser
def main():
if __name__ == "__main__":
args = parser().parse_args()
# set random seed
configure_random_seed(0)
configure_random_seed(0) # set random seed
# save the configuration and other files
log_dir = os.path.dirname(os.path.abspath(__file__)) + "/saved"
@ -46,7 +45,3 @@ def main():
trainer.train(epoch=50)
print("Run YOPO Finish!")
if __name__ == "__main__":
main()

View File

@ -1,12 +1,13 @@
Panels:
- Class: rviz/Displays
Help Height: 78
Help Height: 138
Name: Displays
Property Tree Widget:
Expanded:
- /PointCloud21/Autocompute Value Bounds1
Splitter Ratio: 0.5
Tree Height: 326
- /Map1/Autocompute Value Bounds1
- /Trajectory1
Splitter Ratio: 0.6625221967697144
Tree Height: 476
- Class: rviz/Selection
Name: Selection
- Class: rviz/Tool Properties
@ -24,7 +25,7 @@ Panels:
- Class: rviz/Time
Name: Time
SyncMode: 0
SyncSource: Image
SyncSource: Depth
Preferences:
PromptSaveOnExit: true
Toolbars:
@ -56,7 +57,7 @@ Visualization Manager:
Max Value: 1
Median window: 5
Min Value: 0
Name: Image
Name: Depth
Normalize Range: true
Queue Size: 2
Transport Hint: raw
@ -78,7 +79,7 @@ Visualization Manager:
Invert Rainbow: false
Max Color: 255; 255; 255
Min Color: 0; 0; 0
Name: PointCloud2
Name: Map
Position Transformer: XYZ
Queue Size: 10
Selectable: true
@ -108,7 +109,7 @@ Visualization Manager:
Invert Rainbow: false
Max Color: 255; 255; 255
Min Color: 0; 0; 0
Name: PointCloud2
Name: Best_Traj
Position Transformer: XYZ
Queue Size: 10
Selectable: true
@ -136,7 +137,7 @@ Visualization Manager:
Invert Rainbow: true
Max Color: 255; 255; 255
Min Color: 0; 0; 0
Name: PointCloud2
Name: All_traj
Position Transformer: XYZ
Queue Size: 10
Selectable: true
@ -164,7 +165,7 @@ Visualization Manager:
Invert Rainbow: false
Max Color: 255; 255; 255
Min Color: 0; 0; 0
Name: PointCloud2
Name: Primitive_Traj
Position Transformer: XYZ
Queue Size: 10
Selectable: true
@ -177,11 +178,11 @@ Visualization Manager:
Use rainbow: true
Value: false
Enabled: true
Name: Traj
Name: Trajectory
- Class: rviz/Marker
Enabled: true
Marker Topic: /quadrotor_simulator_so3/uav
Name: Marker
Name: Drone
Namespaces:
mesh: true
Queue Size: 100
@ -235,14 +236,14 @@ Visualization Manager:
Yaw: 3.140000104904175
Saved: ~
Window Geometry:
Depth:
collapsed: false
Displays:
collapsed: false
Height: 1016
Height: 1600
Hide Left Dock: false
Hide Right Dock: true
Image:
collapsed: false
QMainWindow State: 000000ff00000000fd0000000400000000000002350000035afc0200000009fb0000001200530065006c0065006300740069006f006e00000001e10000009b0000005c00fffffffb0000001e0054006f006f006c002000500072006f007000650072007400690065007302000001ed000001df00000185000000a3fb000000120056006900650077007300200054006f006f02000001df000002110000018500000122fb000000200054006f006f006c002000500072006f0070006500720074006900650073003203000002880000011d000002210000017afb000000100044006900730070006c006100790073010000003d000001d1000000c900fffffffb0000002000730065006c0065006300740069006f006e00200062007500660066006500720200000138000000aa0000023a00000294fb00000014005700690064006500530074006500720065006f02000000e6000000d2000003ee0000030bfb0000000c004b0069006e0065006300740200000186000001060000030c00000261fb0000000a0049006d0061006700650100000214000001830000001600ffffff00000001000001b90000035afc0200000003fb0000001e0054006f006f006c002000500072006f00700065007200740069006500730100000041000000780000000000000000fb0000000a00560069006500770073000000003d0000035a000000a400fffffffb0000001200530065006c0065006300740069006f006e010000025a000000b200000000000000000000000200000490000000a9fc0100000001fb0000000a00560069006500770073030000004e00000080000002e10000019700000003000007380000003efc0100000002fb0000000800540069006d0065010000000000000738000003bc00fffffffb0000000800540069006d00650100000000000004500000000000000000000004fd0000035a00000004000000040000000800000008fc0000000100000002000000010000000a0054006f006f006c00730100000000ffffffff0000000000000000
QMainWindow State: 000000ff00000000fd0000000400000000000003310000053afc0200000009fb0000001200530065006c0065006300740069006f006e00000001e10000009b000000b000fffffffb0000001e0054006f006f006c002000500072006f007000650072007400690065007302000001ed000001df00000185000000b0fb000000120056006900650077007300200054006f006f02000001df000002110000018500000122fb000000200054006f006f006c002000500072006f0070006500720074006900650073003203000002880000011d000002210000017afb000000100044006900730070006c006100790073010000006e000002d40000018200fffffffb0000002000730065006c0065006300740069006f006e00200062007500660066006500720200000138000000aa0000023a00000294fb00000014005700690064006500530074006500720065006f02000000e6000000d2000003ee0000030bfb0000000c004b0069006e0065006300740200000186000001060000030c00000261fb0000000a00440065007000740068010000034e0000025a0000002600ffffff00000001000001b90000035afc0200000003fb0000001e0054006f006f006c002000500072006f00700065007200740069006500730100000041000000780000000000000000fb0000000a00560069006500770073000000003d0000035a0000013200fffffffb0000001200530065006c0065006300740069006f006e010000025a000000b200000000000000000000000200000490000000a9fc0100000001fb0000000a00560069006500770073030000004e00000080000002e1000001970000000300000b700000005efc0100000002fb0000000800540069006d0065010000000000000b70000006dc00fffffffb0000000800540069006d00650100000000000004500000000000000000000008330000053a00000004000000040000000800000008fc0000000100000002000000010000000a0054006f006f006c00730100000000ffffffff0000000000000000
Selection:
collapsed: false
Time:
@ -251,6 +252,6 @@ Window Geometry:
collapsed: false
Views:
collapsed: true
Width: 1848
X: 72
Y: 387
Width: 2928
X: 144
Y: 54

View File

@ -7,13 +7,13 @@
python setup.py install
"""
import argparse
import os
import argparse
import time
import numpy as np
import torch
from torch2trt import torch2trt
from ruamel.yaml import YAML
import time
from config.config import cfg
from policy.yopo_network import YopoNetwork
@ -25,16 +25,15 @@ def parser():
return parser
def main():
if __name__ == "__main__":
args = parser().parse_args()
base_dir = os.path.dirname(os.path.abspath(__file__))
cfg = YAML().load(open(os.path.join(base_dir, "config/traj_opt.yaml"), 'r'))
weight = base_dir + "/saved/YOPO_{}/epoch{}.pth".format(args.trial, args.epoch)
print("Loading Network...")
device = "cuda" if torch.cuda.is_available() else "cpu"
state_dict = torch.load(weight, weights_only=True)
policy = YopoNetwork(horizon_num=cfg["horizon_num"], vertical_num=cfg["vertical_num"])
policy = YopoNetwork()
policy.load_state_dict(state_dict)
policy = policy.to(device)
policy.eval()
@ -77,6 +76,3 @@ def main():
f"Transfer Trajectory Error: {traj_error.item():.6f},"
f"Transfer Score Error: {score_error.item():.6f}")
if __name__ == "__main__":
main()