diff --git a/YOPO/config/__init__.py b/YOPO/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/YOPO/config/config.py b/YOPO/config/config.py new file mode 100644 index 0000000..1bb8ddc --- /dev/null +++ b/YOPO/config/config.py @@ -0,0 +1,22 @@ +import os +from ruamel.yaml import YAML + + +# Global Configuration Management +class Config: + def __init__(self): + base_dir = os.path.dirname(os.path.abspath(__file__)) + self._data = YAML().load(open(os.path.join(base_dir, "traj_opt.yaml"), 'r')) + self._data["train"] = True + self._data["goal_length"] = 2.0 * self._data['radio_range'] + self._data["sgm_time"] = 2 * self._data["radio_range"] / self._data["vel_max_train"] + self._data["traj_num"] = self._data['horizon_num'] * self._data['vertical_num'] * self._data["radio_num"] + + def __getitem__(self, key): + return self._data[key] + + def __setitem__(self, key, value): + self._data[key] = value + + +cfg = Config() diff --git a/YOPO/config/traj_opt.yaml b/YOPO/config/traj_opt.yaml index 78eacc0..4994aae 100644 --- a/YOPO/config/traj_opt.yaml +++ b/YOPO/config/traj_opt.yaml @@ -1,11 +1,12 @@ -# IMPORTANT PARAM: actual velocity in training / testing +# IMPORTANT: velocity in testing (modifiable) velocity: 6.0 -# used to align the vel/acc, ensure consistency between testing and training. -# during testing, if vel*n then acc*n*n, please refer to ../policy/primitive.py -vel_align: 6.0 -acc_align: 6.0 -# IMPORTANT PARAM: weight of penalties (for unit speed) +# IMPORTANT: vel_max and acc_max in training +# not the actual values in testing, ensure consistency between testing and training, please refer to ../policy/primitive.py +vel_max_train: 6.0 +acc_max_train: 6.0 + +# IMPORTANT: weight of costs for unit speed (can be visualized in tensorboard) wg: 0.1 # guidance ws: 10.0 # smoothness wc: 0.1 # collision @@ -15,7 +16,7 @@ dataset_path: "../dataset" image_height: 96 image_width: 160 -# trajectory and primitive param (Note: traj_time = 2 * radio / vel_max) +# trajectory and primitive param (Note: traj_time = 2 * radio / vel_max) (can be visualized in rviz) horizon_num: 5 vertical_num: 3 horizon_camera_fov: 90.0 @@ -29,7 +30,7 @@ radio_num: 1 # only support 1 currently d0: 1.2 r: 0.6 -# distribution parameters for unit state sampling +# distribution for unit state sampling (can be visualized by 'python ./policy/yopo_dataset.py') vx_mean_unit: 0.4 vy_mean_unit: 0.0 vz_mean_unit: 0.0 diff --git a/YOPO/loss/guidance_loss.py b/YOPO/loss/guidance_loss.py index 5b8ab04..a4646a3 100644 --- a/YOPO/loss/guidance_loss.py +++ b/YOPO/loss/guidance_loss.py @@ -1,16 +1,13 @@ -import os import torch.nn as nn import torch as th import torch.nn.functional as F -from ruamel.yaml import YAML +from config.config import cfg class GuidanceLoss(nn.Module): def __init__(self): super(GuidanceLoss, self).__init__() - base_dir = os.path.dirname(os.path.abspath(__file__)) - cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r')) - self.goal_length = 2.0 * cfg['radio_range'] + self.goal_length = cfg['goal_length'] def forward(self, Df, Dp, goal): """ diff --git a/YOPO/loss/loss_function.py b/YOPO/loss/loss_function.py index 6bd0d34..eb23c36 100644 --- a/YOPO/loss/loss_function.py +++ b/YOPO/loss/loss_function.py @@ -1,8 +1,7 @@ -import os import math import torch as th import torch.nn as nn -from ruamel.yaml import YAML +from config.config import cfg from loss.safety_loss import SafetyLoss from loss.smoothness_loss import SmoothnessLoss from loss.guidance_loss import GuidanceLoss @@ -17,20 +16,18 @@ class YOPOLoss(nn.Module): df: fixed parameters """ super(YOPOLoss, self).__init__() - base_dir = os.path.dirname(os.path.abspath(__file__)) - cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r')) - self.sgm_time = 2 * cfg["radio_range"] / cfg["velocity"] + self.sgm_time = cfg["sgm_time"] self.device = th.device("cuda" if th.cuda.is_available() else "cpu") self._C, self._B, self._L, self._R = self.qp_generation() self._R = self._R.to(self.device) self._L = self._L.to(self.device) - vel_scale = cfg["velocity"] / 1.0 + vel_scale = cfg["vel_max_train"] / 1.0 self.smoothness_weight = cfg["ws"] self.safety_weight = cfg["wc"] self.goal_weight = cfg["wg"] self.denormalize_weight(vel_scale) self.smoothness_loss = SmoothnessLoss(self._R) - self.safety_loss = SafetyLoss(self._L, self.sgm_time) + self.safety_loss = SafetyLoss(self._L) self.goal_loss = GuidanceLoss() print("------ Actual Loss ------") print(f"| {'smooth':<12} = {self.smoothness_weight:6.4f} |") diff --git a/YOPO/loss/safety_loss.py b/YOPO/loss/safety_loss.py index 44cd50f..e1310bc 100644 --- a/YOPO/loss/safety_loss.py +++ b/YOPO/loss/safety_loss.py @@ -5,23 +5,21 @@ import torch as th import torch.nn as nn import torch.nn.functional as F import open3d as o3d -from ruamel.yaml import YAML from scipy.ndimage import distance_transform_edt +from config.config import cfg class SafetyLoss(nn.Module): - def __init__(self, L, sgm_time): + def __init__(self, L): super(SafetyLoss, self).__init__() - base_dir = os.path.dirname(os.path.abspath(__file__)) - cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r')) - self.traj_num = cfg['horizon_num'] * cfg['vertical_num'] + self.traj_num = cfg['traj_num'] self.map_expand_min = np.array(cfg['map_expand_min']) self.map_expand_max = np.array(cfg['map_expand_max']) self.d0 = cfg["d0"] self.r = cfg["r"] self._L = L - self.sgm_time = sgm_time + self.sgm_time = cfg["sgm_time"] self.eval_points = 30 self.device = self._L.device @@ -31,6 +29,7 @@ class SafetyLoss(nn.Module): self.max_bounds = None # shape: (N, 3) self.sdf_shapes = None # shape: (1, 3) print("Building ESDF map...") + base_dir = os.path.dirname(os.path.abspath(__file__)) data_dir = os.path.join(base_dir, "../", cfg["dataset_path"]) self.sdf_maps = self.get_sdf_from_ply(data_dir) print("Map built!") diff --git a/YOPO/policy/primitive.py b/YOPO/policy/primitive.py index ecd9e1e..f392bd2 100644 --- a/YOPO/policy/primitive.py +++ b/YOPO/policy/primitive.py @@ -1,16 +1,18 @@ import torch from scipy.spatial.transform import Rotation as R +from config.config import cfg class LatticeParam: - def __init__(self, cfg): - ratio = cfg["velocity"] / cfg["vel_align"] - self.vel_max = ratio * cfg["vel_align"] - self.acc_max = ratio * ratio * cfg["acc_align"] - self.segment_time = 2 * cfg["radio_range"] / self.vel_max + def __init__(self): + ratio = 1.0 if cfg["train"] else cfg["velocity"] / cfg["vel_max_train"] + self.vel_max = ratio * cfg["vel_max_train"] + self.acc_max = ratio * ratio * cfg["acc_max_train"] + self.segment_time = cfg["sgm_time"] / ratio self.horizon_num = cfg["horizon_num"] self.vertical_num = cfg["vertical_num"] self.radio_num = cfg["radio_num"] + self.traj_num = cfg["traj_num"] self.horizon_fov = cfg["horizon_camera_fov"] self.vertical_fov = cfg["vertical_camera_fov"] self.horizon_anchor_fov = cfg["horizon_anchor_fov"] @@ -38,12 +40,10 @@ class LatticePrimitive(LatticeParam): """ _instance = None - def __init__(self, cfg): - super().__init__(cfg) + def __init__(self): + super().__init__() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.traj_num = self.vertical_num * self.horizon_num * self.radio_num - if self.horizon_num == 1: direction_diff = 0 else: @@ -105,16 +105,6 @@ class LatticePrimitive(LatticeParam): return self.traj_num - id - 1 @classmethod - def get_instance(self, cfg): - if self._instance is None: self._instance = self(cfg) + def get_instance(self): + if self._instance is None: self._instance = self() return self._instance - - -if __name__ == '__main__': - import os - from ruamel.yaml import YAML - - base_dir = os.path.dirname(os.path.abspath(__file__)) - cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r')) - lattice_primitive = LatticePrimitive.get_instance(cfg) - print(lattice_primitive.getStateLattice(list(range(lattice_primitive.traj_num)))) diff --git a/YOPO/policy/state_transform.py b/YOPO/policy/state_transform.py index b4a0cba..eb0f30c 100644 --- a/YOPO/policy/state_transform.py +++ b/YOPO/policy/state_transform.py @@ -1,16 +1,13 @@ -import os import torch import numpy as np -from ruamel.yaml import YAML +from config.config import cfg from policy.primitive import LatticePrimitive class StateTransform: def __init__(self): - base_dir = os.path.dirname(os.path.abspath(__file__)) - cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r')) - self.lattice_primitive = LatticePrimitive.get_instance(cfg) - self.goal_length = 2.0 * cfg['radio_range'] + self.lattice_primitive = LatticePrimitive.get_instance() + self.goal_length = cfg['goal_length'] def pred_to_endstate(self, endstate_pred: torch.Tensor) -> torch.Tensor: """ diff --git a/YOPO/policy/yopo_dataset.py b/YOPO/policy/yopo_dataset.py index 6ba1ac3..6fa16ff 100644 --- a/YOPO/policy/yopo_dataset.py +++ b/YOPO/policy/yopo_dataset.py @@ -1,38 +1,37 @@ -import os +import os, sys import cv2 +import time import numpy as np from torch.utils.data import Dataset, DataLoader -from ruamel.yaml import YAML -import time from scipy.spatial.transform import Rotation as R from sklearn.model_selection import train_test_split +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from config.config import cfg class YOPODataset(Dataset): def __init__(self, mode='train', val_ratio=0.1): super(YOPODataset, self).__init__() - base_dir = os.path.dirname(os.path.abspath(__file__)) - cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r')) # image params self.height = int(cfg["image_height"]) self.width = int(cfg["image_width"]) # ramdom state: x-direction: log-normal distribution, yz-direction: normal distribution - scale = cfg["velocity"] / cfg["vel_align"] - self.vel_max = scale * cfg["vel_align"] - self.acc_max = scale * scale * cfg["acc_align"] + self.vel_max = cfg["vel_max_train"] + self.acc_max = cfg["acc_max_train"] self.vx_lognorm_mean = np.log(1 - cfg["vx_mean_unit"]) self.vx_logmorm_sigma = np.log(cfg["vx_std_unit"]) self.v_mean = np.array([cfg["vx_mean_unit"], cfg["vy_mean_unit"], cfg["vz_mean_unit"]]) self.v_std = np.array([cfg["vx_std_unit"], cfg["vy_std_unit"], cfg["vz_std_unit"]]) self.a_mean = np.array([cfg["ax_mean_unit"], cfg["ay_mean_unit"], cfg["az_mean_unit"]]) self.a_std = np.array([cfg["ax_std_unit"], cfg["ay_std_unit"], cfg["az_std_unit"]]) - self.goal_length = 2.0 * cfg['radio_range'] + self.goal_length = cfg['goal_length'] self.goal_pitch_std = cfg["goal_pitch_std"] self.goal_yaw_std = cfg["goal_yaw_std"] if mode == 'train': self.print_data() # dataset print("Loading", mode, "dataset, it may take a while...") + base_dir = os.path.dirname(os.path.abspath(__file__)) data_dir = os.path.join(base_dir, "../", cfg["dataset_path"]) self.img_list, self.map_idx, self.positions, self.quaternions = [], [], np.empty((0, 3), dtype=np.float32), np.empty((0, 4), dtype=np.float32) diff --git a/YOPO/policy/yopo_trainer.py b/YOPO/policy/yopo_trainer.py index 9e46716..f07352a 100644 --- a/YOPO/policy/yopo_trainer.py +++ b/YOPO/policy/yopo_trainer.py @@ -2,6 +2,7 @@ Training Strategy supervised learning, imitation learning, testing, rollout """ +import os import time import atexit from torch.nn import functional as F @@ -9,6 +10,7 @@ from rich.progress import Progress from torch.utils.data import DataLoader from torch.utils.tensorboard.writer import SummaryWriter +from config.config import cfg from loss.loss_function import YOPOLoss from policy.yopo_network import YopoNetwork from policy.yopo_dataset import YOPODataset @@ -35,10 +37,7 @@ class YopoTrainer: self.tensorboard_path = self.get_next_log_path(tensorboard_path) self.tensorboard_log = SummaryWriter(log_dir=self.tensorboard_path) # params - base_dir = os.path.dirname(os.path.abspath(__file__)) - cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r')) - self.lattice_primitive = LatticePrimitive.get_instance(cfg) - self.traj_num = self.lattice_primitive.traj_num + self.traj_num = cfg['traj_num'] # loss self.yopo_loss = YOPOLoss() diff --git a/YOPO/test_yopo_ros.py b/YOPO/test_yopo_ros.py index 6905506..11694fe 100644 --- a/YOPO/test_yopo_ros.py +++ b/YOPO/test_yopo_ros.py @@ -13,9 +13,9 @@ import time import torch import numpy as np import argparse -from ruamel.yaml import YAML from scipy.spatial.transform import Rotation as R +from config.config import cfg from control_msg import PositionCommand from policy.yopo_network import YopoNetwork from policy.poly_solver import * @@ -32,8 +32,7 @@ class YopoNet: self.config = config rospy.init_node('yopo_net', anonymous=False) # load params - base_dir = os.path.dirname(os.path.abspath(__file__)) - cfg = YAML().load(open(os.path.join(base_dir, "config/traj_opt.yaml"), 'r')) + cfg["train"] = False self.height = cfg['image_height'] self.width = cfg['image_width'] self.min_dis, self.max_dis = 0.04, 20.0 @@ -64,7 +63,7 @@ class YopoNet: self.lock = Lock() self.last_control_msg = None self.state_transform = StateTransform() - self.lattice_primitive = LatticePrimitive.get_instance(cfg) + self.lattice_primitive = LatticePrimitive.get_instance() self.traj_time = self.lattice_primitive.segment_time # eval @@ -354,7 +353,7 @@ def parser(): return parser -def main(): +if __name__ == "__main__": args = parser().parse_args() base_dir = os.path.dirname(os.path.abspath(__file__)) weight = "yopo_trt.pth" if args.use_tensorrt else base_dir + "/saved/YOPO_{}/epoch{}.pth".format(args.trial, args.epoch) @@ -372,7 +371,3 @@ def main(): 'visualize': True # 可视化所有轨迹?(实飞改为False节省计算) } YopoNet(settings, weight) - - -if __name__ == "__main__": - main() diff --git a/YOPO/train_yopo.py b/YOPO/train_yopo.py index 493b5ac..9522985 100644 --- a/YOPO/train_yopo.py +++ b/YOPO/train_yopo.py @@ -24,10 +24,9 @@ def parser(): return parser -def main(): +if __name__ == "__main__": args = parser().parse_args() - # set random seed - configure_random_seed(0) + configure_random_seed(0) # set random seed # save the configuration and other files log_dir = os.path.dirname(os.path.abspath(__file__)) + "/saved" @@ -46,7 +45,3 @@ def main(): trainer.train(epoch=50) print("Run YOPO Finish!") - - -if __name__ == "__main__": - main() diff --git a/YOPO/yopo.rviz b/YOPO/yopo.rviz index 178bf13..2d56b3c 100644 --- a/YOPO/yopo.rviz +++ b/YOPO/yopo.rviz @@ -1,12 +1,13 @@ Panels: - Class: rviz/Displays - Help Height: 78 + Help Height: 138 Name: Displays Property Tree Widget: Expanded: - - /PointCloud21/Autocompute Value Bounds1 - Splitter Ratio: 0.5 - Tree Height: 326 + - /Map1/Autocompute Value Bounds1 + - /Trajectory1 + Splitter Ratio: 0.6625221967697144 + Tree Height: 476 - Class: rviz/Selection Name: Selection - Class: rviz/Tool Properties @@ -24,7 +25,7 @@ Panels: - Class: rviz/Time Name: Time SyncMode: 0 - SyncSource: Image + SyncSource: Depth Preferences: PromptSaveOnExit: true Toolbars: @@ -56,7 +57,7 @@ Visualization Manager: Max Value: 1 Median window: 5 Min Value: 0 - Name: Image + Name: Depth Normalize Range: true Queue Size: 2 Transport Hint: raw @@ -78,7 +79,7 @@ Visualization Manager: Invert Rainbow: false Max Color: 255; 255; 255 Min Color: 0; 0; 0 - Name: PointCloud2 + Name: Map Position Transformer: XYZ Queue Size: 10 Selectable: true @@ -108,7 +109,7 @@ Visualization Manager: Invert Rainbow: false Max Color: 255; 255; 255 Min Color: 0; 0; 0 - Name: PointCloud2 + Name: Best_Traj Position Transformer: XYZ Queue Size: 10 Selectable: true @@ -136,7 +137,7 @@ Visualization Manager: Invert Rainbow: true Max Color: 255; 255; 255 Min Color: 0; 0; 0 - Name: PointCloud2 + Name: All_traj Position Transformer: XYZ Queue Size: 10 Selectable: true @@ -164,7 +165,7 @@ Visualization Manager: Invert Rainbow: false Max Color: 255; 255; 255 Min Color: 0; 0; 0 - Name: PointCloud2 + Name: Primitive_Traj Position Transformer: XYZ Queue Size: 10 Selectable: true @@ -177,11 +178,11 @@ Visualization Manager: Use rainbow: true Value: false Enabled: true - Name: Traj + Name: Trajectory - Class: rviz/Marker Enabled: true Marker Topic: /quadrotor_simulator_so3/uav - Name: Marker + Name: Drone Namespaces: mesh: true Queue Size: 100 @@ -235,14 +236,14 @@ Visualization Manager: Yaw: 3.140000104904175 Saved: ~ Window Geometry: + Depth: + collapsed: false Displays: collapsed: false - Height: 1016 + Height: 1600 Hide Left Dock: false Hide Right Dock: true - Image: - collapsed: false - QMainWindow State: 000000ff00000000fd0000000400000000000002350000035afc0200000009fb0000001200530065006c0065006300740069006f006e00000001e10000009b0000005c00fffffffb0000001e0054006f006f006c002000500072006f007000650072007400690065007302000001ed000001df00000185000000a3fb000000120056006900650077007300200054006f006f02000001df000002110000018500000122fb000000200054006f006f006c002000500072006f0070006500720074006900650073003203000002880000011d000002210000017afb000000100044006900730070006c006100790073010000003d000001d1000000c900fffffffb0000002000730065006c0065006300740069006f006e00200062007500660066006500720200000138000000aa0000023a00000294fb00000014005700690064006500530074006500720065006f02000000e6000000d2000003ee0000030bfb0000000c004b0069006e0065006300740200000186000001060000030c00000261fb0000000a0049006d0061006700650100000214000001830000001600ffffff00000001000001b90000035afc0200000003fb0000001e0054006f006f006c002000500072006f00700065007200740069006500730100000041000000780000000000000000fb0000000a00560069006500770073000000003d0000035a000000a400fffffffb0000001200530065006c0065006300740069006f006e010000025a000000b200000000000000000000000200000490000000a9fc0100000001fb0000000a00560069006500770073030000004e00000080000002e10000019700000003000007380000003efc0100000002fb0000000800540069006d0065010000000000000738000003bc00fffffffb0000000800540069006d00650100000000000004500000000000000000000004fd0000035a00000004000000040000000800000008fc0000000100000002000000010000000a0054006f006f006c00730100000000ffffffff0000000000000000 + QMainWindow State: 000000ff00000000fd0000000400000000000003310000053afc0200000009fb0000001200530065006c0065006300740069006f006e00000001e10000009b000000b000fffffffb0000001e0054006f006f006c002000500072006f007000650072007400690065007302000001ed000001df00000185000000b0fb000000120056006900650077007300200054006f006f02000001df000002110000018500000122fb000000200054006f006f006c002000500072006f0070006500720074006900650073003203000002880000011d000002210000017afb000000100044006900730070006c006100790073010000006e000002d40000018200fffffffb0000002000730065006c0065006300740069006f006e00200062007500660066006500720200000138000000aa0000023a00000294fb00000014005700690064006500530074006500720065006f02000000e6000000d2000003ee0000030bfb0000000c004b0069006e0065006300740200000186000001060000030c00000261fb0000000a00440065007000740068010000034e0000025a0000002600ffffff00000001000001b90000035afc0200000003fb0000001e0054006f006f006c002000500072006f00700065007200740069006500730100000041000000780000000000000000fb0000000a00560069006500770073000000003d0000035a0000013200fffffffb0000001200530065006c0065006300740069006f006e010000025a000000b200000000000000000000000200000490000000a9fc0100000001fb0000000a00560069006500770073030000004e00000080000002e1000001970000000300000b700000005efc0100000002fb0000000800540069006d0065010000000000000b70000006dc00fffffffb0000000800540069006d00650100000000000004500000000000000000000008330000053a00000004000000040000000800000008fc0000000100000002000000010000000a0054006f006f006c00730100000000ffffffff0000000000000000 Selection: collapsed: false Time: @@ -251,6 +252,6 @@ Window Geometry: collapsed: false Views: collapsed: true - Width: 1848 - X: 72 - Y: 387 + Width: 2928 + X: 144 + Y: 54 diff --git a/YOPO/yopo_trt_transfer.py b/YOPO/yopo_trt_transfer.py index 1e5c17a..512909c 100644 --- a/YOPO/yopo_trt_transfer.py +++ b/YOPO/yopo_trt_transfer.py @@ -7,13 +7,13 @@ python setup.py install """ -import argparse import os +import argparse +import time import numpy as np import torch from torch2trt import torch2trt -from ruamel.yaml import YAML -import time +from config.config import cfg from policy.yopo_network import YopoNetwork @@ -25,16 +25,15 @@ def parser(): return parser -def main(): +if __name__ == "__main__": args = parser().parse_args() base_dir = os.path.dirname(os.path.abspath(__file__)) - cfg = YAML().load(open(os.path.join(base_dir, "config/traj_opt.yaml"), 'r')) weight = base_dir + "/saved/YOPO_{}/epoch{}.pth".format(args.trial, args.epoch) print("Loading Network...") device = "cuda" if torch.cuda.is_available() else "cpu" state_dict = torch.load(weight, weights_only=True) - policy = YopoNetwork(horizon_num=cfg["horizon_num"], vertical_num=cfg["vertical_num"]) + policy = YopoNetwork() policy.load_state_dict(state_dict) policy = policy.to(device) policy.eval() @@ -77,6 +76,3 @@ def main(): f"Transfer Trajectory Error: {traj_error.item():.6f}," f"Transfer Score Error: {score_error.item():.6f}") - -if __name__ == "__main__": - main()