Add truncate collision cost

This commit is contained in:
Your Name Here 2025-07-20 14:56:26 +08:00
parent 6f7060838c
commit 45ab454a23

View File

@ -22,12 +22,13 @@ class SafetyLoss(nn.Module):
self.sgm_time = cfg["sgm_time"]
self.eval_points = 30
self.device = self._L.device
self.truncate_cost = False # truncate cost at collision or use full trajectory cost
# SDF
self.voxel_size = 0.2
self.min_bounds = None # shape: (N, 3)
self.max_bounds = None # shape: (N, 3)
self.sdf_shapes = None # shape: (1, 3)
self.sdf_shapes = None # shape: (N, 3)
print("Building ESDF map...")
base_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.join(base_dir, "../", cfg["dataset_path"])
@ -54,11 +55,31 @@ class SafetyLoss(nn.Module):
# get pos from coeff [B*H*V, N, 3] -> [B, H*V*N, 3]
pos_coe = self.get_position_from_coeff(coe, t_list)
pos_batch = pos_coe.reshape(-1, self.traj_num * pos_coe.shape[1], 3)
# get info from sdf_map
cost = self.get_distance_cost(pos_batch, map_id)
cost_dt = (cost * dt).reshape(-1, pos_coe.shape[1]) # [B*H*V, N]
cost_colli = cost_dt.sum(dim=-1)
# get info from sdf_map
cost, dist = self.get_distance_cost(pos_batch, map_id)
if not self.truncate_cost:
# Compute time integral of full trajectory cost (for general scenario)
cost_dt = (cost * dt).reshape(-1, pos_coe.shape[1]) # [B*H*V, N]
cost_colli = cost_dt.sum(dim=-1)
else:
# Only compute cost before the first collision (better for large-obstacle scenario)
dist = dist.view(batch_size, -1) # [B*H*V, N]
cost = cost.view(batch_size, -1) # [B*H*V, N]
N = dist.shape[1]
mask = dist <= 0 # [B*H*V, N]
index = th.where(mask, th.arange(N).to(self.device).expand(batch_size, N), N - 1)
first_colli_idx = index.min(dim=1).values # [B*H*V]
arange = th.arange(N).to(self.device).unsqueeze(0).expand(batch_size, N) # [B*H*V, N]
valid_mask = arange <= first_colli_idx.unsqueeze(1) # [B*H*V, N]
masked_cost = cost * valid_mask
valid_count = first_colli_idx + 1
cost_colli = self.sgm_time * masked_cost.sum(dim=-1) / valid_count
return cost_colli
@ -93,7 +114,7 @@ class SafetyLoss(nn.Module):
cost = cost.masked_fill(~valid_mask, 0.0)
return cost
return cost, dist_query
def cost_function(self, d):
return th.exp(-(d - self.d0) / self.r)