fix primitive num to variable

This commit is contained in:
TJU-Lu 2025-08-29 11:40:26 +08:00
parent d14118eefd
commit bd1480df33
2 changed files with 4 additions and 4 deletions

View File

@ -11,7 +11,7 @@ wg: 0.12 # guidance
ws: 10.0 # smoothness
wc: 0.1 # collision
# dataset
# dataset: set image_size = primitive_num × downsampling_factor (×32 for ResNet-18) in each axis
dataset_path: "../dataset"
image_height: 96
image_width: 160

View File

@ -15,10 +15,10 @@ class StateTransform:
endstate_pred: [batch; px py pz vx vy vz ax ay az; primitive_v; primitive_h]
:return [batch; px py pz vx vy vz ax ay az; primitive_v; primitive_h] in body frame
"""
B, N = endstate_pred.shape[0], endstate_pred.shape[2] * endstate_pred.shape[3]
B, V, H = endstate_pred.shape[0], endstate_pred.shape[2], endstate_pred.shape[3]
# [B, 9, 3, 5] -> [B, 3, 5, 9] -> [B, 15, 9]
endstate_pred = endstate_pred.permute(0, 2, 3, 1).reshape(B, N, 9)
endstate_pred = endstate_pred.permute(0, 2, 3, 1).reshape(B, V * H, 9)
# 获取 lattice angle 和 rotation (.flip: 由于lattice和grid的顺序相反)
yaw, pitch = self.lattice_primitive.getAngleLattice() # [15]
@ -47,7 +47,7 @@ class StateTransform:
endstate = torch.cat([endstate_p, endstate_vb, endstate_ab], dim=-1) # [B, 15, 9]
endstate = endstate.permute(0, 2, 1).reshape(B, 9, 3, 5) # [B, 9, 3, 5]
endstate = endstate.permute(0, 2, 1).reshape(B, 9, V, H) # [B, 9, 3, 5]
return endstate
def pred_to_endstate_cpu(self, endstate_pred: np.ndarray, lattice_id: torch.Tensor) -> np.ndarray: