Fix a minor issue in safety cost
This commit is contained in:
parent
42e9722597
commit
315063477d
@ -93,17 +93,13 @@ class SafetyLoss(nn.Module):
|
||||
grid_point = 2.0 * grid / (local_shape - 1).unsqueeze(1) - 1.0 # (B, N, 3)
|
||||
|
||||
grid_point = grid_point.view(B, 1, 1, N, 3)
|
||||
|
||||
valid_mask = ((grid_point < 0.99).all(-1) & (grid_point > -0.99).all(-1)).squeeze(dim=1).squeeze(dim=1) # (B, N)
|
||||
grid_point = th.clamp(grid_point, min=-0.99, max=0.99) # (B, N)
|
||||
|
||||
dist_query = F.grid_sample(sdf_maps, grid_point, mode='bilinear', padding_mode='zeros', align_corners=True) # (B, 1, 1, 1, N)
|
||||
dist_query = dist_query.view(B, N)
|
||||
|
||||
# Cost function
|
||||
cost = self.cost_function(dist_query) # (B, N)
|
||||
|
||||
cost = cost.masked_fill(~valid_mask, 0.0)
|
||||
|
||||
return cost, dist_query
|
||||
|
||||
def cost_function(self, d):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user