simplify inference node, vectorize NumPy operations, fix timing bug.

This commit is contained in:
TJU_Lu 2024-12-17 21:10:46 +08:00
parent 59364bef31
commit 35cd195a10
2 changed files with 77 additions and 96 deletions

View File

@ -88,7 +88,7 @@ class YopoNet:
def callback_set_goal(self, data):
self.goal = np.asarray([data.pose.position.x, data.pose.position.y, 2])
print("New goal:", self.goal)
print("New Goal:", self.goal)
# the first frame
def callback_odometry(self, data):
@ -104,19 +104,20 @@ class YopoNet:
self.new_odom = True
def process_odom(self):
# Rwb
# Rwb -> Rwc -> Rcw
Rotation_wb = R.from_quat([self.odom.pose.pose.orientation.x, self.odom.pose.pose.orientation.y,
self.odom.pose.pose.orientation.z, self.odom.pose.pose.orientation.w]).as_matrix()
self.Rotation_wc = np.dot(Rotation_wb, self.Rotation_bc)
Rotation_cw = self.Rotation_wc.T
if self.odom_ref_init:
odom_data = self.odom_ref
# vel_b
vel_w = np.array([odom_data.twist.twist.linear.x, odom_data.twist.twist.linear.y, odom_data.twist.twist.linear.z])
vel_b = np.dot(np.linalg.inv(self.Rotation_wc), vel_w)
vel_b = np.dot(Rotation_cw, vel_w)
# acc_b (acc stored in angular in our ref_state topic)
acc_w = np.array([odom_data.twist.twist.angular.x, odom_data.twist.twist.angular.y, odom_data.twist.twist.angular.z])
acc_b = np.dot(np.linalg.inv(self.Rotation_wc), acc_w)
acc_b = np.dot(Rotation_cw, acc_w)
else:
odom_data = self.odom
vel_b = np.array([0.0, 0.0, 0.0])
@ -125,7 +126,7 @@ class YopoNet:
# pose and goal_dir
pos = np.array([odom_data.pose.pose.position.x, odom_data.pose.pose.position.y, odom_data.pose.pose.position.z])
goal_w = (self.goal - pos) / np.linalg.norm(self.goal - pos)
goal_b = np.dot(np.linalg.inv(self.Rotation_wc), goal_w)
goal_b = np.dot(Rotation_cw, goal_w)
vel_acc = np.concatenate((vel_b, acc_b), axis=0)
vel_acc_norm = self.normalize_obs(vel_acc[np.newaxis, :])
@ -154,8 +155,7 @@ class YopoNet:
if self.verbose:
self.time_interpolation = self.time_interpolation + (time.time() - start)
self.count_interpolation = self.count_interpolation + 1
print("Time Consuming: interpolation:", self.time_interpolation / self.count_interpolation)
print(f"Time Consuming: depth-interpolation: {1000 * self.time_interpolation / self.count_interpolation:.2f}ms")
# cv2.imshow("1", depth_[0][0])
# cv2.waitKey(1)
self.depth = depth_.astype(np.float32)
@ -164,8 +164,7 @@ class YopoNet:
# TODO: Move the test_policy to callback_depth directly?
def test_policy(self, _timer):
if self.new_depth and self.new_odom:
self.new_odom = False
self.new_depth = False
self.new_odom, self.new_depth = False, False
obs = self.process_odom()
odom_sec = self.odom.header.stamp.to_sec()
@ -176,49 +175,29 @@ class YopoNet:
obs_norm_input = obs_norm_input.to(self.device, non_blocking=True)
# torch.cuda.synchronize()
# forward
if self.use_trt: # TensorRT (inference speed increased by 10x)
time1 = time.time()
trt_output = self.policy(depth, obs_norm_input)
time2 = time.time()
endstate_pred, score_pred = self.trt_process(trt_output, return_all_preds=self.visualize)
endstate_pred = endstate_pred.squeeze()
time3 = time.time()
else:
time1 = time.time()
endstate_pred, score_pred = self.policy.predict(depth, obs_norm_input, return_all_preds=self.visualize)
endstate_pred = endstate_pred.cpu().numpy().squeeze()
score_pred = score_pred.cpu().numpy()
time2 = time3 = time.time()
time1 = time.time()
# Forward (TensorRT: inference speed increased by 5x)
with torch.no_grad():
network_output = self.policy(depth, obs_norm_input)
network_output = network_output.cpu().numpy() # torch.cuda.synchronize() is not needed here
time2 = time.time()
# Replacing PyTorch operation on CUDA with NumPy operation on CPU (speed increased by 10x)
endstate_pred, score_pred = self.process_output(network_output, return_all_preds=self.visualize)
time3 = time.time()
# Transform the prediction(body frame) to the world frame with the attitude in inference
# Replacing PyTorch calculations on CUDA with NumPy calculations on the CPU (speed increased by 10x)
endstate_b = endstate_pred
endstate_w = np.zeros_like(endstate_b)
traj_num = self.lattice_space.horizon_num * self.lattice_space.vertical_num if self.visualize else 1
Pb, Vb, Ab = [np.zeros((3, traj_num)) for _ in range(3)]
for i in range(3):
Pb[i] = endstate_b[3 * i]
Vb[i] = endstate_b[3 * i + 1]
Ab[i] = endstate_b[3 * i + 2]
# pos_actual = np.array([self.odom.pose.pose.position.x,
# self.odom.pose.pose.position.y,
# self.odom.pose.pose.position.z])
Pw = np.dot(self.Rotation_wc, Pb) # + np.tile(pos_actual, (15, 1)).T
Vw = np.dot(self.Rotation_wc, Vb)
Aw = np.dot(self.Rotation_wc, Ab)
for i in range(3):
endstate_w[3 * i] = Pw[i]
endstate_w[3 * i + 1] = Vw[i]
endstate_w[3 * i + 2] = Aw[i]
# Vectorization: transform the prediction(P V A in body frame) to the world frame with the attitude (without the position)
endstate_c = endstate_pred.T.reshape(-1, 3, 3)
endstate_w = np.matmul(self.Rotation_wc, endstate_c)
endstate_w = endstate_w.reshape(-1, 9).T
if self.verbose:
self.time_prepare = self.time_prepare + (time1 - time0)
self.time_forward = self.time_forward + (time2 - time1)
self.time_process = self.time_process + (time3 - time2)
self.count = self.count + 1
print("Time Consuming: prepare:", self.time_prepare / self.count, "; forward:", self.time_forward / self.count,
"; process:", self.time_process / self.count)
print(f"Time Consuming: data-prepare: {1000 * self.time_prepare / self.count:.2f}ms; "
f"network-inference: {1000 * self.time_forward / self.count:.2f}ms; "
f"post-process: {1000 * self.time_process / self.count:.2f}ms")
# publish
if not self.visualize:
@ -232,7 +211,7 @@ class YopoNet:
endstate_pred_to_pub.layout.data_offset = int(1000 * odom_sec) % 1000000 # 预测时用的里程计时间戳(ms)
self.endstate_pub.publish(endstate_pred_to_pub)
# visualization
endstate_score_preds = np.concatenate((endstate_w, score_pred), axis=0)
endstate_score_preds = np.vstack([endstate_w, score_pred])
all_endstate_pred = Float32MultiArray(data=endstate_score_preds.T.reshape(-1))
all_endstate_pred.layout.dim.append(MultiArrayDimension())
all_endstate_pred.layout.dim[0].size = endstate_score_preds.shape[1]
@ -246,28 +225,25 @@ class YopoNet:
elif not self.new_odom:
self.odom_ref_init = False
def trt_process(self, input_tensor: torch.Tensor, return_all_preds=False) -> torch.Tensor:
batch_size = input_tensor.shape[0]
input_tensor = input_tensor.cpu().numpy()
input_tensor = input_tensor.reshape(batch_size, 10, self.lattice_space.horizon_num * self.lattice_space.vertical_num)
endstate_pred = input_tensor[:, 0:9, :]
score_pred = input_tensor[:, 9, :]
def process_output(self, network_output, return_all_preds=False):
if network_output.shape[0] != 1:
raise ValueError("batch of output values must be 1 in test!")
network_output = network_output.reshape(10, self.lattice_space.horizon_num * self.lattice_space.vertical_num)
endstate_pred = network_output[0:9, :]
score_pred = network_output[9, :]
if not return_all_preds:
endstate_prediction = np.zeros((batch_size, 9))
score_prediction = np.zeros((batch_size, 1))
for i in range(0, batch_size):
action_id = np.argmin(score_pred[i])
lattice_id = self.lattice_space.horizon_num * self.lattice_space.vertical_num - 1 - action_id
endstate_prediction[i] = self.pred_to_endstate(np.expand_dims(endstate_pred[i, :, action_id], axis=0), lattice_id)
score_prediction[i] = score_pred[i, action_id]
action_id = np.argmin(score_pred)
lattice_id = self.lattice_space.horizon_num * self.lattice_space.vertical_num - 1 - action_id
endstate_prediction = self.pred_to_endstate(endstate_pred[:, action_id], lattice_id)
endstate_prediction = endstate_prediction[:, np.newaxis]
score_prediction = score_pred[action_id]
else:
endstate_prediction = np.zeros_like(endstate_pred)
score_prediction = score_pred
for i in range(0, self.lattice_space.horizon_num * self.lattice_space.vertical_num):
for i in range(self.lattice_space.horizon_num * self.lattice_space.vertical_num):
lattice_id = self.lattice_space.horizon_num * self.lattice_space.vertical_num - 1 - i
endstate = self.pred_to_endstate(endstate_pred[:, :, i], lattice_id)
endstate_prediction[:, :, i] = endstate
endstate_prediction[:, i] = self.pred_to_endstate(endstate_pred[:, i], lattice_id)
return endstate_prediction, score_prediction
@ -276,45 +252,40 @@ class YopoNet:
convert the observation from body frame to primitive frame,
and then concatenate it with the depth features (to ensure the translational invariance)
"""
obs_return = np.ones((obs.shape[0], self.lattice_space.vertical_num, self.lattice_space.horizon_num, obs.shape[1]), dtype=np.float32)
if obs.shape[0] != 1:
raise ValueError("batch of input observations must be 1 in test!")
obs_return = np.ones((obs.shape[0], obs.shape[1], self.lattice_space.vertical_num, self.lattice_space.horizon_num), dtype=np.float32)
id = 0
v_b = obs[:, 0:3]
a_b = obs[:, 3:6]
g_b = obs[:, 6:9]
obs_reshaped = obs.reshape(3, 3)
for i in range(self.lattice_space.vertical_num - 1, -1, -1):
for j in range(self.lattice_space.horizon_num - 1, -1, -1):
Rbp = self.lattice_primitive.getRotation(id)
v_p = np.dot(Rbp.T, v_b.T).T
a_p = np.dot(Rbp.T, a_b.T).T
g_p = np.dot(Rbp.T, g_b.T).T
obs_return[:, i, j, 0:3] = v_p
obs_return[:, i, j, 3:6] = a_p
obs_return[:, i, j, 6:9] = g_p
# obs_return[:, i, j, 0:6] = self.normalize_obs(obs_return[:, i, j, 0:6])
obs_return_reshaped = np.dot(obs_reshaped, Rbp)
obs_return[:, :, i, j] = obs_return_reshaped.reshape(9)
id = id + 1
obs_return = np.transpose(obs_return, [0, 3, 1, 2])
return torch.from_numpy(obs_return)
def pred_to_endstate(self, endstate_pred: np.ndarray, id: int):
"""
Transform the predicted state to the body frame.
"""
delta_yaw = endstate_pred[:, 0] * self.lattice_primitive.yaw_diff
delta_pitch = endstate_pred[:, 1] * self.lattice_primitive.pitch_diff
radio = endstate_pred[:, 2] * self.lattice_space.radio_range + self.lattice_space.radio_range
delta_yaw = endstate_pred[0] * self.lattice_primitive.yaw_diff
delta_pitch = endstate_pred[1] * self.lattice_primitive.pitch_diff
radio = endstate_pred[2] * self.lattice_space.radio_range + self.lattice_space.radio_range
yaw, pitch = self.lattice_primitive.getAngleLattice(id)
endstate_x = np.cos(pitch + delta_pitch) * np.cos(yaw + delta_yaw) * radio
endstate_y = np.cos(pitch + delta_pitch) * np.sin(yaw + delta_yaw) * radio
endstate_z = np.sin(pitch + delta_pitch) * radio
endstate_p = np.stack((endstate_x, endstate_y, endstate_z), axis=1)
endstate_p = np.array((endstate_x, endstate_y, endstate_z))
endstate_vp = endstate_pred[:, 3:6] * self.lattice_space.vel_max
endstate_ap = endstate_pred[:, 6:9] * self.lattice_space.acc_max
Rbp = self.lattice_primitive.getRotation(id)
endstate_vb = np.matmul(Rbp, endstate_vp.T).T
endstate_ab = np.matmul(Rbp, endstate_ap.T).T
endstate = np.concatenate((endstate_p, endstate_vb, endstate_ab), axis=1)
endstate[:, [0, 1, 2, 3, 4, 5, 6, 7, 8]] = endstate[:, [0, 3, 6, 1, 4, 7, 2, 5, 8]]
endstate_vp = endstate_pred[3:6] * self.lattice_space.vel_max
endstate_ap = endstate_pred[6:9] * self.lattice_space.acc_max
Rpb = self.lattice_primitive.getRotation(id).T
endstate_vb = np.matmul(endstate_vp, Rpb)
endstate_ab = np.matmul(endstate_ap, Rpb)
endstate = np.concatenate((endstate_p, endstate_vb, endstate_ab))
endstate[[0, 1, 2, 3, 4, 5, 6, 7, 8]] = endstate[[0, 3, 6, 1, 4, 7, 2, 5, 8]]
return endstate
def normalize_obs(self, vel_acc):
@ -326,11 +297,8 @@ class YopoNet:
depth = np.zeros(shape=[1, 1, self.height, self.width], dtype=np.float32)
obs = np.zeros(shape=[1, 9], dtype=np.float32)
obs_input = self.prepare_input_observation(obs)
if self.use_trt:
trt_output = self.policy(torch.from_numpy(depth).to(self.device), obs_input.to(self.device))
self.trt_process(trt_output, return_all_preds=True)
else:
self.policy.predict(torch.from_numpy(depth).to(self.device), obs_input.to(self.device), return_all_preds=True)
network_output = self.policy(torch.from_numpy(depth).to(self.device), obs_input.to(self.device))
self.process_output(network_output.cpu().numpy(), return_all_preds=True)
def parser():
@ -339,6 +307,7 @@ def parser():
parser.add_argument("--trial", type=int, default=1, help="trial number")
parser.add_argument("--epoch", type=int, default=0, help="epoch number")
parser.add_argument("--iter", type=int, default=0, help="iter number")
parser.add_argument("--trt_file", type=str, default='yopo_trt.pth', help="tensorrt filename")
return parser
@ -348,9 +317,10 @@ def main():
args = parser().parse_args()
rsg_root = os.path.dirname(os.path.abspath(__file__))
if args.use_tensorrt:
weight = "yopo_trt.pth"
weight = args.trt_file
else:
weight = rsg_root + "/saved/YOPO_{}/Policy/epoch{}_iter{}.pth".format(args.trial, args.epoch, args.iter)
print("load weight from:", weight)
settings = {'use_tensorrt': args.use_tensorrt,
'network_frequency': 30,

View File

@ -46,6 +46,7 @@ def parser():
parser.add_argument("--trial", type=int, default=1, help="trial number")
parser.add_argument("--epoch", type=int, default=0, help="epoch number")
parser.add_argument("--iter", type=int, default=0, help="iter number")
parser.add_argument("--fp16_mode", type=int, default=1, help="fp16 or fp32")
parser.add_argument("--filename", type=str, default='yopo_trt.pth', help="output file name")
return parser
@ -75,6 +76,7 @@ if __name__ == "__main__":
saved_variables = torch.load(weight, map_location=device)
model.policy.load_state_dict(saved_variables["state_dict"], strict=False)
model.policy.set_training_mode(False)
torch.set_grad_enabled(False)
lattice_space = saved_variables["data"]["lattice_space"]
lattice_primitive = saved_variables["data"]["lattice_primitive"]
@ -86,7 +88,7 @@ if __name__ == "__main__":
obs_input = prapare_input_observation(obs, lattice_space, lattice_primitive)
depth_in = torch.from_numpy(depth).cuda()
obs_in = torch.from_numpy(obs_input).cuda()
model_trt = torch2trt(model.policy, [depth_in, obs_in])
model_trt = torch2trt(model.policy, [depth_in, obs_in], fp16_mode=args.fp16_mode)
torch.save(model_trt.state_dict(), args.filename)
print("TensorRT Transfer Finish!")
@ -95,18 +97,27 @@ if __name__ == "__main__":
# model_trt.load_state_dict(torch.load('yopo_trt.pth'))
print("Evaluation...")
# warm up...
# Warm Up...
y_trt = model_trt(depth_in, obs_in)
y = model.policy(depth_in, obs_in)
torch.cuda.synchronize()
# PyTorch Latency
torch_start = time.time()
y = model.policy(depth_in, obs_in)
torch.cuda.synchronize()
torch_end = time.time()
# TensorRT Latency
trt_start = time.time()
y_trt = model_trt(depth_in, obs_in)
torch.cuda.synchronize()
trt_end = time.time()
# Transfer Error
error = torch.mean(torch.abs(y - y_trt))
print("Torch Latency: ", 1000 * (torch_end - torch_start),
"ms, TensorRT Latency: ", 1000 * (trt_end - torch_end),
"ms, Transfer Error: ", error.item())
print(f"Torch Latency: {1000 * (torch_end - torch_start):.3f} ms, "
f"TensorRT Latency: {1000 * (trt_end - trt_start):.3f} ms, "
f"Transfer Error: {error.item():.8f}")