simplify goal_norm

This commit is contained in:
TJU_Lu 2025-06-29 23:36:04 +08:00
parent cc13b0e293
commit c8ca7d096c

View File

@ -113,12 +113,10 @@ class StateTransform:
def normalize_obs(self, vel_acc_goal): def normalize_obs(self, vel_acc_goal):
vel_acc_goal[:, 0:3] = vel_acc_goal[:, 0:3] / self.lattice_primitive.vel_max vel_acc_goal[:, 0:3] = vel_acc_goal[:, 0:3] / self.lattice_primitive.vel_max
vel_acc_goal[:, 3:6] = vel_acc_goal[:, 3:6] / self.lattice_primitive.acc_max vel_acc_goal[:, 3:6] = vel_acc_goal[:, 3:6] / self.lattice_primitive.acc_max
vel_acc_goal[:, 6:9] = vel_acc_goal[:, 6:9] / self.goal_length
# Clamp the goal direction to unit length # Clamp the goal direction to unit length
goal_norm_length = vel_acc_goal[:, 6:9].norm(dim=1, keepdim=True) goal_norm = vel_acc_goal[:, 6:9].norm(dim=1, keepdim=True)
scaling = goal_norm_length.clamp(max=1.0) / (goal_norm_length + 1e-8) vel_acc_goal[:, 6:9] = vel_acc_goal[:, 6:9] / goal_norm.clamp(min=self.goal_length)
vel_acc_goal[:, 6:9] = vel_acc_goal[:, 6:9] * scaling
return vel_acc_goal return vel_acc_goal