Modify to global config and simplified speed adjustment during training and testing
This commit is contained in:
parent
788e9bc979
commit
bd94cfbd51
0
YOPO/config/__init__.py
Normal file
0
YOPO/config/__init__.py
Normal file
22
YOPO/config/config.py
Normal file
22
YOPO/config/config.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import os
|
||||||
|
from ruamel.yaml import YAML
|
||||||
|
|
||||||
|
|
||||||
|
# Global Configuration Management
|
||||||
|
class Config:
|
||||||
|
def __init__(self):
|
||||||
|
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
self._data = YAML().load(open(os.path.join(base_dir, "traj_opt.yaml"), 'r'))
|
||||||
|
self._data["train"] = True
|
||||||
|
self._data["goal_length"] = 2.0 * self._data['radio_range']
|
||||||
|
self._data["sgm_time"] = 2 * self._data["radio_range"] / self._data["vel_max_train"]
|
||||||
|
self._data["traj_num"] = self._data['horizon_num'] * self._data['vertical_num'] * self._data["radio_num"]
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self._data[key]
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
self._data[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
cfg = Config()
|
||||||
@ -1,11 +1,12 @@
|
|||||||
# IMPORTANT PARAM: actual velocity in training / testing
|
# IMPORTANT: velocity in testing (modifiable)
|
||||||
velocity: 6.0
|
velocity: 6.0
|
||||||
# used to align the vel/acc, ensure consistency between testing and training.
|
|
||||||
# during testing, if vel*n then acc*n*n, please refer to ../policy/primitive.py
|
|
||||||
vel_align: 6.0
|
|
||||||
acc_align: 6.0
|
|
||||||
|
|
||||||
# IMPORTANT PARAM: weight of penalties (for unit speed)
|
# IMPORTANT: vel_max and acc_max in training
|
||||||
|
# not the actual values in testing, ensure consistency between testing and training, please refer to ../policy/primitive.py
|
||||||
|
vel_max_train: 6.0
|
||||||
|
acc_max_train: 6.0
|
||||||
|
|
||||||
|
# IMPORTANT: weight of costs for unit speed (can be visualized in tensorboard)
|
||||||
wg: 0.1 # guidance
|
wg: 0.1 # guidance
|
||||||
ws: 10.0 # smoothness
|
ws: 10.0 # smoothness
|
||||||
wc: 0.1 # collision
|
wc: 0.1 # collision
|
||||||
@ -15,7 +16,7 @@ dataset_path: "../dataset"
|
|||||||
image_height: 96
|
image_height: 96
|
||||||
image_width: 160
|
image_width: 160
|
||||||
|
|
||||||
# trajectory and primitive param (Note: traj_time = 2 * radio / vel_max)
|
# trajectory and primitive param (Note: traj_time = 2 * radio / vel_max) (can be visualized in rviz)
|
||||||
horizon_num: 5
|
horizon_num: 5
|
||||||
vertical_num: 3
|
vertical_num: 3
|
||||||
horizon_camera_fov: 90.0
|
horizon_camera_fov: 90.0
|
||||||
@ -29,7 +30,7 @@ radio_num: 1 # only support 1 currently
|
|||||||
d0: 1.2
|
d0: 1.2
|
||||||
r: 0.6
|
r: 0.6
|
||||||
|
|
||||||
# distribution parameters for unit state sampling
|
# distribution for unit state sampling (can be visualized by 'python ./policy/yopo_dataset.py')
|
||||||
vx_mean_unit: 0.4
|
vx_mean_unit: 0.4
|
||||||
vy_mean_unit: 0.0
|
vy_mean_unit: 0.0
|
||||||
vz_mean_unit: 0.0
|
vz_mean_unit: 0.0
|
||||||
|
|||||||
@ -1,16 +1,13 @@
|
|||||||
import os
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch as th
|
import torch as th
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from ruamel.yaml import YAML
|
from config.config import cfg
|
||||||
|
|
||||||
|
|
||||||
class GuidanceLoss(nn.Module):
|
class GuidanceLoss(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(GuidanceLoss, self).__init__()
|
super(GuidanceLoss, self).__init__()
|
||||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
self.goal_length = cfg['goal_length']
|
||||||
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
|
|
||||||
self.goal_length = 2.0 * cfg['radio_range']
|
|
||||||
|
|
||||||
def forward(self, Df, Dp, goal):
|
def forward(self, Df, Dp, goal):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
import os
|
|
||||||
import math
|
import math
|
||||||
import torch as th
|
import torch as th
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from ruamel.yaml import YAML
|
from config.config import cfg
|
||||||
from loss.safety_loss import SafetyLoss
|
from loss.safety_loss import SafetyLoss
|
||||||
from loss.smoothness_loss import SmoothnessLoss
|
from loss.smoothness_loss import SmoothnessLoss
|
||||||
from loss.guidance_loss import GuidanceLoss
|
from loss.guidance_loss import GuidanceLoss
|
||||||
@ -17,20 +16,18 @@ class YOPOLoss(nn.Module):
|
|||||||
df: fixed parameters
|
df: fixed parameters
|
||||||
"""
|
"""
|
||||||
super(YOPOLoss, self).__init__()
|
super(YOPOLoss, self).__init__()
|
||||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
self.sgm_time = cfg["sgm_time"]
|
||||||
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
|
|
||||||
self.sgm_time = 2 * cfg["radio_range"] / cfg["velocity"]
|
|
||||||
self.device = th.device("cuda" if th.cuda.is_available() else "cpu")
|
self.device = th.device("cuda" if th.cuda.is_available() else "cpu")
|
||||||
self._C, self._B, self._L, self._R = self.qp_generation()
|
self._C, self._B, self._L, self._R = self.qp_generation()
|
||||||
self._R = self._R.to(self.device)
|
self._R = self._R.to(self.device)
|
||||||
self._L = self._L.to(self.device)
|
self._L = self._L.to(self.device)
|
||||||
vel_scale = cfg["velocity"] / 1.0
|
vel_scale = cfg["vel_max_train"] / 1.0
|
||||||
self.smoothness_weight = cfg["ws"]
|
self.smoothness_weight = cfg["ws"]
|
||||||
self.safety_weight = cfg["wc"]
|
self.safety_weight = cfg["wc"]
|
||||||
self.goal_weight = cfg["wg"]
|
self.goal_weight = cfg["wg"]
|
||||||
self.denormalize_weight(vel_scale)
|
self.denormalize_weight(vel_scale)
|
||||||
self.smoothness_loss = SmoothnessLoss(self._R)
|
self.smoothness_loss = SmoothnessLoss(self._R)
|
||||||
self.safety_loss = SafetyLoss(self._L, self.sgm_time)
|
self.safety_loss = SafetyLoss(self._L)
|
||||||
self.goal_loss = GuidanceLoss()
|
self.goal_loss = GuidanceLoss()
|
||||||
print("------ Actual Loss ------")
|
print("------ Actual Loss ------")
|
||||||
print(f"| {'smooth':<12} = {self.smoothness_weight:6.4f} |")
|
print(f"| {'smooth':<12} = {self.smoothness_weight:6.4f} |")
|
||||||
|
|||||||
@ -5,23 +5,21 @@ import torch as th
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import open3d as o3d
|
import open3d as o3d
|
||||||
from ruamel.yaml import YAML
|
|
||||||
from scipy.ndimage import distance_transform_edt
|
from scipy.ndimage import distance_transform_edt
|
||||||
|
from config.config import cfg
|
||||||
|
|
||||||
|
|
||||||
class SafetyLoss(nn.Module):
|
class SafetyLoss(nn.Module):
|
||||||
def __init__(self, L, sgm_time):
|
def __init__(self, L):
|
||||||
super(SafetyLoss, self).__init__()
|
super(SafetyLoss, self).__init__()
|
||||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
self.traj_num = cfg['traj_num']
|
||||||
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
|
|
||||||
self.traj_num = cfg['horizon_num'] * cfg['vertical_num']
|
|
||||||
self.map_expand_min = np.array(cfg['map_expand_min'])
|
self.map_expand_min = np.array(cfg['map_expand_min'])
|
||||||
self.map_expand_max = np.array(cfg['map_expand_max'])
|
self.map_expand_max = np.array(cfg['map_expand_max'])
|
||||||
self.d0 = cfg["d0"]
|
self.d0 = cfg["d0"]
|
||||||
self.r = cfg["r"]
|
self.r = cfg["r"]
|
||||||
|
|
||||||
self._L = L
|
self._L = L
|
||||||
self.sgm_time = sgm_time
|
self.sgm_time = cfg["sgm_time"]
|
||||||
self.eval_points = 30
|
self.eval_points = 30
|
||||||
self.device = self._L.device
|
self.device = self._L.device
|
||||||
|
|
||||||
@ -31,6 +29,7 @@ class SafetyLoss(nn.Module):
|
|||||||
self.max_bounds = None # shape: (N, 3)
|
self.max_bounds = None # shape: (N, 3)
|
||||||
self.sdf_shapes = None # shape: (1, 3)
|
self.sdf_shapes = None # shape: (1, 3)
|
||||||
print("Building ESDF map...")
|
print("Building ESDF map...")
|
||||||
|
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
data_dir = os.path.join(base_dir, "../", cfg["dataset_path"])
|
data_dir = os.path.join(base_dir, "../", cfg["dataset_path"])
|
||||||
self.sdf_maps = self.get_sdf_from_ply(data_dir)
|
self.sdf_maps = self.get_sdf_from_ply(data_dir)
|
||||||
print("Map built!")
|
print("Map built!")
|
||||||
|
|||||||
@ -1,16 +1,18 @@
|
|||||||
import torch
|
import torch
|
||||||
from scipy.spatial.transform import Rotation as R
|
from scipy.spatial.transform import Rotation as R
|
||||||
|
from config.config import cfg
|
||||||
|
|
||||||
|
|
||||||
class LatticeParam:
|
class LatticeParam:
|
||||||
def __init__(self, cfg):
|
def __init__(self):
|
||||||
ratio = cfg["velocity"] / cfg["vel_align"]
|
ratio = 1.0 if cfg["train"] else cfg["velocity"] / cfg["vel_max_train"]
|
||||||
self.vel_max = ratio * cfg["vel_align"]
|
self.vel_max = ratio * cfg["vel_max_train"]
|
||||||
self.acc_max = ratio * ratio * cfg["acc_align"]
|
self.acc_max = ratio * ratio * cfg["acc_max_train"]
|
||||||
self.segment_time = 2 * cfg["radio_range"] / self.vel_max
|
self.segment_time = cfg["sgm_time"] / ratio
|
||||||
self.horizon_num = cfg["horizon_num"]
|
self.horizon_num = cfg["horizon_num"]
|
||||||
self.vertical_num = cfg["vertical_num"]
|
self.vertical_num = cfg["vertical_num"]
|
||||||
self.radio_num = cfg["radio_num"]
|
self.radio_num = cfg["radio_num"]
|
||||||
|
self.traj_num = cfg["traj_num"]
|
||||||
self.horizon_fov = cfg["horizon_camera_fov"]
|
self.horizon_fov = cfg["horizon_camera_fov"]
|
||||||
self.vertical_fov = cfg["vertical_camera_fov"]
|
self.vertical_fov = cfg["vertical_camera_fov"]
|
||||||
self.horizon_anchor_fov = cfg["horizon_anchor_fov"]
|
self.horizon_anchor_fov = cfg["horizon_anchor_fov"]
|
||||||
@ -38,12 +40,10 @@ class LatticePrimitive(LatticeParam):
|
|||||||
"""
|
"""
|
||||||
_instance = None
|
_instance = None
|
||||||
|
|
||||||
def __init__(self, cfg):
|
def __init__(self):
|
||||||
super().__init__(cfg)
|
super().__init__()
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
self.traj_num = self.vertical_num * self.horizon_num * self.radio_num
|
|
||||||
|
|
||||||
if self.horizon_num == 1:
|
if self.horizon_num == 1:
|
||||||
direction_diff = 0
|
direction_diff = 0
|
||||||
else:
|
else:
|
||||||
@ -105,16 +105,6 @@ class LatticePrimitive(LatticeParam):
|
|||||||
return self.traj_num - id - 1
|
return self.traj_num - id - 1
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(self, cfg):
|
def get_instance(self):
|
||||||
if self._instance is None: self._instance = self(cfg)
|
if self._instance is None: self._instance = self()
|
||||||
return self._instance
|
return self._instance
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
import os
|
|
||||||
from ruamel.yaml import YAML
|
|
||||||
|
|
||||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
|
|
||||||
lattice_primitive = LatticePrimitive.get_instance(cfg)
|
|
||||||
print(lattice_primitive.getStateLattice(list(range(lattice_primitive.traj_num))))
|
|
||||||
|
|||||||
@ -1,16 +1,13 @@
|
|||||||
import os
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ruamel.yaml import YAML
|
from config.config import cfg
|
||||||
from policy.primitive import LatticePrimitive
|
from policy.primitive import LatticePrimitive
|
||||||
|
|
||||||
|
|
||||||
class StateTransform:
|
class StateTransform:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
self.lattice_primitive = LatticePrimitive.get_instance()
|
||||||
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
|
self.goal_length = cfg['goal_length']
|
||||||
self.lattice_primitive = LatticePrimitive.get_instance(cfg)
|
|
||||||
self.goal_length = 2.0 * cfg['radio_range']
|
|
||||||
|
|
||||||
def pred_to_endstate(self, endstate_pred: torch.Tensor) -> torch.Tensor:
|
def pred_to_endstate(self, endstate_pred: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,38 +1,37 @@
|
|||||||
import os
|
import os, sys
|
||||||
import cv2
|
import cv2
|
||||||
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
from ruamel.yaml import YAML
|
|
||||||
import time
|
|
||||||
from scipy.spatial.transform import Rotation as R
|
from scipy.spatial.transform import Rotation as R
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
from config.config import cfg
|
||||||
|
|
||||||
|
|
||||||
class YOPODataset(Dataset):
|
class YOPODataset(Dataset):
|
||||||
def __init__(self, mode='train', val_ratio=0.1):
|
def __init__(self, mode='train', val_ratio=0.1):
|
||||||
super(YOPODataset, self).__init__()
|
super(YOPODataset, self).__init__()
|
||||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
|
|
||||||
# image params
|
# image params
|
||||||
self.height = int(cfg["image_height"])
|
self.height = int(cfg["image_height"])
|
||||||
self.width = int(cfg["image_width"])
|
self.width = int(cfg["image_width"])
|
||||||
# ramdom state: x-direction: log-normal distribution, yz-direction: normal distribution
|
# ramdom state: x-direction: log-normal distribution, yz-direction: normal distribution
|
||||||
scale = cfg["velocity"] / cfg["vel_align"]
|
self.vel_max = cfg["vel_max_train"]
|
||||||
self.vel_max = scale * cfg["vel_align"]
|
self.acc_max = cfg["acc_max_train"]
|
||||||
self.acc_max = scale * scale * cfg["acc_align"]
|
|
||||||
self.vx_lognorm_mean = np.log(1 - cfg["vx_mean_unit"])
|
self.vx_lognorm_mean = np.log(1 - cfg["vx_mean_unit"])
|
||||||
self.vx_logmorm_sigma = np.log(cfg["vx_std_unit"])
|
self.vx_logmorm_sigma = np.log(cfg["vx_std_unit"])
|
||||||
self.v_mean = np.array([cfg["vx_mean_unit"], cfg["vy_mean_unit"], cfg["vz_mean_unit"]])
|
self.v_mean = np.array([cfg["vx_mean_unit"], cfg["vy_mean_unit"], cfg["vz_mean_unit"]])
|
||||||
self.v_std = np.array([cfg["vx_std_unit"], cfg["vy_std_unit"], cfg["vz_std_unit"]])
|
self.v_std = np.array([cfg["vx_std_unit"], cfg["vy_std_unit"], cfg["vz_std_unit"]])
|
||||||
self.a_mean = np.array([cfg["ax_mean_unit"], cfg["ay_mean_unit"], cfg["az_mean_unit"]])
|
self.a_mean = np.array([cfg["ax_mean_unit"], cfg["ay_mean_unit"], cfg["az_mean_unit"]])
|
||||||
self.a_std = np.array([cfg["ax_std_unit"], cfg["ay_std_unit"], cfg["az_std_unit"]])
|
self.a_std = np.array([cfg["ax_std_unit"], cfg["ay_std_unit"], cfg["az_std_unit"]])
|
||||||
self.goal_length = 2.0 * cfg['radio_range']
|
self.goal_length = cfg['goal_length']
|
||||||
self.goal_pitch_std = cfg["goal_pitch_std"]
|
self.goal_pitch_std = cfg["goal_pitch_std"]
|
||||||
self.goal_yaw_std = cfg["goal_yaw_std"]
|
self.goal_yaw_std = cfg["goal_yaw_std"]
|
||||||
if mode == 'train': self.print_data()
|
if mode == 'train': self.print_data()
|
||||||
|
|
||||||
# dataset
|
# dataset
|
||||||
print("Loading", mode, "dataset, it may take a while...")
|
print("Loading", mode, "dataset, it may take a while...")
|
||||||
|
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
data_dir = os.path.join(base_dir, "../", cfg["dataset_path"])
|
data_dir = os.path.join(base_dir, "../", cfg["dataset_path"])
|
||||||
self.img_list, self.map_idx, self.positions, self.quaternions = [], [], np.empty((0, 3), dtype=np.float32), np.empty((0, 4), dtype=np.float32)
|
self.img_list, self.map_idx, self.positions, self.quaternions = [], [], np.empty((0, 3), dtype=np.float32), np.empty((0, 4), dtype=np.float32)
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
Training Strategy
|
Training Strategy
|
||||||
supervised learning, imitation learning, testing, rollout
|
supervised learning, imitation learning, testing, rollout
|
||||||
"""
|
"""
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
import atexit
|
import atexit
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
@ -9,6 +10,7 @@ from rich.progress import Progress
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.tensorboard.writer import SummaryWriter
|
from torch.utils.tensorboard.writer import SummaryWriter
|
||||||
|
|
||||||
|
from config.config import cfg
|
||||||
from loss.loss_function import YOPOLoss
|
from loss.loss_function import YOPOLoss
|
||||||
from policy.yopo_network import YopoNetwork
|
from policy.yopo_network import YopoNetwork
|
||||||
from policy.yopo_dataset import YOPODataset
|
from policy.yopo_dataset import YOPODataset
|
||||||
@ -35,10 +37,7 @@ class YopoTrainer:
|
|||||||
self.tensorboard_path = self.get_next_log_path(tensorboard_path)
|
self.tensorboard_path = self.get_next_log_path(tensorboard_path)
|
||||||
self.tensorboard_log = SummaryWriter(log_dir=self.tensorboard_path)
|
self.tensorboard_log = SummaryWriter(log_dir=self.tensorboard_path)
|
||||||
# params
|
# params
|
||||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
self.traj_num = cfg['traj_num']
|
||||||
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
|
|
||||||
self.lattice_primitive = LatticePrimitive.get_instance(cfg)
|
|
||||||
self.traj_num = self.lattice_primitive.traj_num
|
|
||||||
|
|
||||||
# loss
|
# loss
|
||||||
self.yopo_loss = YOPOLoss()
|
self.yopo_loss = YOPOLoss()
|
||||||
|
|||||||
@ -13,9 +13,9 @@ import time
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import argparse
|
import argparse
|
||||||
from ruamel.yaml import YAML
|
|
||||||
from scipy.spatial.transform import Rotation as R
|
from scipy.spatial.transform import Rotation as R
|
||||||
|
|
||||||
|
from config.config import cfg
|
||||||
from control_msg import PositionCommand
|
from control_msg import PositionCommand
|
||||||
from policy.yopo_network import YopoNetwork
|
from policy.yopo_network import YopoNetwork
|
||||||
from policy.poly_solver import *
|
from policy.poly_solver import *
|
||||||
@ -32,8 +32,7 @@ class YopoNet:
|
|||||||
self.config = config
|
self.config = config
|
||||||
rospy.init_node('yopo_net', anonymous=False)
|
rospy.init_node('yopo_net', anonymous=False)
|
||||||
# load params
|
# load params
|
||||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
cfg["train"] = False
|
||||||
cfg = YAML().load(open(os.path.join(base_dir, "config/traj_opt.yaml"), 'r'))
|
|
||||||
self.height = cfg['image_height']
|
self.height = cfg['image_height']
|
||||||
self.width = cfg['image_width']
|
self.width = cfg['image_width']
|
||||||
self.min_dis, self.max_dis = 0.04, 20.0
|
self.min_dis, self.max_dis = 0.04, 20.0
|
||||||
@ -64,7 +63,7 @@ class YopoNet:
|
|||||||
self.lock = Lock()
|
self.lock = Lock()
|
||||||
self.last_control_msg = None
|
self.last_control_msg = None
|
||||||
self.state_transform = StateTransform()
|
self.state_transform = StateTransform()
|
||||||
self.lattice_primitive = LatticePrimitive.get_instance(cfg)
|
self.lattice_primitive = LatticePrimitive.get_instance()
|
||||||
self.traj_time = self.lattice_primitive.segment_time
|
self.traj_time = self.lattice_primitive.segment_time
|
||||||
|
|
||||||
# eval
|
# eval
|
||||||
@ -354,7 +353,7 @@ def parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def main():
|
if __name__ == "__main__":
|
||||||
args = parser().parse_args()
|
args = parser().parse_args()
|
||||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
weight = "yopo_trt.pth" if args.use_tensorrt else base_dir + "/saved/YOPO_{}/epoch{}.pth".format(args.trial, args.epoch)
|
weight = "yopo_trt.pth" if args.use_tensorrt else base_dir + "/saved/YOPO_{}/epoch{}.pth".format(args.trial, args.epoch)
|
||||||
@ -372,7 +371,3 @@ def main():
|
|||||||
'visualize': True # 可视化所有轨迹?(实飞改为False节省计算)
|
'visualize': True # 可视化所有轨迹?(实飞改为False节省计算)
|
||||||
}
|
}
|
||||||
YopoNet(settings, weight)
|
YopoNet(settings, weight)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|||||||
@ -24,10 +24,9 @@ def parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def main():
|
if __name__ == "__main__":
|
||||||
args = parser().parse_args()
|
args = parser().parse_args()
|
||||||
# set random seed
|
configure_random_seed(0) # set random seed
|
||||||
configure_random_seed(0)
|
|
||||||
|
|
||||||
# save the configuration and other files
|
# save the configuration and other files
|
||||||
log_dir = os.path.dirname(os.path.abspath(__file__)) + "/saved"
|
log_dir = os.path.dirname(os.path.abspath(__file__)) + "/saved"
|
||||||
@ -46,7 +45,3 @@ def main():
|
|||||||
trainer.train(epoch=50)
|
trainer.train(epoch=50)
|
||||||
|
|
||||||
print("Run YOPO Finish!")
|
print("Run YOPO Finish!")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|||||||
@ -1,12 +1,13 @@
|
|||||||
Panels:
|
Panels:
|
||||||
- Class: rviz/Displays
|
- Class: rviz/Displays
|
||||||
Help Height: 78
|
Help Height: 138
|
||||||
Name: Displays
|
Name: Displays
|
||||||
Property Tree Widget:
|
Property Tree Widget:
|
||||||
Expanded:
|
Expanded:
|
||||||
- /PointCloud21/Autocompute Value Bounds1
|
- /Map1/Autocompute Value Bounds1
|
||||||
Splitter Ratio: 0.5
|
- /Trajectory1
|
||||||
Tree Height: 326
|
Splitter Ratio: 0.6625221967697144
|
||||||
|
Tree Height: 476
|
||||||
- Class: rviz/Selection
|
- Class: rviz/Selection
|
||||||
Name: Selection
|
Name: Selection
|
||||||
- Class: rviz/Tool Properties
|
- Class: rviz/Tool Properties
|
||||||
@ -24,7 +25,7 @@ Panels:
|
|||||||
- Class: rviz/Time
|
- Class: rviz/Time
|
||||||
Name: Time
|
Name: Time
|
||||||
SyncMode: 0
|
SyncMode: 0
|
||||||
SyncSource: Image
|
SyncSource: Depth
|
||||||
Preferences:
|
Preferences:
|
||||||
PromptSaveOnExit: true
|
PromptSaveOnExit: true
|
||||||
Toolbars:
|
Toolbars:
|
||||||
@ -56,7 +57,7 @@ Visualization Manager:
|
|||||||
Max Value: 1
|
Max Value: 1
|
||||||
Median window: 5
|
Median window: 5
|
||||||
Min Value: 0
|
Min Value: 0
|
||||||
Name: Image
|
Name: Depth
|
||||||
Normalize Range: true
|
Normalize Range: true
|
||||||
Queue Size: 2
|
Queue Size: 2
|
||||||
Transport Hint: raw
|
Transport Hint: raw
|
||||||
@ -78,7 +79,7 @@ Visualization Manager:
|
|||||||
Invert Rainbow: false
|
Invert Rainbow: false
|
||||||
Max Color: 255; 255; 255
|
Max Color: 255; 255; 255
|
||||||
Min Color: 0; 0; 0
|
Min Color: 0; 0; 0
|
||||||
Name: PointCloud2
|
Name: Map
|
||||||
Position Transformer: XYZ
|
Position Transformer: XYZ
|
||||||
Queue Size: 10
|
Queue Size: 10
|
||||||
Selectable: true
|
Selectable: true
|
||||||
@ -108,7 +109,7 @@ Visualization Manager:
|
|||||||
Invert Rainbow: false
|
Invert Rainbow: false
|
||||||
Max Color: 255; 255; 255
|
Max Color: 255; 255; 255
|
||||||
Min Color: 0; 0; 0
|
Min Color: 0; 0; 0
|
||||||
Name: PointCloud2
|
Name: Best_Traj
|
||||||
Position Transformer: XYZ
|
Position Transformer: XYZ
|
||||||
Queue Size: 10
|
Queue Size: 10
|
||||||
Selectable: true
|
Selectable: true
|
||||||
@ -136,7 +137,7 @@ Visualization Manager:
|
|||||||
Invert Rainbow: true
|
Invert Rainbow: true
|
||||||
Max Color: 255; 255; 255
|
Max Color: 255; 255; 255
|
||||||
Min Color: 0; 0; 0
|
Min Color: 0; 0; 0
|
||||||
Name: PointCloud2
|
Name: All_traj
|
||||||
Position Transformer: XYZ
|
Position Transformer: XYZ
|
||||||
Queue Size: 10
|
Queue Size: 10
|
||||||
Selectable: true
|
Selectable: true
|
||||||
@ -164,7 +165,7 @@ Visualization Manager:
|
|||||||
Invert Rainbow: false
|
Invert Rainbow: false
|
||||||
Max Color: 255; 255; 255
|
Max Color: 255; 255; 255
|
||||||
Min Color: 0; 0; 0
|
Min Color: 0; 0; 0
|
||||||
Name: PointCloud2
|
Name: Primitive_Traj
|
||||||
Position Transformer: XYZ
|
Position Transformer: XYZ
|
||||||
Queue Size: 10
|
Queue Size: 10
|
||||||
Selectable: true
|
Selectable: true
|
||||||
@ -177,11 +178,11 @@ Visualization Manager:
|
|||||||
Use rainbow: true
|
Use rainbow: true
|
||||||
Value: false
|
Value: false
|
||||||
Enabled: true
|
Enabled: true
|
||||||
Name: Traj
|
Name: Trajectory
|
||||||
- Class: rviz/Marker
|
- Class: rviz/Marker
|
||||||
Enabled: true
|
Enabled: true
|
||||||
Marker Topic: /quadrotor_simulator_so3/uav
|
Marker Topic: /quadrotor_simulator_so3/uav
|
||||||
Name: Marker
|
Name: Drone
|
||||||
Namespaces:
|
Namespaces:
|
||||||
mesh: true
|
mesh: true
|
||||||
Queue Size: 100
|
Queue Size: 100
|
||||||
@ -235,14 +236,14 @@ Visualization Manager:
|
|||||||
Yaw: 3.140000104904175
|
Yaw: 3.140000104904175
|
||||||
Saved: ~
|
Saved: ~
|
||||||
Window Geometry:
|
Window Geometry:
|
||||||
|
Depth:
|
||||||
|
collapsed: false
|
||||||
Displays:
|
Displays:
|
||||||
collapsed: false
|
collapsed: false
|
||||||
Height: 1016
|
Height: 1600
|
||||||
Hide Left Dock: false
|
Hide Left Dock: false
|
||||||
Hide Right Dock: true
|
Hide Right Dock: true
|
||||||
Image:
|
QMainWindow State: 000000ff00000000fd0000000400000000000003310000053afc0200000009fb0000001200530065006c0065006300740069006f006e00000001e10000009b000000b000fffffffb0000001e0054006f006f006c002000500072006f007000650072007400690065007302000001ed000001df00000185000000b0fb000000120056006900650077007300200054006f006f02000001df000002110000018500000122fb000000200054006f006f006c002000500072006f0070006500720074006900650073003203000002880000011d000002210000017afb000000100044006900730070006c006100790073010000006e000002d40000018200fffffffb0000002000730065006c0065006300740069006f006e00200062007500660066006500720200000138000000aa0000023a00000294fb00000014005700690064006500530074006500720065006f02000000e6000000d2000003ee0000030bfb0000000c004b0069006e0065006300740200000186000001060000030c00000261fb0000000a00440065007000740068010000034e0000025a0000002600ffffff00000001000001b90000035afc0200000003fb0000001e0054006f006f006c002000500072006f00700065007200740069006500730100000041000000780000000000000000fb0000000a00560069006500770073000000003d0000035a0000013200fffffffb0000001200530065006c0065006300740069006f006e010000025a000000b200000000000000000000000200000490000000a9fc0100000001fb0000000a00560069006500770073030000004e00000080000002e1000001970000000300000b700000005efc0100000002fb0000000800540069006d0065010000000000000b70000006dc00fffffffb0000000800540069006d00650100000000000004500000000000000000000008330000053a00000004000000040000000800000008fc0000000100000002000000010000000a0054006f006f006c00730100000000ffffffff0000000000000000
|
||||||
collapsed: false
|
|
||||||
QMainWindow State: 000000ff00000000fd0000000400000000000002350000035afc0200000009fb0000001200530065006c0065006300740069006f006e00000001e10000009b0000005c00fffffffb0000001e0054006f006f006c002000500072006f007000650072007400690065007302000001ed000001df00000185000000a3fb000000120056006900650077007300200054006f006f02000001df000002110000018500000122fb000000200054006f006f006c002000500072006f0070006500720074006900650073003203000002880000011d000002210000017afb000000100044006900730070006c006100790073010000003d000001d1000000c900fffffffb0000002000730065006c0065006300740069006f006e00200062007500660066006500720200000138000000aa0000023a00000294fb00000014005700690064006500530074006500720065006f02000000e6000000d2000003ee0000030bfb0000000c004b0069006e0065006300740200000186000001060000030c00000261fb0000000a0049006d0061006700650100000214000001830000001600ffffff00000001000001b90000035afc0200000003fb0000001e0054006f006f006c002000500072006f00700065007200740069006500730100000041000000780000000000000000fb0000000a00560069006500770073000000003d0000035a000000a400fffffffb0000001200530065006c0065006300740069006f006e010000025a000000b200000000000000000000000200000490000000a9fc0100000001fb0000000a00560069006500770073030000004e00000080000002e10000019700000003000007380000003efc0100000002fb0000000800540069006d0065010000000000000738000003bc00fffffffb0000000800540069006d00650100000000000004500000000000000000000004fd0000035a00000004000000040000000800000008fc0000000100000002000000010000000a0054006f006f006c00730100000000ffffffff0000000000000000
|
|
||||||
Selection:
|
Selection:
|
||||||
collapsed: false
|
collapsed: false
|
||||||
Time:
|
Time:
|
||||||
@ -251,6 +252,6 @@ Window Geometry:
|
|||||||
collapsed: false
|
collapsed: false
|
||||||
Views:
|
Views:
|
||||||
collapsed: true
|
collapsed: true
|
||||||
Width: 1848
|
Width: 2928
|
||||||
X: 72
|
X: 144
|
||||||
Y: 387
|
Y: 54
|
||||||
|
|||||||
@ -7,13 +7,13 @@
|
|||||||
python setup.py install
|
python setup.py install
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
|
||||||
import os
|
import os
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch2trt import torch2trt
|
from torch2trt import torch2trt
|
||||||
from ruamel.yaml import YAML
|
from config.config import cfg
|
||||||
import time
|
|
||||||
from policy.yopo_network import YopoNetwork
|
from policy.yopo_network import YopoNetwork
|
||||||
|
|
||||||
|
|
||||||
@ -25,16 +25,15 @@ def parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def main():
|
if __name__ == "__main__":
|
||||||
args = parser().parse_args()
|
args = parser().parse_args()
|
||||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
cfg = YAML().load(open(os.path.join(base_dir, "config/traj_opt.yaml"), 'r'))
|
|
||||||
weight = base_dir + "/saved/YOPO_{}/epoch{}.pth".format(args.trial, args.epoch)
|
weight = base_dir + "/saved/YOPO_{}/epoch{}.pth".format(args.trial, args.epoch)
|
||||||
|
|
||||||
print("Loading Network...")
|
print("Loading Network...")
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
state_dict = torch.load(weight, weights_only=True)
|
state_dict = torch.load(weight, weights_only=True)
|
||||||
policy = YopoNetwork(horizon_num=cfg["horizon_num"], vertical_num=cfg["vertical_num"])
|
policy = YopoNetwork()
|
||||||
policy.load_state_dict(state_dict)
|
policy.load_state_dict(state_dict)
|
||||||
policy = policy.to(device)
|
policy = policy.to(device)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
@ -77,6 +76,3 @@ def main():
|
|||||||
f"Transfer Trajectory Error: {traj_error.item():.6f},"
|
f"Transfer Trajectory Error: {traj_error.item():.6f},"
|
||||||
f"Transfer Score Error: {score_error.item():.6f}")
|
f"Transfer Score Error: {score_error.item():.6f}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user