From 45ab454a239e879f6305d60ebcbecf0faa174a86 Mon Sep 17 00:00:00 2001 From: Your Name Here Date: Sun, 20 Jul 2025 14:56:26 +0800 Subject: [PATCH] Add truncate collision cost --- YOPO/loss/safety_loss.py | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/YOPO/loss/safety_loss.py b/YOPO/loss/safety_loss.py index cbe5f4f..755c64f 100644 --- a/YOPO/loss/safety_loss.py +++ b/YOPO/loss/safety_loss.py @@ -22,12 +22,13 @@ class SafetyLoss(nn.Module): self.sgm_time = cfg["sgm_time"] self.eval_points = 30 self.device = self._L.device + self.truncate_cost = False # truncate cost at collision or use full trajectory cost # SDF self.voxel_size = 0.2 self.min_bounds = None # shape: (N, 3) self.max_bounds = None # shape: (N, 3) - self.sdf_shapes = None # shape: (1, 3) + self.sdf_shapes = None # shape: (N, 3) print("Building ESDF map...") base_dir = os.path.dirname(os.path.abspath(__file__)) data_dir = os.path.join(base_dir, "../", cfg["dataset_path"]) @@ -54,11 +55,31 @@ class SafetyLoss(nn.Module): # get pos from coeff [B*H*V, N, 3] -> [B, H*V*N, 3] pos_coe = self.get_position_from_coeff(coe, t_list) pos_batch = pos_coe.reshape(-1, self.traj_num * pos_coe.shape[1], 3) - # get info from sdf_map - cost = self.get_distance_cost(pos_batch, map_id) - cost_dt = (cost * dt).reshape(-1, pos_coe.shape[1]) # [B*H*V, N] - cost_colli = cost_dt.sum(dim=-1) + # get info from sdf_map + cost, dist = self.get_distance_cost(pos_batch, map_id) + + if not self.truncate_cost: + # Compute time integral of full trajectory cost (for general scenario) + cost_dt = (cost * dt).reshape(-1, pos_coe.shape[1]) # [B*H*V, N] + cost_colli = cost_dt.sum(dim=-1) + else: + # Only compute cost before the first collision (better for large-obstacle scenario) + dist = dist.view(batch_size, -1) # [B*H*V, N] + cost = cost.view(batch_size, -1) # [B*H*V, N] + + N = dist.shape[1] + mask = dist <= 0 # [B*H*V, N] + index = th.where(mask, th.arange(N).to(self.device).expand(batch_size, N), N - 1) + first_colli_idx = index.min(dim=1).values # [B*H*V] + + arange = th.arange(N).to(self.device).unsqueeze(0).expand(batch_size, N) # [B*H*V, N] + valid_mask = arange <= first_colli_idx.unsqueeze(1) # [B*H*V, N] + + masked_cost = cost * valid_mask + valid_count = first_colli_idx + 1 + + cost_colli = self.sgm_time * masked_cost.sum(dim=-1) / valid_count return cost_colli @@ -93,7 +114,7 @@ class SafetyLoss(nn.Module): cost = cost.masked_fill(~valid_mask, 0.0) - return cost + return cost, dist_query def cost_function(self, d): return th.exp(-(d - self.d0) / self.r)