Fix a minor issue in safety cost

This commit is contained in:
TJU-Lu 2025-11-05 19:24:25 +08:00
parent 42e9722597
commit 315063477d

View File

@ -93,17 +93,13 @@ class SafetyLoss(nn.Module):
grid_point = 2.0 * grid / (local_shape - 1).unsqueeze(1) - 1.0 # (B, N, 3)
grid_point = grid_point.view(B, 1, 1, N, 3)
valid_mask = ((grid_point < 0.99).all(-1) & (grid_point > -0.99).all(-1)).squeeze(dim=1).squeeze(dim=1) # (B, N)
grid_point = th.clamp(grid_point, min=-0.99, max=0.99) # (B, N)
dist_query = F.grid_sample(sdf_maps, grid_point, mode='bilinear', padding_mode='zeros', align_corners=True) # (B, 1, 1, 1, N)
dist_query = dist_query.view(B, N)
# Cost function
cost = self.cost_function(dist_query) # (B, N)
cost = cost.masked_fill(~valid_mask, 0.0)
return cost, dist_query
def cost_function(self, d):