Fix a minor issue in safety cost
This commit is contained in:
parent
42e9722597
commit
315063477d
@ -93,17 +93,13 @@ class SafetyLoss(nn.Module):
|
|||||||
grid_point = 2.0 * grid / (local_shape - 1).unsqueeze(1) - 1.0 # (B, N, 3)
|
grid_point = 2.0 * grid / (local_shape - 1).unsqueeze(1) - 1.0 # (B, N, 3)
|
||||||
|
|
||||||
grid_point = grid_point.view(B, 1, 1, N, 3)
|
grid_point = grid_point.view(B, 1, 1, N, 3)
|
||||||
|
grid_point = th.clamp(grid_point, min=-0.99, max=0.99) # (B, N)
|
||||||
valid_mask = ((grid_point < 0.99).all(-1) & (grid_point > -0.99).all(-1)).squeeze(dim=1).squeeze(dim=1) # (B, N)
|
|
||||||
|
|
||||||
dist_query = F.grid_sample(sdf_maps, grid_point, mode='bilinear', padding_mode='zeros', align_corners=True) # (B, 1, 1, 1, N)
|
dist_query = F.grid_sample(sdf_maps, grid_point, mode='bilinear', padding_mode='zeros', align_corners=True) # (B, 1, 1, 1, N)
|
||||||
dist_query = dist_query.view(B, N)
|
dist_query = dist_query.view(B, N)
|
||||||
|
|
||||||
# Cost function
|
# Cost function
|
||||||
cost = self.cost_function(dist_query) # (B, N)
|
cost = self.cost_function(dist_query) # (B, N)
|
||||||
|
|
||||||
cost = cost.masked_fill(~valid_mask, 0.0)
|
|
||||||
|
|
||||||
return cost, dist_query
|
return cost, dist_query
|
||||||
|
|
||||||
def cost_function(self, d):
|
def cost_function(self, d):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user