diff --git a/YOPO/loss/safety_loss.py b/YOPO/loss/safety_loss.py index 8dc37ab..793b068 100644 --- a/YOPO/loss/safety_loss.py +++ b/YOPO/loss/safety_loss.py @@ -93,17 +93,13 @@ class SafetyLoss(nn.Module): grid_point = 2.0 * grid / (local_shape - 1).unsqueeze(1) - 1.0 # (B, N, 3) grid_point = grid_point.view(B, 1, 1, N, 3) - - valid_mask = ((grid_point < 0.99).all(-1) & (grid_point > -0.99).all(-1)).squeeze(dim=1).squeeze(dim=1) # (B, N) + grid_point = th.clamp(grid_point, min=-0.99, max=0.99) # (B, N) dist_query = F.grid_sample(sdf_maps, grid_point, mode='bilinear', padding_mode='zeros', align_corners=True) # (B, 1, 1, 1, N) dist_query = dist_query.view(B, N) # Cost function cost = self.cost_function(dist_query) # (B, N) - - cost = cost.masked_fill(~valid_mask, 0.0) - return cost, dist_query def cost_function(self, d):