Add truncate collision cost
This commit is contained in:
parent
6f7060838c
commit
45ab454a23
@ -22,12 +22,13 @@ class SafetyLoss(nn.Module):
|
||||
self.sgm_time = cfg["sgm_time"]
|
||||
self.eval_points = 30
|
||||
self.device = self._L.device
|
||||
self.truncate_cost = False # truncate cost at collision or use full trajectory cost
|
||||
|
||||
# SDF
|
||||
self.voxel_size = 0.2
|
||||
self.min_bounds = None # shape: (N, 3)
|
||||
self.max_bounds = None # shape: (N, 3)
|
||||
self.sdf_shapes = None # shape: (1, 3)
|
||||
self.sdf_shapes = None # shape: (N, 3)
|
||||
print("Building ESDF map...")
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
data_dir = os.path.join(base_dir, "../", cfg["dataset_path"])
|
||||
@ -54,11 +55,31 @@ class SafetyLoss(nn.Module):
|
||||
# get pos from coeff [B*H*V, N, 3] -> [B, H*V*N, 3]
|
||||
pos_coe = self.get_position_from_coeff(coe, t_list)
|
||||
pos_batch = pos_coe.reshape(-1, self.traj_num * pos_coe.shape[1], 3)
|
||||
# get info from sdf_map
|
||||
cost = self.get_distance_cost(pos_batch, map_id)
|
||||
|
||||
cost_dt = (cost * dt).reshape(-1, pos_coe.shape[1]) # [B*H*V, N]
|
||||
cost_colli = cost_dt.sum(dim=-1)
|
||||
# get info from sdf_map
|
||||
cost, dist = self.get_distance_cost(pos_batch, map_id)
|
||||
|
||||
if not self.truncate_cost:
|
||||
# Compute time integral of full trajectory cost (for general scenario)
|
||||
cost_dt = (cost * dt).reshape(-1, pos_coe.shape[1]) # [B*H*V, N]
|
||||
cost_colli = cost_dt.sum(dim=-1)
|
||||
else:
|
||||
# Only compute cost before the first collision (better for large-obstacle scenario)
|
||||
dist = dist.view(batch_size, -1) # [B*H*V, N]
|
||||
cost = cost.view(batch_size, -1) # [B*H*V, N]
|
||||
|
||||
N = dist.shape[1]
|
||||
mask = dist <= 0 # [B*H*V, N]
|
||||
index = th.where(mask, th.arange(N).to(self.device).expand(batch_size, N), N - 1)
|
||||
first_colli_idx = index.min(dim=1).values # [B*H*V]
|
||||
|
||||
arange = th.arange(N).to(self.device).unsqueeze(0).expand(batch_size, N) # [B*H*V, N]
|
||||
valid_mask = arange <= first_colli_idx.unsqueeze(1) # [B*H*V, N]
|
||||
|
||||
masked_cost = cost * valid_mask
|
||||
valid_count = first_colli_idx + 1
|
||||
|
||||
cost_colli = self.sgm_time * masked_cost.sum(dim=-1) / valid_count
|
||||
|
||||
return cost_colli
|
||||
|
||||
@ -93,7 +114,7 @@ class SafetyLoss(nn.Module):
|
||||
|
||||
cost = cost.masked_fill(~valid_mask, 0.0)
|
||||
|
||||
return cost
|
||||
return cost, dist_query
|
||||
|
||||
def cost_function(self, d):
|
||||
return th.exp(-(d - self.d0) / self.r)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user