From 35cd195a1054338175f307acf539e903d2536ed6 Mon Sep 17 00:00:00 2001 From: TJU_Lu Date: Tue, 17 Dec 2024 21:10:46 +0800 Subject: [PATCH] simplify inference node, vectorize NumPy operations, fix timing bug. --- run/test_yopo_ros.py | 152 ++++++++++++++++----------------------- run/yopo_trt_transfer.py | 21 ++++-- 2 files changed, 77 insertions(+), 96 deletions(-) diff --git a/run/test_yopo_ros.py b/run/test_yopo_ros.py index 68cca4f..cd40dfd 100644 --- a/run/test_yopo_ros.py +++ b/run/test_yopo_ros.py @@ -88,7 +88,7 @@ class YopoNet: def callback_set_goal(self, data): self.goal = np.asarray([data.pose.position.x, data.pose.position.y, 2]) - print("New goal:", self.goal) + print("New Goal:", self.goal) # the first frame def callback_odometry(self, data): @@ -104,19 +104,20 @@ class YopoNet: self.new_odom = True def process_odom(self): - # Rwb + # Rwb -> Rwc -> Rcw Rotation_wb = R.from_quat([self.odom.pose.pose.orientation.x, self.odom.pose.pose.orientation.y, self.odom.pose.pose.orientation.z, self.odom.pose.pose.orientation.w]).as_matrix() self.Rotation_wc = np.dot(Rotation_wb, self.Rotation_bc) + Rotation_cw = self.Rotation_wc.T if self.odom_ref_init: odom_data = self.odom_ref # vel_b vel_w = np.array([odom_data.twist.twist.linear.x, odom_data.twist.twist.linear.y, odom_data.twist.twist.linear.z]) - vel_b = np.dot(np.linalg.inv(self.Rotation_wc), vel_w) + vel_b = np.dot(Rotation_cw, vel_w) # acc_b (acc stored in angular in our ref_state topic) acc_w = np.array([odom_data.twist.twist.angular.x, odom_data.twist.twist.angular.y, odom_data.twist.twist.angular.z]) - acc_b = np.dot(np.linalg.inv(self.Rotation_wc), acc_w) + acc_b = np.dot(Rotation_cw, acc_w) else: odom_data = self.odom vel_b = np.array([0.0, 0.0, 0.0]) @@ -125,7 +126,7 @@ class YopoNet: # pose and goal_dir pos = np.array([odom_data.pose.pose.position.x, odom_data.pose.pose.position.y, odom_data.pose.pose.position.z]) goal_w = (self.goal - pos) / np.linalg.norm(self.goal - pos) - goal_b = np.dot(np.linalg.inv(self.Rotation_wc), goal_w) + goal_b = np.dot(Rotation_cw, goal_w) vel_acc = np.concatenate((vel_b, acc_b), axis=0) vel_acc_norm = self.normalize_obs(vel_acc[np.newaxis, :]) @@ -154,8 +155,7 @@ class YopoNet: if self.verbose: self.time_interpolation = self.time_interpolation + (time.time() - start) self.count_interpolation = self.count_interpolation + 1 - print("Time Consuming: interpolation:", self.time_interpolation / self.count_interpolation) - + print(f"Time Consuming: depth-interpolation: {1000 * self.time_interpolation / self.count_interpolation:.2f}ms") # cv2.imshow("1", depth_[0][0]) # cv2.waitKey(1) self.depth = depth_.astype(np.float32) @@ -164,8 +164,7 @@ class YopoNet: # TODO: Move the test_policy to callback_depth directly? def test_policy(self, _timer): if self.new_depth and self.new_odom: - self.new_odom = False - self.new_depth = False + self.new_odom, self.new_depth = False, False obs = self.process_odom() odom_sec = self.odom.header.stamp.to_sec() @@ -176,49 +175,29 @@ class YopoNet: obs_norm_input = obs_norm_input.to(self.device, non_blocking=True) # torch.cuda.synchronize() - # forward - if self.use_trt: # TensorRT (inference speed increased by 10x) - time1 = time.time() - trt_output = self.policy(depth, obs_norm_input) - time2 = time.time() - endstate_pred, score_pred = self.trt_process(trt_output, return_all_preds=self.visualize) - endstate_pred = endstate_pred.squeeze() - time3 = time.time() - else: - time1 = time.time() - endstate_pred, score_pred = self.policy.predict(depth, obs_norm_input, return_all_preds=self.visualize) - endstate_pred = endstate_pred.cpu().numpy().squeeze() - score_pred = score_pred.cpu().numpy() - time2 = time3 = time.time() + time1 = time.time() + # Forward (TensorRT: inference speed increased by 5x) + with torch.no_grad(): + network_output = self.policy(depth, obs_norm_input) + network_output = network_output.cpu().numpy() # torch.cuda.synchronize() is not needed here + time2 = time.time() + # Replacing PyTorch operation on CUDA with NumPy operation on CPU (speed increased by 10x) + endstate_pred, score_pred = self.process_output(network_output, return_all_preds=self.visualize) + time3 = time.time() - # Transform the prediction(body frame) to the world frame with the attitude in inference - # Replacing PyTorch calculations on CUDA with NumPy calculations on the CPU (speed increased by 10x) - endstate_b = endstate_pred - endstate_w = np.zeros_like(endstate_b) - traj_num = self.lattice_space.horizon_num * self.lattice_space.vertical_num if self.visualize else 1 - Pb, Vb, Ab = [np.zeros((3, traj_num)) for _ in range(3)] - for i in range(3): - Pb[i] = endstate_b[3 * i] - Vb[i] = endstate_b[3 * i + 1] - Ab[i] = endstate_b[3 * i + 2] - # pos_actual = np.array([self.odom.pose.pose.position.x, - # self.odom.pose.pose.position.y, - # self.odom.pose.pose.position.z]) - Pw = np.dot(self.Rotation_wc, Pb) # + np.tile(pos_actual, (15, 1)).T - Vw = np.dot(self.Rotation_wc, Vb) - Aw = np.dot(self.Rotation_wc, Ab) - for i in range(3): - endstate_w[3 * i] = Pw[i] - endstate_w[3 * i + 1] = Vw[i] - endstate_w[3 * i + 2] = Aw[i] + # Vectorization: transform the prediction(P V A in body frame) to the world frame with the attitude (without the position) + endstate_c = endstate_pred.T.reshape(-1, 3, 3) + endstate_w = np.matmul(self.Rotation_wc, endstate_c) + endstate_w = endstate_w.reshape(-1, 9).T if self.verbose: self.time_prepare = self.time_prepare + (time1 - time0) self.time_forward = self.time_forward + (time2 - time1) self.time_process = self.time_process + (time3 - time2) self.count = self.count + 1 - print("Time Consuming: prepare:", self.time_prepare / self.count, "; forward:", self.time_forward / self.count, - "; process:", self.time_process / self.count) + print(f"Time Consuming: data-prepare: {1000 * self.time_prepare / self.count:.2f}ms; " + f"network-inference: {1000 * self.time_forward / self.count:.2f}ms; " + f"post-process: {1000 * self.time_process / self.count:.2f}ms") # publish if not self.visualize: @@ -232,7 +211,7 @@ class YopoNet: endstate_pred_to_pub.layout.data_offset = int(1000 * odom_sec) % 1000000 # 预测时用的里程计时间戳(ms) self.endstate_pub.publish(endstate_pred_to_pub) # visualization - endstate_score_preds = np.concatenate((endstate_w, score_pred), axis=0) + endstate_score_preds = np.vstack([endstate_w, score_pred]) all_endstate_pred = Float32MultiArray(data=endstate_score_preds.T.reshape(-1)) all_endstate_pred.layout.dim.append(MultiArrayDimension()) all_endstate_pred.layout.dim[0].size = endstate_score_preds.shape[1] @@ -246,28 +225,25 @@ class YopoNet: elif not self.new_odom: self.odom_ref_init = False - def trt_process(self, input_tensor: torch.Tensor, return_all_preds=False) -> torch.Tensor: - batch_size = input_tensor.shape[0] - input_tensor = input_tensor.cpu().numpy() - input_tensor = input_tensor.reshape(batch_size, 10, self.lattice_space.horizon_num * self.lattice_space.vertical_num) - endstate_pred = input_tensor[:, 0:9, :] - score_pred = input_tensor[:, 9, :] + def process_output(self, network_output, return_all_preds=False): + if network_output.shape[0] != 1: + raise ValueError("batch of output values must be 1 in test!") + network_output = network_output.reshape(10, self.lattice_space.horizon_num * self.lattice_space.vertical_num) + endstate_pred = network_output[0:9, :] + score_pred = network_output[9, :] if not return_all_preds: - endstate_prediction = np.zeros((batch_size, 9)) - score_prediction = np.zeros((batch_size, 1)) - for i in range(0, batch_size): - action_id = np.argmin(score_pred[i]) - lattice_id = self.lattice_space.horizon_num * self.lattice_space.vertical_num - 1 - action_id - endstate_prediction[i] = self.pred_to_endstate(np.expand_dims(endstate_pred[i, :, action_id], axis=0), lattice_id) - score_prediction[i] = score_pred[i, action_id] + action_id = np.argmin(score_pred) + lattice_id = self.lattice_space.horizon_num * self.lattice_space.vertical_num - 1 - action_id + endstate_prediction = self.pred_to_endstate(endstate_pred[:, action_id], lattice_id) + endstate_prediction = endstate_prediction[:, np.newaxis] + score_prediction = score_pred[action_id] else: endstate_prediction = np.zeros_like(endstate_pred) score_prediction = score_pred - for i in range(0, self.lattice_space.horizon_num * self.lattice_space.vertical_num): + for i in range(self.lattice_space.horizon_num * self.lattice_space.vertical_num): lattice_id = self.lattice_space.horizon_num * self.lattice_space.vertical_num - 1 - i - endstate = self.pred_to_endstate(endstate_pred[:, :, i], lattice_id) - endstate_prediction[:, :, i] = endstate + endstate_prediction[:, i] = self.pred_to_endstate(endstate_pred[:, i], lattice_id) return endstate_prediction, score_prediction @@ -276,45 +252,40 @@ class YopoNet: convert the observation from body frame to primitive frame, and then concatenate it with the depth features (to ensure the translational invariance) """ - obs_return = np.ones((obs.shape[0], self.lattice_space.vertical_num, self.lattice_space.horizon_num, obs.shape[1]), dtype=np.float32) + if obs.shape[0] != 1: + raise ValueError("batch of input observations must be 1 in test!") + + obs_return = np.ones((obs.shape[0], obs.shape[1], self.lattice_space.vertical_num, self.lattice_space.horizon_num), dtype=np.float32) id = 0 - v_b = obs[:, 0:3] - a_b = obs[:, 3:6] - g_b = obs[:, 6:9] + obs_reshaped = obs.reshape(3, 3) for i in range(self.lattice_space.vertical_num - 1, -1, -1): for j in range(self.lattice_space.horizon_num - 1, -1, -1): Rbp = self.lattice_primitive.getRotation(id) - v_p = np.dot(Rbp.T, v_b.T).T - a_p = np.dot(Rbp.T, a_b.T).T - g_p = np.dot(Rbp.T, g_b.T).T - obs_return[:, i, j, 0:3] = v_p - obs_return[:, i, j, 3:6] = a_p - obs_return[:, i, j, 6:9] = g_p - # obs_return[:, i, j, 0:6] = self.normalize_obs(obs_return[:, i, j, 0:6]) + obs_return_reshaped = np.dot(obs_reshaped, Rbp) + obs_return[:, :, i, j] = obs_return_reshaped.reshape(9) id = id + 1 - obs_return = np.transpose(obs_return, [0, 3, 1, 2]) return torch.from_numpy(obs_return) def pred_to_endstate(self, endstate_pred: np.ndarray, id: int): """ Transform the predicted state to the body frame. """ - delta_yaw = endstate_pred[:, 0] * self.lattice_primitive.yaw_diff - delta_pitch = endstate_pred[:, 1] * self.lattice_primitive.pitch_diff - radio = endstate_pred[:, 2] * self.lattice_space.radio_range + self.lattice_space.radio_range + delta_yaw = endstate_pred[0] * self.lattice_primitive.yaw_diff + delta_pitch = endstate_pred[1] * self.lattice_primitive.pitch_diff + radio = endstate_pred[2] * self.lattice_space.radio_range + self.lattice_space.radio_range yaw, pitch = self.lattice_primitive.getAngleLattice(id) endstate_x = np.cos(pitch + delta_pitch) * np.cos(yaw + delta_yaw) * radio endstate_y = np.cos(pitch + delta_pitch) * np.sin(yaw + delta_yaw) * radio endstate_z = np.sin(pitch + delta_pitch) * radio - endstate_p = np.stack((endstate_x, endstate_y, endstate_z), axis=1) + endstate_p = np.array((endstate_x, endstate_y, endstate_z)) - endstate_vp = endstate_pred[:, 3:6] * self.lattice_space.vel_max - endstate_ap = endstate_pred[:, 6:9] * self.lattice_space.acc_max - Rbp = self.lattice_primitive.getRotation(id) - endstate_vb = np.matmul(Rbp, endstate_vp.T).T - endstate_ab = np.matmul(Rbp, endstate_ap.T).T - endstate = np.concatenate((endstate_p, endstate_vb, endstate_ab), axis=1) - endstate[:, [0, 1, 2, 3, 4, 5, 6, 7, 8]] = endstate[:, [0, 3, 6, 1, 4, 7, 2, 5, 8]] + endstate_vp = endstate_pred[3:6] * self.lattice_space.vel_max + endstate_ap = endstate_pred[6:9] * self.lattice_space.acc_max + Rpb = self.lattice_primitive.getRotation(id).T + endstate_vb = np.matmul(endstate_vp, Rpb) + endstate_ab = np.matmul(endstate_ap, Rpb) + endstate = np.concatenate((endstate_p, endstate_vb, endstate_ab)) + endstate[[0, 1, 2, 3, 4, 5, 6, 7, 8]] = endstate[[0, 3, 6, 1, 4, 7, 2, 5, 8]] return endstate def normalize_obs(self, vel_acc): @@ -326,11 +297,8 @@ class YopoNet: depth = np.zeros(shape=[1, 1, self.height, self.width], dtype=np.float32) obs = np.zeros(shape=[1, 9], dtype=np.float32) obs_input = self.prepare_input_observation(obs) - if self.use_trt: - trt_output = self.policy(torch.from_numpy(depth).to(self.device), obs_input.to(self.device)) - self.trt_process(trt_output, return_all_preds=True) - else: - self.policy.predict(torch.from_numpy(depth).to(self.device), obs_input.to(self.device), return_all_preds=True) + network_output = self.policy(torch.from_numpy(depth).to(self.device), obs_input.to(self.device)) + self.process_output(network_output.cpu().numpy(), return_all_preds=True) def parser(): @@ -339,6 +307,7 @@ def parser(): parser.add_argument("--trial", type=int, default=1, help="trial number") parser.add_argument("--epoch", type=int, default=0, help="epoch number") parser.add_argument("--iter", type=int, default=0, help="iter number") + parser.add_argument("--trt_file", type=str, default='yopo_trt.pth', help="tensorrt filename") return parser @@ -348,9 +317,10 @@ def main(): args = parser().parse_args() rsg_root = os.path.dirname(os.path.abspath(__file__)) if args.use_tensorrt: - weight = "yopo_trt.pth" + weight = args.trt_file else: weight = rsg_root + "/saved/YOPO_{}/Policy/epoch{}_iter{}.pth".format(args.trial, args.epoch, args.iter) + print("load weight from:", weight) settings = {'use_tensorrt': args.use_tensorrt, 'network_frequency': 30, diff --git a/run/yopo_trt_transfer.py b/run/yopo_trt_transfer.py index 1f464ad..e86af55 100644 --- a/run/yopo_trt_transfer.py +++ b/run/yopo_trt_transfer.py @@ -46,6 +46,7 @@ def parser(): parser.add_argument("--trial", type=int, default=1, help="trial number") parser.add_argument("--epoch", type=int, default=0, help="epoch number") parser.add_argument("--iter", type=int, default=0, help="iter number") + parser.add_argument("--fp16_mode", type=int, default=1, help="fp16 or fp32") parser.add_argument("--filename", type=str, default='yopo_trt.pth', help="output file name") return parser @@ -75,6 +76,7 @@ if __name__ == "__main__": saved_variables = torch.load(weight, map_location=device) model.policy.load_state_dict(saved_variables["state_dict"], strict=False) model.policy.set_training_mode(False) + torch.set_grad_enabled(False) lattice_space = saved_variables["data"]["lattice_space"] lattice_primitive = saved_variables["data"]["lattice_primitive"] @@ -86,7 +88,7 @@ if __name__ == "__main__": obs_input = prapare_input_observation(obs, lattice_space, lattice_primitive) depth_in = torch.from_numpy(depth).cuda() obs_in = torch.from_numpy(obs_input).cuda() - model_trt = torch2trt(model.policy, [depth_in, obs_in]) + model_trt = torch2trt(model.policy, [depth_in, obs_in], fp16_mode=args.fp16_mode) torch.save(model_trt.state_dict(), args.filename) print("TensorRT Transfer Finish!") @@ -95,18 +97,27 @@ if __name__ == "__main__": # model_trt.load_state_dict(torch.load('yopo_trt.pth')) print("Evaluation...") - # warm up... + # Warm Up... y_trt = model_trt(depth_in, obs_in) y = model.policy(depth_in, obs_in) + torch.cuda.synchronize() + # PyTorch Latency torch_start = time.time() y = model.policy(depth_in, obs_in) + torch.cuda.synchronize() torch_end = time.time() + + # TensorRT Latency + trt_start = time.time() y_trt = model_trt(depth_in, obs_in) + torch.cuda.synchronize() trt_end = time.time() + # Transfer Error error = torch.mean(torch.abs(y - y_trt)) - print("Torch Latency: ", 1000 * (torch_end - torch_start), - "ms, TensorRT Latency: ", 1000 * (trt_end - torch_end), - "ms, Transfer Error: ", error.item()) + + print(f"Torch Latency: {1000 * (torch_end - torch_start):.3f} ms, " + f"TensorRT Latency: {1000 * (trt_end - trt_start):.3f} ms, " + f"Transfer Error: {error.item():.8f}")