YOPO/YOPO/policy/state_transform.py
2025-06-24 23:16:49 +08:00

154 lines
6.8 KiB
Python

import os
import torch
import numpy as np
from ruamel.yaml import YAML
from policy.primitive import LatticePrimitive
class StateTransform:
def __init__(self):
base_dir = os.path.dirname(os.path.abspath(__file__))
cfg = YAML().load(open(os.path.join(base_dir, "../config/traj_opt.yaml"), 'r'))
self.lattice_primitive = LatticePrimitive.get_instance(cfg)
self.goal_length = 2.0 * cfg['radio_range']
def pred_to_endstate(self, endstate_pred: torch.Tensor) -> torch.Tensor:
"""
Transform the predicted state to the body frame (Original prediction → Primitive frame → Body frame).
endstate_pred: [batch; px py pz vx vy vz ax ay az; primitive_v; primitive_h]
:return [batch; px py pz vx vy vz ax ay az; primitive_v; primitive_h] in body frame
"""
B, N = endstate_pred.shape[0], endstate_pred.shape[2] * endstate_pred.shape[3]
# [B, 9, 3, 5] -> [B, 3, 5, 9] -> [B, 15, 9]
endstate_pred = endstate_pred.permute(0, 2, 3, 1).reshape(B, N, 9)
# 获取 lattice angle 和 rotation (.flip: 由于lattice和grid的顺序相反)
yaw, pitch = self.lattice_primitive.getAngleLattice() # [15]
yaw = yaw.flip(0)[None, :].expand(B, -1) # [B, 15]
pitch = pitch.flip(0)[None, :].expand(B, -1) # [B, 15]
Rbp = self.lattice_primitive.getRotation().flip(0) # [15, 3, 3]
Rbp = Rbp[None, :, :, :].expand(B, -1, -1, -1) # [B, 15, 3, 3]
delta_yaw = endstate_pred[:, :, 0] * self.lattice_primitive.yaw_diff # [B, 15]
delta_pitch = endstate_pred[:, :, 1] * self.lattice_primitive.pitch_diff
radio = (endstate_pred[:, :, 2] + 1.0) * self.lattice_primitive.radio_range
cos_pitch = torch.cos(pitch + delta_pitch)
endstate_x = cos_pitch * torch.cos(yaw + delta_yaw) * radio
endstate_y = cos_pitch * torch.sin(yaw + delta_yaw) * radio
endstate_z = torch.sin(pitch + delta_pitch) * radio
endstate_p = torch.stack([endstate_x, endstate_y, endstate_z], dim=-1) # [B, 15, 3]
# vel / acc
endstate_vp = endstate_pred[:, :, 3:6] * self.lattice_primitive.vel_max # [B, 15, 3]
endstate_ap = endstate_pred[:, :, 6:9] * self.lattice_primitive.acc_max # [B, 15, 3]
# v/a 变换到 body frame
endstate_vb = torch.matmul(Rbp, endstate_vp.unsqueeze(-1)).squeeze(-1) # [B, 15, 3]
endstate_ab = torch.matmul(Rbp, endstate_ap.unsqueeze(-1)).squeeze(-1)
endstate = torch.cat([endstate_p, endstate_vb, endstate_ab], dim=-1) # [B, 15, 9]
endstate = endstate.permute(0, 2, 1).reshape(B, 9, 3, 5) # [B, 9, 3, 5]
return endstate
def pred_to_endstate_cpu(self, endstate_pred: np.ndarray, lattice_id: torch.Tensor) -> np.ndarray:
"""
Used during test:
Numpy version of pred_to_endstate() on CPU (used in test, x10 times faster than torch on CUDA)
:return [B; px py pz vx vy vz ax ay az] in body frame
"""
delta_yaw = endstate_pred[:, 0] * self.lattice_primitive.yaw_diff
delta_pitch = endstate_pred[:, 1] * self.lattice_primitive.pitch_diff
radio = (endstate_pred[:, 2] + 1.0) * self.lattice_primitive.radio_range
yaw, pitch = self.lattice_primitive.getAngleLattice(lattice_id)
yaw, pitch = yaw.cpu().numpy(), pitch.cpu().numpy()
endstate_x = np.cos(pitch + delta_pitch) * np.cos(yaw + delta_yaw) * radio
endstate_y = np.cos(pitch + delta_pitch) * np.sin(yaw + delta_yaw) * radio
endstate_z = np.sin(pitch + delta_pitch) * radio
endstate_p = np.stack((endstate_x, endstate_y, endstate_z), axis=1)
endstate_vp = endstate_pred[:, 3:6] * self.lattice_primitive.vel_max
endstate_ap = endstate_pred[:, 6:9] * self.lattice_primitive.acc_max
Rpb = self.lattice_primitive.getRotation(lattice_id).cpu().numpy()
endstate_vb = np.matmul(Rpb, endstate_vp[:, :, np.newaxis]).squeeze(-1)
endstate_ab = np.matmul(Rpb, endstate_ap[:, :, np.newaxis]).squeeze(-1)
return np.concatenate((endstate_p, endstate_vb, endstate_ab), axis=1)
def prepare_input(self, obs):
"""
Transform the observation to the primitive frame (Body frame → Primitive frame → Body frame).
obs: [batch; vx, vy, yz, ax, ay, az, gx, gy, gz] in body frame
:return [batch; vx, vy, yz, ax, ay, az, gx, gy, gz; primitive_v; primitive_h] in primitive frame
"""
B, N = obs.shape[0], self.lattice_primitive.traj_num
# 获取所有 Rbp 并倒序排列 (由于lattice和grid的顺序相反)
Rbp_all = self.lattice_primitive.getRotation().flip(0) # shape: [N, 3, 3]
obs = obs.view(B, 3, 3) # [B, 3, 3]
# 扩展 obs 和 Rbp 到 [B, N, 3, 3]
obs_exp = obs[:, None, :, :].expand(B, N, 3, 3)
Rbp_exp = Rbp_all[None, :, :, :].expand(B, N, 3, 3)
# 执行批量坐标变换
transformed = torch.matmul(obs_exp, Rbp_exp) # [B, N, 3, 3]
transformed_flat = transformed.view(B, N, 9) # [B, N, 9]
out = transformed_flat.permute(0, 2, 1).contiguous() # [B, 9, N]
out = out.view(B, 9, self.lattice_primitive.vertical_num, self.lattice_primitive.horizon_num) # [B, 9, V, H]
return out
def unnormalize_obs(self, vel_acc):
vel_acc[:, 0:3] = vel_acc[:, 0:3] * self.lattice_primitive.vel_max
vel_acc[:, 3:6] = vel_acc[:, 3:6] * self.lattice_primitive.acc_max
return vel_acc
def normalize_obs(self, vel_acc_goal):
vel_acc_goal[:, 0:3] = vel_acc_goal[:, 0:3] / self.lattice_primitive.vel_max
vel_acc_goal[:, 3:6] = vel_acc_goal[:, 3:6] / self.lattice_primitive.acc_max
vel_acc_goal[:, 6:9] = vel_acc_goal[:, 6:9] / self.goal_length
# Clamp the goal direction to unit length
goal_norm_length = vel_acc_goal[:, 6:9].norm(dim=1, keepdim=True)
scaling = goal_norm_length.clamp(max=1.0) / (goal_norm_length + 1e-8)
vel_acc_goal[:, 6:9] = vel_acc_goal[:, 6:9] * scaling
return vel_acc_goal
def rotate_body2world(rot_wb, pos_b):
"""
Rotate pos_b from body frame to world frame using quaternion q_wb.
rot_wb: (..., 3, 3)
pos_b: (..., 3)
"""
pos_w = torch.matmul(rot_wb, pos_b.unsqueeze(-1)).squeeze(-1)
return pos_w
def transform_body2world(rot_wb, t_w, pos_b):
"""
Transform pos_b from body frame to world frame using quaternion q_wb and t_w.
rot_wb: (..., 3, 3)
t_w: (..., 3)
pos_b: (..., 3)
"""
return rotate_body2world(rot_wb, pos_b) + t_w
def state_body2world(pos_w, rot_wb, pos_b, vel_b, acc_b):
pos_b = transform_body2world(rot_wb, pos_w, pos_b)
vel_b = rotate_body2world(rot_wb, vel_b)
acc_b = rotate_body2world(rot_wb, acc_b)
return pos_b, vel_b, acc_b
if __name__ == '__main__':
CoordTransform = StateTransform()