From bd1480df33de5f31723d59a04f69d778d94acfba Mon Sep 17 00:00:00 2001 From: TJU-Lu Date: Fri, 29 Aug 2025 11:40:26 +0800 Subject: [PATCH] fix primitive num to variable --- YOPO/config/traj_opt.yaml | 2 +- YOPO/policy/state_transform.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/YOPO/config/traj_opt.yaml b/YOPO/config/traj_opt.yaml index 7be5200..1d671ee 100644 --- a/YOPO/config/traj_opt.yaml +++ b/YOPO/config/traj_opt.yaml @@ -11,7 +11,7 @@ wg: 0.12 # guidance ws: 10.0 # smoothness wc: 0.1 # collision -# dataset +# dataset: set image_size = primitive_num × downsampling_factor (×32 for ResNet-18) in each axis dataset_path: "../dataset" image_height: 96 image_width: 160 diff --git a/YOPO/policy/state_transform.py b/YOPO/policy/state_transform.py index eb0f30c..31cefef 100644 --- a/YOPO/policy/state_transform.py +++ b/YOPO/policy/state_transform.py @@ -15,10 +15,10 @@ class StateTransform: endstate_pred: [batch; px py pz vx vy vz ax ay az; primitive_v; primitive_h] :return [batch; px py pz vx vy vz ax ay az; primitive_v; primitive_h] in body frame """ - B, N = endstate_pred.shape[0], endstate_pred.shape[2] * endstate_pred.shape[3] + B, V, H = endstate_pred.shape[0], endstate_pred.shape[2], endstate_pred.shape[3] # [B, 9, 3, 5] -> [B, 3, 5, 9] -> [B, 15, 9] - endstate_pred = endstate_pred.permute(0, 2, 3, 1).reshape(B, N, 9) + endstate_pred = endstate_pred.permute(0, 2, 3, 1).reshape(B, V * H, 9) # 获取 lattice angle 和 rotation (.flip: 由于lattice和grid的顺序相反) yaw, pitch = self.lattice_primitive.getAngleLattice() # [15] @@ -47,7 +47,7 @@ class StateTransform: endstate = torch.cat([endstate_p, endstate_vb, endstate_ab], dim=-1) # [B, 15, 9] - endstate = endstate.permute(0, 2, 1).reshape(B, 9, 3, 5) # [B, 9, 3, 5] + endstate = endstate.permute(0, 2, 1).reshape(B, 9, V, H) # [B, 9, 3, 5] return endstate def pred_to_endstate_cpu(self, endstate_pred: np.ndarray, lattice_id: torch.Tensor) -> np.ndarray: