fix primitive num to variable
This commit is contained in:
parent
d14118eefd
commit
bd1480df33
@ -11,7 +11,7 @@ wg: 0.12 # guidance
|
||||
ws: 10.0 # smoothness
|
||||
wc: 0.1 # collision
|
||||
|
||||
# dataset
|
||||
# dataset: set image_size = primitive_num × downsampling_factor (×32 for ResNet-18) in each axis
|
||||
dataset_path: "../dataset"
|
||||
image_height: 96
|
||||
image_width: 160
|
||||
|
||||
@ -15,10 +15,10 @@ class StateTransform:
|
||||
endstate_pred: [batch; px py pz vx vy vz ax ay az; primitive_v; primitive_h]
|
||||
:return [batch; px py pz vx vy vz ax ay az; primitive_v; primitive_h] in body frame
|
||||
"""
|
||||
B, N = endstate_pred.shape[0], endstate_pred.shape[2] * endstate_pred.shape[3]
|
||||
B, V, H = endstate_pred.shape[0], endstate_pred.shape[2], endstate_pred.shape[3]
|
||||
|
||||
# [B, 9, 3, 5] -> [B, 3, 5, 9] -> [B, 15, 9]
|
||||
endstate_pred = endstate_pred.permute(0, 2, 3, 1).reshape(B, N, 9)
|
||||
endstate_pred = endstate_pred.permute(0, 2, 3, 1).reshape(B, V * H, 9)
|
||||
|
||||
# 获取 lattice angle 和 rotation (.flip: 由于lattice和grid的顺序相反)
|
||||
yaw, pitch = self.lattice_primitive.getAngleLattice() # [15]
|
||||
@ -47,7 +47,7 @@ class StateTransform:
|
||||
|
||||
endstate = torch.cat([endstate_p, endstate_vb, endstate_ab], dim=-1) # [B, 15, 9]
|
||||
|
||||
endstate = endstate.permute(0, 2, 1).reshape(B, 9, 3, 5) # [B, 9, 3, 5]
|
||||
endstate = endstate.permute(0, 2, 1).reshape(B, 9, V, H) # [B, 9, 3, 5]
|
||||
return endstate
|
||||
|
||||
def pred_to_endstate_cpu(self, endstate_pred: np.ndarray, lattice_id: torch.Tensor) -> np.ndarray:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user