diff --git a/YOPO/policy/state_transform.py b/YOPO/policy/state_transform.py index 79a89c6..b4a0cba 100644 --- a/YOPO/policy/state_transform.py +++ b/YOPO/policy/state_transform.py @@ -113,12 +113,10 @@ class StateTransform: def normalize_obs(self, vel_acc_goal): vel_acc_goal[:, 0:3] = vel_acc_goal[:, 0:3] / self.lattice_primitive.vel_max vel_acc_goal[:, 3:6] = vel_acc_goal[:, 3:6] / self.lattice_primitive.acc_max - vel_acc_goal[:, 6:9] = vel_acc_goal[:, 6:9] / self.goal_length # Clamp the goal direction to unit length - goal_norm_length = vel_acc_goal[:, 6:9].norm(dim=1, keepdim=True) - scaling = goal_norm_length.clamp(max=1.0) / (goal_norm_length + 1e-8) - vel_acc_goal[:, 6:9] = vel_acc_goal[:, 6:9] * scaling + goal_norm = vel_acc_goal[:, 6:9].norm(dim=1, keepdim=True) + vel_acc_goal[:, 6:9] = vel_acc_goal[:, 6:9] / goal_norm.clamp(min=self.goal_length) return vel_acc_goal