Add dynamic weight adjustment based on speed

This commit is contained in:
TJU_Lu 2025-07-05 17:05:28 +08:00
parent fa7f026173
commit 788e9bc979
3 changed files with 47 additions and 31 deletions

View File

@ -5,10 +5,10 @@ velocity: 6.0
vel_align: 6.0
acc_align: 6.0
# IMPORTANT PARAM: weight of penalties (6m/s)
# IMPORTANT PARAM: weight of penalties (for unit speed)
wg: 0.1 # guidance
ws: 0.0015 # smoothness
wc: 0.5 # collision
ws: 10.0 # smoothness
wc: 0.1 # collision
# dataset
dataset_path: "../dataset"
@ -29,18 +29,18 @@ radio_num: 1 # only support 1 currently
d0: 1.2
r: 0.6
# distribution parameters for unit state sampling (no need to adjust)
vx_mean_unit: 0.5
# distribution parameters for unit state sampling
vx_mean_unit: 0.4
vy_mean_unit: 0.0
vz_mean_unit: 0.0
vx_std_unit: 1.7
vx_std_unit: 2.0
vy_std_unit: 0.45
vz_std_unit: 0.3
ax_mean_unit: 0.0
ay_mean_unit: 0.0
az_mean_unit: 0.0
ax_std_unit: 0.4
ay_std_unit: 0.4
ax_std_unit: 0.5
ay_std_unit: 0.5
az_std_unit: 0.3
goal_pitch_std: 10.0 # clip goal_length to: 2 * radio_range; 10% probability [0, 2 * radio_range]
goal_yaw_std: 20.0

View File

@ -18,19 +18,21 @@ class YOPOLoss(nn.Module):
"""
super(YOPOLoss, self).__init__()
base_dir = os.path.dirname(os.path.abspath(__file__))
self.cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
self.sgm_time = 2 * self.cfg["radio_range"] / self.cfg["velocity"]
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
self.sgm_time = 2 * cfg["radio_range"] / cfg["velocity"]
self.device = th.device("cuda" if th.cuda.is_available() else "cpu")
self._C, self._B, self._L, self._R = self.qp_generation()
self._R = self._R.to(self.device)
self._L = self._L.to(self.device)
self.smoothness_weight = self.cfg["ws"]
self.safety_weight = self.cfg["wc"]
self.goal_weight = self.cfg["wg"]
vel_scale = cfg["velocity"] / 1.0
self.smoothness_weight = cfg["ws"]
self.safety_weight = cfg["wc"]
self.goal_weight = cfg["wg"]
self.denormalize_weight(vel_scale)
self.smoothness_loss = SmoothnessLoss(self._R)
self.safety_loss = SafetyLoss(self._L, self.sgm_time)
self.goal_loss = GuidanceLoss()
print("---------- Loss ---------")
print("------ Actual Loss ------")
print(f"| {'smooth':<12} = {self.smoothness_weight:6.4f} |")
print(f"| {'safety':<12} = {self.safety_weight:6.4f} |")
print(f"| {'goal':<12} = {self.goal_weight:6.4f} |")
@ -68,6 +70,20 @@ class YOPOLoss(nn.Module):
return _C, B, _L, _R
def denormalize_weight(self, vel_scale):
"""
Denormalize the cost weight to ensure consistency across different speeds to simplify parameter tuning.
smoothness cost: time integral of jerk² is used as a smoothness cost.
If the speed is scaled by n, the cost is scaled by n⁵ (because jerk * n⁶ and time * 1/n).
safety cost: time integral of the distance from trajectory to obstacles.
If the speed is scaled by n, the cost is scaled by 1/n (because time * 1/n).
goal cost: projection of the trajectory onto goal direction.
Independent of speed.
"""
self.smoothness_weight = self.smoothness_weight / vel_scale ** 5
self.safety_weight = self.safety_weight * vel_scale
self.goal_weight = self.goal_weight
def forward(self, state, prediction, goal, map_id):
"""
Args:

View File

@ -18,8 +18,8 @@ class YOPODataset(Dataset):
self.width = int(cfg["image_width"])
# ramdom state: x-direction: log-normal distribution, yz-direction: normal distribution
scale = cfg["velocity"] / cfg["vel_align"]
self.vel_scale = scale * cfg["vel_align"]
self.acc_scale = scale * scale * cfg["acc_align"]
self.vel_max = scale * cfg["vel_align"]
self.acc_max = scale * scale * cfg["acc_align"]
self.vx_lognorm_mean = np.log(1 - cfg["vx_mean_unit"])
self.vx_logmorm_sigma = np.log(cfg["vx_std_unit"])
self.v_mean = np.array([cfg["vx_mean_unit"], cfg["vy_mean_unit"], cfg["vz_mean_unit"]])
@ -106,18 +106,18 @@ class YOPODataset(Dataset):
def _get_random_state(self):
while True:
vel = self.vel_scale * (self.v_mean + self.v_std * np.random.randn(3))
vel = self.vel_max * (self.v_mean + self.v_std * np.random.randn(3))
right_skewed_vx = -1
while right_skewed_vx < 0:
right_skewed_vx = self.vel_scale * np.random.lognormal(mean=self.vx_lognorm_mean, sigma=self.vx_logmorm_sigma, size=None)
right_skewed_vx = -right_skewed_vx + 1.2 * self.vel_scale # * 1.2 to ensure v_max can be sampled
right_skewed_vx = self.vel_max * np.random.lognormal(mean=self.vx_lognorm_mean, sigma=self.vx_logmorm_sigma, size=None)
right_skewed_vx = -right_skewed_vx + 1.2 * self.vel_max # * 1.2 to ensure v_max can be sampled
vel[0] = right_skewed_vx
if np.linalg.norm(vel) < 1.2 * self.vel_scale: # avoid outliers
if np.linalg.norm(vel) < 1.2 * self.vel_max: # avoid outliers
break
while True:
acc = self.acc_scale * (self.a_mean + self.a_std * np.random.randn(3))
if np.linalg.norm(acc) < 1.2 * self.acc_scale: # avoid outliers
acc = self.acc_max * (self.a_mean + self.a_std * np.random.randn(3))
if np.linalg.norm(acc) < 1.2 * self.acc_max: # avoid outliers
break
return vel, acc
@ -136,19 +136,19 @@ class YOPODataset(Dataset):
def print_data(self):
import scipy.stats as stats
# 计算Vx 5% ~ 95% 区间
p5 = self.vel_scale * np.exp(stats.norm.ppf(0.05, loc=self.vx_lognorm_mean, scale=self.vx_logmorm_sigma))
p95 = self.vel_scale * np.exp(stats.norm.ppf(0.95, loc=self.vx_lognorm_mean, scale=self.vx_logmorm_sigma))
p5 = self.vel_max * np.exp(stats.norm.ppf(0.05, loc=self.vx_lognorm_mean, scale=self.vx_logmorm_sigma))
p95 = self.vel_max * np.exp(stats.norm.ppf(0.95, loc=self.vx_lognorm_mean, scale=self.vx_logmorm_sigma))
v_lower = self.vel_scale * (self.v_mean - 2 * self.v_std)
v_upper = self.vel_scale * (self.v_mean + 2 * self.v_std)
v_lower[0] = -p95 + 1.2 * self.vel_scale
v_upper[0] = -p5 + 1.2 * self.vel_scale
v_lower = self.vel_max * (self.v_mean - 2 * self.v_std)
v_upper = self.vel_max * (self.v_mean + 2 * self.v_std)
v_lower[0] = max(-p95 + 1.2 * self.vel_max, 0)
v_upper[0] = -p5 + 1.2 * self.vel_max
a_lower = self.acc_scale * (self.a_mean - 2 * self.a_std)
a_upper = self.acc_scale * (self.a_mean + 2 * self.a_std)
a_lower = self.acc_max * (self.a_mean - 2 * self.a_std)
a_upper = self.acc_max * (self.a_mean + 2 * self.a_std)
print("----------------- Sampling State --------------------")
print("| X-Y-Z | Vel 90% Range(m/s) | Acc 90% Range(m/s2) |")
print("| X-Y-Z | Vel 95% Range(m/s) | Acc 95% Range(m/s2) |")
print("|-------|---------------------|---------------------|")
for i in range(3):
print(f"| {i:^4} | {v_lower[i]:^9.1f}~{v_upper[i]:^9.1f} |"