simplify goal_norm
This commit is contained in:
parent
cc13b0e293
commit
c8ca7d096c
@ -113,12 +113,10 @@ class StateTransform:
|
|||||||
def normalize_obs(self, vel_acc_goal):
|
def normalize_obs(self, vel_acc_goal):
|
||||||
vel_acc_goal[:, 0:3] = vel_acc_goal[:, 0:3] / self.lattice_primitive.vel_max
|
vel_acc_goal[:, 0:3] = vel_acc_goal[:, 0:3] / self.lattice_primitive.vel_max
|
||||||
vel_acc_goal[:, 3:6] = vel_acc_goal[:, 3:6] / self.lattice_primitive.acc_max
|
vel_acc_goal[:, 3:6] = vel_acc_goal[:, 3:6] / self.lattice_primitive.acc_max
|
||||||
vel_acc_goal[:, 6:9] = vel_acc_goal[:, 6:9] / self.goal_length
|
|
||||||
|
|
||||||
# Clamp the goal direction to unit length
|
# Clamp the goal direction to unit length
|
||||||
goal_norm_length = vel_acc_goal[:, 6:9].norm(dim=1, keepdim=True)
|
goal_norm = vel_acc_goal[:, 6:9].norm(dim=1, keepdim=True)
|
||||||
scaling = goal_norm_length.clamp(max=1.0) / (goal_norm_length + 1e-8)
|
vel_acc_goal[:, 6:9] = vel_acc_goal[:, 6:9] / goal_norm.clamp(min=self.goal_length)
|
||||||
vel_acc_goal[:, 6:9] = vel_acc_goal[:, 6:9] * scaling
|
|
||||||
return vel_acc_goal
|
return vel_acc_goal
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user