diff --git a/YOPO/policy/state_transform.py b/YOPO/policy/state_transform.py index e051eee..8392449 100644 --- a/YOPO/policy/state_transform.py +++ b/YOPO/policy/state_transform.py @@ -105,8 +105,8 @@ class StateTransform: return out def unnormalize_obs(self, vel_acc): - vel_acc = vel_acc[:, 0:3] * self.lattice_primitive.vel_max - vel_acc = vel_acc[:, 3:6] * self.lattice_primitive.acc_max + vel_acc[:, 0:3] = vel_acc[:, 0:3] * self.lattice_primitive.vel_max + vel_acc[:, 3:6] = vel_acc[:, 3:6] * self.lattice_primitive.acc_max return vel_acc def normalize_obs(self, vel_acc_goal):