Modify to global config and simplified speed adjustment during training and testing
This commit is contained in:
parent
788e9bc979
commit
bd94cfbd51
0
YOPO/config/__init__.py
Normal file
0
YOPO/config/__init__.py
Normal file
22
YOPO/config/config.py
Normal file
22
YOPO/config/config.py
Normal file
@ -0,0 +1,22 @@
|
||||
import os
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
|
||||
# Global Configuration Management
|
||||
class Config:
|
||||
def __init__(self):
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
self._data = YAML().load(open(os.path.join(base_dir, "traj_opt.yaml"), 'r'))
|
||||
self._data["train"] = True
|
||||
self._data["goal_length"] = 2.0 * self._data['radio_range']
|
||||
self._data["sgm_time"] = 2 * self._data["radio_range"] / self._data["vel_max_train"]
|
||||
self._data["traj_num"] = self._data['horizon_num'] * self._data['vertical_num'] * self._data["radio_num"]
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._data[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self._data[key] = value
|
||||
|
||||
|
||||
cfg = Config()
|
||||
@ -1,11 +1,12 @@
|
||||
# IMPORTANT PARAM: actual velocity in training / testing
|
||||
# IMPORTANT: velocity in testing (modifiable)
|
||||
velocity: 6.0
|
||||
# used to align the vel/acc, ensure consistency between testing and training.
|
||||
# during testing, if vel*n then acc*n*n, please refer to ../policy/primitive.py
|
||||
vel_align: 6.0
|
||||
acc_align: 6.0
|
||||
|
||||
# IMPORTANT PARAM: weight of penalties (for unit speed)
|
||||
# IMPORTANT: vel_max and acc_max in training
|
||||
# not the actual values in testing, ensure consistency between testing and training, please refer to ../policy/primitive.py
|
||||
vel_max_train: 6.0
|
||||
acc_max_train: 6.0
|
||||
|
||||
# IMPORTANT: weight of costs for unit speed (can be visualized in tensorboard)
|
||||
wg: 0.1 # guidance
|
||||
ws: 10.0 # smoothness
|
||||
wc: 0.1 # collision
|
||||
@ -15,7 +16,7 @@ dataset_path: "../dataset"
|
||||
image_height: 96
|
||||
image_width: 160
|
||||
|
||||
# trajectory and primitive param (Note: traj_time = 2 * radio / vel_max)
|
||||
# trajectory and primitive param (Note: traj_time = 2 * radio / vel_max) (can be visualized in rviz)
|
||||
horizon_num: 5
|
||||
vertical_num: 3
|
||||
horizon_camera_fov: 90.0
|
||||
@ -29,7 +30,7 @@ radio_num: 1 # only support 1 currently
|
||||
d0: 1.2
|
||||
r: 0.6
|
||||
|
||||
# distribution parameters for unit state sampling
|
||||
# distribution for unit state sampling (can be visualized by 'python ./policy/yopo_dataset.py')
|
||||
vx_mean_unit: 0.4
|
||||
vy_mean_unit: 0.0
|
||||
vz_mean_unit: 0.0
|
||||
|
||||
@ -1,16 +1,13 @@
|
||||
import os
|
||||
import torch.nn as nn
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
from ruamel.yaml import YAML
|
||||
from config.config import cfg
|
||||
|
||||
|
||||
class GuidanceLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super(GuidanceLoss, self).__init__()
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
|
||||
self.goal_length = 2.0 * cfg['radio_range']
|
||||
self.goal_length = cfg['goal_length']
|
||||
|
||||
def forward(self, Df, Dp, goal):
|
||||
"""
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
import os
|
||||
import math
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
from ruamel.yaml import YAML
|
||||
from config.config import cfg
|
||||
from loss.safety_loss import SafetyLoss
|
||||
from loss.smoothness_loss import SmoothnessLoss
|
||||
from loss.guidance_loss import GuidanceLoss
|
||||
@ -17,20 +16,18 @@ class YOPOLoss(nn.Module):
|
||||
df: fixed parameters
|
||||
"""
|
||||
super(YOPOLoss, self).__init__()
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
|
||||
self.sgm_time = 2 * cfg["radio_range"] / cfg["velocity"]
|
||||
self.sgm_time = cfg["sgm_time"]
|
||||
self.device = th.device("cuda" if th.cuda.is_available() else "cpu")
|
||||
self._C, self._B, self._L, self._R = self.qp_generation()
|
||||
self._R = self._R.to(self.device)
|
||||
self._L = self._L.to(self.device)
|
||||
vel_scale = cfg["velocity"] / 1.0
|
||||
vel_scale = cfg["vel_max_train"] / 1.0
|
||||
self.smoothness_weight = cfg["ws"]
|
||||
self.safety_weight = cfg["wc"]
|
||||
self.goal_weight = cfg["wg"]
|
||||
self.denormalize_weight(vel_scale)
|
||||
self.smoothness_loss = SmoothnessLoss(self._R)
|
||||
self.safety_loss = SafetyLoss(self._L, self.sgm_time)
|
||||
self.safety_loss = SafetyLoss(self._L)
|
||||
self.goal_loss = GuidanceLoss()
|
||||
print("------ Actual Loss ------")
|
||||
print(f"| {'smooth':<12} = {self.smoothness_weight:6.4f} |")
|
||||
|
||||
@ -5,23 +5,21 @@ import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import open3d as o3d
|
||||
from ruamel.yaml import YAML
|
||||
from scipy.ndimage import distance_transform_edt
|
||||
from config.config import cfg
|
||||
|
||||
|
||||
class SafetyLoss(nn.Module):
|
||||
def __init__(self, L, sgm_time):
|
||||
def __init__(self, L):
|
||||
super(SafetyLoss, self).__init__()
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
|
||||
self.traj_num = cfg['horizon_num'] * cfg['vertical_num']
|
||||
self.traj_num = cfg['traj_num']
|
||||
self.map_expand_min = np.array(cfg['map_expand_min'])
|
||||
self.map_expand_max = np.array(cfg['map_expand_max'])
|
||||
self.d0 = cfg["d0"]
|
||||
self.r = cfg["r"]
|
||||
|
||||
self._L = L
|
||||
self.sgm_time = sgm_time
|
||||
self.sgm_time = cfg["sgm_time"]
|
||||
self.eval_points = 30
|
||||
self.device = self._L.device
|
||||
|
||||
@ -31,6 +29,7 @@ class SafetyLoss(nn.Module):
|
||||
self.max_bounds = None # shape: (N, 3)
|
||||
self.sdf_shapes = None # shape: (1, 3)
|
||||
print("Building ESDF map...")
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
data_dir = os.path.join(base_dir, "../", cfg["dataset_path"])
|
||||
self.sdf_maps = self.get_sdf_from_ply(data_dir)
|
||||
print("Map built!")
|
||||
|
||||
@ -1,16 +1,18 @@
|
||||
import torch
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
from config.config import cfg
|
||||
|
||||
|
||||
class LatticeParam:
|
||||
def __init__(self, cfg):
|
||||
ratio = cfg["velocity"] / cfg["vel_align"]
|
||||
self.vel_max = ratio * cfg["vel_align"]
|
||||
self.acc_max = ratio * ratio * cfg["acc_align"]
|
||||
self.segment_time = 2 * cfg["radio_range"] / self.vel_max
|
||||
def __init__(self):
|
||||
ratio = 1.0 if cfg["train"] else cfg["velocity"] / cfg["vel_max_train"]
|
||||
self.vel_max = ratio * cfg["vel_max_train"]
|
||||
self.acc_max = ratio * ratio * cfg["acc_max_train"]
|
||||
self.segment_time = cfg["sgm_time"] / ratio
|
||||
self.horizon_num = cfg["horizon_num"]
|
||||
self.vertical_num = cfg["vertical_num"]
|
||||
self.radio_num = cfg["radio_num"]
|
||||
self.traj_num = cfg["traj_num"]
|
||||
self.horizon_fov = cfg["horizon_camera_fov"]
|
||||
self.vertical_fov = cfg["vertical_camera_fov"]
|
||||
self.horizon_anchor_fov = cfg["horizon_anchor_fov"]
|
||||
@ -38,12 +40,10 @@ class LatticePrimitive(LatticeParam):
|
||||
"""
|
||||
_instance = None
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
self.traj_num = self.vertical_num * self.horizon_num * self.radio_num
|
||||
|
||||
if self.horizon_num == 1:
|
||||
direction_diff = 0
|
||||
else:
|
||||
@ -105,16 +105,6 @@ class LatticePrimitive(LatticeParam):
|
||||
return self.traj_num - id - 1
|
||||
|
||||
@classmethod
|
||||
def get_instance(self, cfg):
|
||||
if self._instance is None: self._instance = self(cfg)
|
||||
def get_instance(self):
|
||||
if self._instance is None: self._instance = self()
|
||||
return self._instance
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import os
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
|
||||
lattice_primitive = LatticePrimitive.get_instance(cfg)
|
||||
print(lattice_primitive.getStateLattice(list(range(lattice_primitive.traj_num))))
|
||||
|
||||
@ -1,16 +1,13 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from ruamel.yaml import YAML
|
||||
from config.config import cfg
|
||||
from policy.primitive import LatticePrimitive
|
||||
|
||||
|
||||
class StateTransform:
|
||||
def __init__(self):
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
|
||||
self.lattice_primitive = LatticePrimitive.get_instance(cfg)
|
||||
self.goal_length = 2.0 * cfg['radio_range']
|
||||
self.lattice_primitive = LatticePrimitive.get_instance()
|
||||
self.goal_length = cfg['goal_length']
|
||||
|
||||
def pred_to_endstate(self, endstate_pred: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
@ -1,38 +1,37 @@
|
||||
import os
|
||||
import os, sys
|
||||
import cv2
|
||||
import time
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from ruamel.yaml import YAML
|
||||
import time
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
from sklearn.model_selection import train_test_split
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from config.config import cfg
|
||||
|
||||
|
||||
class YOPODataset(Dataset):
|
||||
def __init__(self, mode='train', val_ratio=0.1):
|
||||
super(YOPODataset, self).__init__()
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
|
||||
# image params
|
||||
self.height = int(cfg["image_height"])
|
||||
self.width = int(cfg["image_width"])
|
||||
# ramdom state: x-direction: log-normal distribution, yz-direction: normal distribution
|
||||
scale = cfg["velocity"] / cfg["vel_align"]
|
||||
self.vel_max = scale * cfg["vel_align"]
|
||||
self.acc_max = scale * scale * cfg["acc_align"]
|
||||
self.vel_max = cfg["vel_max_train"]
|
||||
self.acc_max = cfg["acc_max_train"]
|
||||
self.vx_lognorm_mean = np.log(1 - cfg["vx_mean_unit"])
|
||||
self.vx_logmorm_sigma = np.log(cfg["vx_std_unit"])
|
||||
self.v_mean = np.array([cfg["vx_mean_unit"], cfg["vy_mean_unit"], cfg["vz_mean_unit"]])
|
||||
self.v_std = np.array([cfg["vx_std_unit"], cfg["vy_std_unit"], cfg["vz_std_unit"]])
|
||||
self.a_mean = np.array([cfg["ax_mean_unit"], cfg["ay_mean_unit"], cfg["az_mean_unit"]])
|
||||
self.a_std = np.array([cfg["ax_std_unit"], cfg["ay_std_unit"], cfg["az_std_unit"]])
|
||||
self.goal_length = 2.0 * cfg['radio_range']
|
||||
self.goal_length = cfg['goal_length']
|
||||
self.goal_pitch_std = cfg["goal_pitch_std"]
|
||||
self.goal_yaw_std = cfg["goal_yaw_std"]
|
||||
if mode == 'train': self.print_data()
|
||||
|
||||
# dataset
|
||||
print("Loading", mode, "dataset, it may take a while...")
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
data_dir = os.path.join(base_dir, "../", cfg["dataset_path"])
|
||||
self.img_list, self.map_idx, self.positions, self.quaternions = [], [], np.empty((0, 3), dtype=np.float32), np.empty((0, 4), dtype=np.float32)
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
Training Strategy
|
||||
supervised learning, imitation learning, testing, rollout
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import atexit
|
||||
from torch.nn import functional as F
|
||||
@ -9,6 +10,7 @@ from rich.progress import Progress
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard.writer import SummaryWriter
|
||||
|
||||
from config.config import cfg
|
||||
from loss.loss_function import YOPOLoss
|
||||
from policy.yopo_network import YopoNetwork
|
||||
from policy.yopo_dataset import YOPODataset
|
||||
@ -35,10 +37,7 @@ class YopoTrainer:
|
||||
self.tensorboard_path = self.get_next_log_path(tensorboard_path)
|
||||
self.tensorboard_log = SummaryWriter(log_dir=self.tensorboard_path)
|
||||
# params
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
|
||||
self.lattice_primitive = LatticePrimitive.get_instance(cfg)
|
||||
self.traj_num = self.lattice_primitive.traj_num
|
||||
self.traj_num = cfg['traj_num']
|
||||
|
||||
# loss
|
||||
self.yopo_loss = YOPOLoss()
|
||||
|
||||
@ -13,9 +13,9 @@ import time
|
||||
import torch
|
||||
import numpy as np
|
||||
import argparse
|
||||
from ruamel.yaml import YAML
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
|
||||
from config.config import cfg
|
||||
from control_msg import PositionCommand
|
||||
from policy.yopo_network import YopoNetwork
|
||||
from policy.poly_solver import *
|
||||
@ -32,8 +32,7 @@ class YopoNet:
|
||||
self.config = config
|
||||
rospy.init_node('yopo_net', anonymous=False)
|
||||
# load params
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
cfg = YAML().load(open(os.path.join(base_dir, "config/traj_opt.yaml"), 'r'))
|
||||
cfg["train"] = False
|
||||
self.height = cfg['image_height']
|
||||
self.width = cfg['image_width']
|
||||
self.min_dis, self.max_dis = 0.04, 20.0
|
||||
@ -64,7 +63,7 @@ class YopoNet:
|
||||
self.lock = Lock()
|
||||
self.last_control_msg = None
|
||||
self.state_transform = StateTransform()
|
||||
self.lattice_primitive = LatticePrimitive.get_instance(cfg)
|
||||
self.lattice_primitive = LatticePrimitive.get_instance()
|
||||
self.traj_time = self.lattice_primitive.segment_time
|
||||
|
||||
# eval
|
||||
@ -354,7 +353,7 @@ def parser():
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
if __name__ == "__main__":
|
||||
args = parser().parse_args()
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
weight = "yopo_trt.pth" if args.use_tensorrt else base_dir + "/saved/YOPO_{}/epoch{}.pth".format(args.trial, args.epoch)
|
||||
@ -372,7 +371,3 @@ def main():
|
||||
'visualize': True # 可视化所有轨迹?(实飞改为False节省计算)
|
||||
}
|
||||
YopoNet(settings, weight)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -24,10 +24,9 @@ def parser():
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
if __name__ == "__main__":
|
||||
args = parser().parse_args()
|
||||
# set random seed
|
||||
configure_random_seed(0)
|
||||
configure_random_seed(0) # set random seed
|
||||
|
||||
# save the configuration and other files
|
||||
log_dir = os.path.dirname(os.path.abspath(__file__)) + "/saved"
|
||||
@ -46,7 +45,3 @@ def main():
|
||||
trainer.train(epoch=50)
|
||||
|
||||
print("Run YOPO Finish!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
Panels:
|
||||
- Class: rviz/Displays
|
||||
Help Height: 78
|
||||
Help Height: 138
|
||||
Name: Displays
|
||||
Property Tree Widget:
|
||||
Expanded:
|
||||
- /PointCloud21/Autocompute Value Bounds1
|
||||
Splitter Ratio: 0.5
|
||||
Tree Height: 326
|
||||
- /Map1/Autocompute Value Bounds1
|
||||
- /Trajectory1
|
||||
Splitter Ratio: 0.6625221967697144
|
||||
Tree Height: 476
|
||||
- Class: rviz/Selection
|
||||
Name: Selection
|
||||
- Class: rviz/Tool Properties
|
||||
@ -24,7 +25,7 @@ Panels:
|
||||
- Class: rviz/Time
|
||||
Name: Time
|
||||
SyncMode: 0
|
||||
SyncSource: Image
|
||||
SyncSource: Depth
|
||||
Preferences:
|
||||
PromptSaveOnExit: true
|
||||
Toolbars:
|
||||
@ -56,7 +57,7 @@ Visualization Manager:
|
||||
Max Value: 1
|
||||
Median window: 5
|
||||
Min Value: 0
|
||||
Name: Image
|
||||
Name: Depth
|
||||
Normalize Range: true
|
||||
Queue Size: 2
|
||||
Transport Hint: raw
|
||||
@ -78,7 +79,7 @@ Visualization Manager:
|
||||
Invert Rainbow: false
|
||||
Max Color: 255; 255; 255
|
||||
Min Color: 0; 0; 0
|
||||
Name: PointCloud2
|
||||
Name: Map
|
||||
Position Transformer: XYZ
|
||||
Queue Size: 10
|
||||
Selectable: true
|
||||
@ -108,7 +109,7 @@ Visualization Manager:
|
||||
Invert Rainbow: false
|
||||
Max Color: 255; 255; 255
|
||||
Min Color: 0; 0; 0
|
||||
Name: PointCloud2
|
||||
Name: Best_Traj
|
||||
Position Transformer: XYZ
|
||||
Queue Size: 10
|
||||
Selectable: true
|
||||
@ -136,7 +137,7 @@ Visualization Manager:
|
||||
Invert Rainbow: true
|
||||
Max Color: 255; 255; 255
|
||||
Min Color: 0; 0; 0
|
||||
Name: PointCloud2
|
||||
Name: All_traj
|
||||
Position Transformer: XYZ
|
||||
Queue Size: 10
|
||||
Selectable: true
|
||||
@ -164,7 +165,7 @@ Visualization Manager:
|
||||
Invert Rainbow: false
|
||||
Max Color: 255; 255; 255
|
||||
Min Color: 0; 0; 0
|
||||
Name: PointCloud2
|
||||
Name: Primitive_Traj
|
||||
Position Transformer: XYZ
|
||||
Queue Size: 10
|
||||
Selectable: true
|
||||
@ -177,11 +178,11 @@ Visualization Manager:
|
||||
Use rainbow: true
|
||||
Value: false
|
||||
Enabled: true
|
||||
Name: Traj
|
||||
Name: Trajectory
|
||||
- Class: rviz/Marker
|
||||
Enabled: true
|
||||
Marker Topic: /quadrotor_simulator_so3/uav
|
||||
Name: Marker
|
||||
Name: Drone
|
||||
Namespaces:
|
||||
mesh: true
|
||||
Queue Size: 100
|
||||
@ -235,14 +236,14 @@ Visualization Manager:
|
||||
Yaw: 3.140000104904175
|
||||
Saved: ~
|
||||
Window Geometry:
|
||||
Depth:
|
||||
collapsed: false
|
||||
Displays:
|
||||
collapsed: false
|
||||
Height: 1016
|
||||
Height: 1600
|
||||
Hide Left Dock: false
|
||||
Hide Right Dock: true
|
||||
Image:
|
||||
collapsed: false
|
||||
QMainWindow State: 000000ff00000000fd0000000400000000000002350000035afc0200000009fb0000001200530065006c0065006300740069006f006e00000001e10000009b0000005c00fffffffb0000001e0054006f006f006c002000500072006f007000650072007400690065007302000001ed000001df00000185000000a3fb000000120056006900650077007300200054006f006f02000001df000002110000018500000122fb000000200054006f006f006c002000500072006f0070006500720074006900650073003203000002880000011d000002210000017afb000000100044006900730070006c006100790073010000003d000001d1000000c900fffffffb0000002000730065006c0065006300740069006f006e00200062007500660066006500720200000138000000aa0000023a00000294fb00000014005700690064006500530074006500720065006f02000000e6000000d2000003ee0000030bfb0000000c004b0069006e0065006300740200000186000001060000030c00000261fb0000000a0049006d0061006700650100000214000001830000001600ffffff00000001000001b90000035afc0200000003fb0000001e0054006f006f006c002000500072006f00700065007200740069006500730100000041000000780000000000000000fb0000000a00560069006500770073000000003d0000035a000000a400fffffffb0000001200530065006c0065006300740069006f006e010000025a000000b200000000000000000000000200000490000000a9fc0100000001fb0000000a00560069006500770073030000004e00000080000002e10000019700000003000007380000003efc0100000002fb0000000800540069006d0065010000000000000738000003bc00fffffffb0000000800540069006d00650100000000000004500000000000000000000004fd0000035a00000004000000040000000800000008fc0000000100000002000000010000000a0054006f006f006c00730100000000ffffffff0000000000000000
|
||||
QMainWindow State: 000000ff00000000fd0000000400000000000003310000053afc0200000009fb0000001200530065006c0065006300740069006f006e00000001e10000009b000000b000fffffffb0000001e0054006f006f006c002000500072006f007000650072007400690065007302000001ed000001df00000185000000b0fb000000120056006900650077007300200054006f006f02000001df000002110000018500000122fb000000200054006f006f006c002000500072006f0070006500720074006900650073003203000002880000011d000002210000017afb000000100044006900730070006c006100790073010000006e000002d40000018200fffffffb0000002000730065006c0065006300740069006f006e00200062007500660066006500720200000138000000aa0000023a00000294fb00000014005700690064006500530074006500720065006f02000000e6000000d2000003ee0000030bfb0000000c004b0069006e0065006300740200000186000001060000030c00000261fb0000000a00440065007000740068010000034e0000025a0000002600ffffff00000001000001b90000035afc0200000003fb0000001e0054006f006f006c002000500072006f00700065007200740069006500730100000041000000780000000000000000fb0000000a00560069006500770073000000003d0000035a0000013200fffffffb0000001200530065006c0065006300740069006f006e010000025a000000b200000000000000000000000200000490000000a9fc0100000001fb0000000a00560069006500770073030000004e00000080000002e1000001970000000300000b700000005efc0100000002fb0000000800540069006d0065010000000000000b70000006dc00fffffffb0000000800540069006d00650100000000000004500000000000000000000008330000053a00000004000000040000000800000008fc0000000100000002000000010000000a0054006f006f006c00730100000000ffffffff0000000000000000
|
||||
Selection:
|
||||
collapsed: false
|
||||
Time:
|
||||
@ -251,6 +252,6 @@ Window Geometry:
|
||||
collapsed: false
|
||||
Views:
|
||||
collapsed: true
|
||||
Width: 1848
|
||||
X: 72
|
||||
Y: 387
|
||||
Width: 2928
|
||||
X: 144
|
||||
Y: 54
|
||||
|
||||
@ -7,13 +7,13 @@
|
||||
python setup.py install
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch2trt import torch2trt
|
||||
from ruamel.yaml import YAML
|
||||
import time
|
||||
from config.config import cfg
|
||||
from policy.yopo_network import YopoNetwork
|
||||
|
||||
|
||||
@ -25,16 +25,15 @@ def parser():
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
if __name__ == "__main__":
|
||||
args = parser().parse_args()
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
cfg = YAML().load(open(os.path.join(base_dir, "config/traj_opt.yaml"), 'r'))
|
||||
weight = base_dir + "/saved/YOPO_{}/epoch{}.pth".format(args.trial, args.epoch)
|
||||
|
||||
print("Loading Network...")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
state_dict = torch.load(weight, weights_only=True)
|
||||
policy = YopoNetwork(horizon_num=cfg["horizon_num"], vertical_num=cfg["vertical_num"])
|
||||
policy = YopoNetwork()
|
||||
policy.load_state_dict(state_dict)
|
||||
policy = policy.to(device)
|
||||
policy.eval()
|
||||
@ -77,6 +76,3 @@ def main():
|
||||
f"Transfer Trajectory Error: {traj_error.item():.6f},"
|
||||
f"Transfer Score Error: {score_error.item():.6f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user