move network inference to depth callback

This commit is contained in:
TJU_Lu 2024-12-17 22:03:45 +08:00
parent 35cd195a10
commit e3f7af7f01

View File

@ -82,7 +82,7 @@ class YopoNet:
self.odom_ref_sub = rospy.Subscriber("/juliett/state_ref/odom", Odometry, self.callback_odometry_ref, queue_size=1, tcp_nodelay=True) self.odom_ref_sub = rospy.Subscriber("/juliett/state_ref/odom", Odometry, self.callback_odometry_ref, queue_size=1, tcp_nodelay=True)
self.depth_sub = rospy.Subscriber(depth_topic, Image, self.callback_depth, queue_size=1, tcp_nodelay=True) self.depth_sub = rospy.Subscriber(depth_topic, Image, self.callback_depth, queue_size=1, tcp_nodelay=True)
self.goal_sub = rospy.Subscriber("/move_base_simple/goal", PoseStamped, self.callback_set_goal, queue_size=1) self.goal_sub = rospy.Subscriber("/move_base_simple/goal", PoseStamped, self.callback_set_goal, queue_size=1)
self.timer_net = rospy.Timer(rospy.Duration(1. / self.config['network_frequency']), self.test_policy) # self.timer_net = rospy.Timer(rospy.Duration(1. / self.config['network_frequency']), self.test_policy)
print("YOPO Net Node Ready!") print("YOPO Net Node Ready!")
rospy.spin() rospy.spin()
@ -134,6 +134,13 @@ class YopoNet:
return obs_norm return obs_norm
def callback_depth(self, data): def callback_depth(self, data):
# start a new round if no new odom_ref (we will stop publish odom_ref when arrive and assume that the rate of odom is higher than depth)
if not self.new_odom:
self.odom_ref_init = False
return
self.new_odom = False
# 1. Depth Image Process
min_dis, max_dis = 0.03, 20.0 min_dis, max_dis = 0.03, 20.0
scale = {'435': 0.001, 'flightmare': 1.0}.get(self.env, 1.0) scale = {'435': 0.001, 'flightmare': 1.0}.get(self.env, 1.0)
@ -147,83 +154,76 @@ class YopoNet:
depth_ = np.minimum(depth_ * scale, max_dis) / max_dis depth_ = np.minimum(depth_ * scale, max_dis) / max_dis
# interpolated the nan value (experiment shows that treating nan directly as 0 produces similar results) # interpolated the nan value (experiment shows that treating nan directly as 0 produces similar results)
start = time.time() interpolation_start = time.time()
nan_mask = np.isnan(depth_) | (depth_ < min_dis) nan_mask = np.isnan(depth_) | (depth_ < min_dis)
interpolated_image = cv2.inpaint(np.uint8(depth_ * 255), np.uint8(nan_mask), 1, cv2.INPAINT_NS) interpolated_image = cv2.inpaint(np.uint8(depth_ * 255), np.uint8(nan_mask), 1, cv2.INPAINT_NS)
interpolated_image = interpolated_image.astype(np.float32) / 255.0 interpolated_image = interpolated_image.astype(np.float32) / 255.0
depth_ = interpolated_image.reshape([1, 1, self.height, self.width]) depth_ = interpolated_image.reshape([1, 1, self.height, self.width])
if self.verbose: if self.verbose:
self.time_interpolation = self.time_interpolation + (time.time() - start) self.time_interpolation = self.time_interpolation + (time.time() - interpolation_start)
self.count_interpolation = self.count_interpolation + 1 self.count_interpolation = self.count_interpolation + 1
print(f"Time Consuming: depth-interpolation: {1000 * self.time_interpolation / self.count_interpolation:.2f}ms") print(f"Time Consuming: depth-interpolation: {1000 * self.time_interpolation / self.count_interpolation:.2f}ms")
# cv2.imshow("1", depth_[0][0]) # cv2.imshow("1", depth_[0][0])
# cv2.waitKey(1) # cv2.waitKey(1)
self.depth = depth_.astype(np.float32) self.depth = depth_.astype(np.float32)
self.new_depth = True
# TODO: Move the test_policy to callback_depth directly? # 2. YOPO Network Inference
def test_policy(self, _timer): obs = self.process_odom()
if self.new_depth and self.new_odom: odom_sec = self.odom.header.stamp.to_sec()
self.new_odom, self.new_depth = False, False
obs = self.process_odom()
odom_sec = self.odom.header.stamp.to_sec()
# input prepare # input prepare
time0 = time.time() time0 = time.time()
depth = torch.from_numpy(self.depth).to(self.device, non_blocking=True) # (non_blocking: copying speed 3x) depth = torch.from_numpy(self.depth).to(self.device, non_blocking=True) # (non_blocking: copying speed 3x)
obs_norm_input = self.prepare_input_observation(obs) obs_norm_input = self.prepare_input_observation(obs)
obs_norm_input = obs_norm_input.to(self.device, non_blocking=True) obs_norm_input = obs_norm_input.to(self.device, non_blocking=True)
# torch.cuda.synchronize() # torch.cuda.synchronize()
time1 = time.time() time1 = time.time()
# Forward (TensorRT: inference speed increased by 5x) # Forward (TensorRT: inference speed increased by 5x)
with torch.no_grad(): with torch.no_grad():
network_output = self.policy(depth, obs_norm_input) network_output = self.policy(depth, obs_norm_input)
network_output = network_output.cpu().numpy() # torch.cuda.synchronize() is not needed here network_output = network_output.cpu().numpy() # torch.cuda.synchronize() is not needed here
time2 = time.time() time2 = time.time()
# Replacing PyTorch operation on CUDA with NumPy operation on CPU (speed increased by 10x) # Replacing PyTorch operation on CUDA with NumPy operation on CPU (speed increased by 10x)
endstate_pred, score_pred = self.process_output(network_output, return_all_preds=self.visualize) endstate_pred, score_pred = self.process_output(network_output, return_all_preds=self.visualize)
time3 = time.time() time3 = time.time()
# Vectorization: transform the prediction(P V A in body frame) to the world frame with the attitude (without the position) # Vectorization: transform the prediction(P V A in body frame) to the world frame with the attitude (without the position)
endstate_c = endstate_pred.T.reshape(-1, 3, 3) endstate_c = endstate_pred.T.reshape(-1, 3, 3)
endstate_w = np.matmul(self.Rotation_wc, endstate_c) endstate_w = np.matmul(self.Rotation_wc, endstate_c)
endstate_w = endstate_w.reshape(-1, 9).T endstate_w = endstate_w.reshape(-1, 9).T
if self.verbose: if self.verbose:
self.time_prepare = self.time_prepare + (time1 - time0) self.time_prepare = self.time_prepare + (time1 - time0)
self.time_forward = self.time_forward + (time2 - time1) self.time_forward = self.time_forward + (time2 - time1)
self.time_process = self.time_process + (time3 - time2) self.time_process = self.time_process + (time3 - time2)
self.count = self.count + 1 self.count = self.count + 1
print(f"Time Consuming: data-prepare: {1000 * self.time_prepare / self.count:.2f}ms; " print(f"Time Consuming: data-prepare: {1000 * self.time_prepare / self.count:.2f}ms; "
f"network-inference: {1000 * self.time_forward / self.count:.2f}ms; " f"network-inference: {1000 * self.time_forward / self.count:.2f}ms; "
f"post-process: {1000 * self.time_process / self.count:.2f}ms") f"post-process: {1000 * self.time_process / self.count:.2f}ms")
# publish # publish
if not self.visualize: if not self.visualize:
endstate_pred_to_pub = Float32MultiArray(data=endstate_w.reshape(-1)) endstate_pred_to_pub = Float32MultiArray(data=endstate_w.reshape(-1))
endstate_pred_to_pub.layout.data_offset = int(1000 * odom_sec) % 1000000 # 预测时用的里程计时间戳(ms) endstate_pred_to_pub.layout.data_offset = int(1000 * odom_sec) % 1000000 # 预测时用的里程计时间戳(ms)
self.endstate_pub.publish(endstate_pred_to_pub) self.endstate_pub.publish(endstate_pred_to_pub)
else: else:
action_id = np.argmin(score_pred) action_id = np.argmin(score_pred)
best_endstate_pred = endstate_w[:, action_id].reshape(-1) best_endstate_pred = endstate_w[:, action_id].reshape(-1)
endstate_pred_to_pub = Float32MultiArray(data=best_endstate_pred) endstate_pred_to_pub = Float32MultiArray(data=best_endstate_pred)
endstate_pred_to_pub.layout.data_offset = int(1000 * odom_sec) % 1000000 # 预测时用的里程计时间戳(ms) endstate_pred_to_pub.layout.data_offset = int(1000 * odom_sec) % 1000000 # 预测时用的里程计时间戳(ms)
self.endstate_pub.publish(endstate_pred_to_pub) self.endstate_pub.publish(endstate_pred_to_pub)
# visualization # visualization
endstate_score_preds = np.vstack([endstate_w, score_pred]) endstate_score_preds = np.vstack([endstate_w, score_pred])
all_endstate_pred = Float32MultiArray(data=endstate_score_preds.T.reshape(-1)) all_endstate_pred = Float32MultiArray(data=endstate_score_preds.T.reshape(-1))
all_endstate_pred.layout.dim.append(MultiArrayDimension()) all_endstate_pred.layout.dim.append(MultiArrayDimension())
all_endstate_pred.layout.dim[0].size = endstate_score_preds.shape[1] all_endstate_pred.layout.dim[0].size = endstate_score_preds.shape[1]
all_endstate_pred.layout.dim[0].label = "primitive_num" all_endstate_pred.layout.dim[0].label = "primitive_num"
all_endstate_pred.layout.dim.append(MultiArrayDimension()) all_endstate_pred.layout.dim.append(MultiArrayDimension())
all_endstate_pred.layout.dim[1].size = endstate_score_preds.shape[0] all_endstate_pred.layout.dim[1].size = endstate_score_preds.shape[0]
all_endstate_pred.layout.dim[1].label = "endstate_and_score_num" all_endstate_pred.layout.dim[1].label = "endstate_and_score_num"
self.all_endstate_pub.publish(all_endstate_pred) self.all_endstate_pub.publish(all_endstate_pred)
self.goal_pub.publish(Float32MultiArray(data=self.goal)) self.goal_pub.publish(Float32MultiArray(data=self.goal))
# start a new round
elif not self.new_odom:
self.odom_ref_init = False
def process_output(self, network_output, return_all_preds=False): def process_output(self, network_output, return_all_preds=False):
if network_output.shape[0] != 1: if network_output.shape[0] != 1:
@ -313,17 +313,14 @@ def parser():
# In realworld flight: visualize=False; use_tensorrt=True, and ensure the pitch_angle consistent with your platform # In realworld flight: visualize=False; use_tensorrt=True, and ensure the pitch_angle consistent with your platform
# When modifying the pitch_angle, there's no need to re-collect and re-train, as all predictions are in the camera coordinate system # When modifying the pitch_angle, there's no need to re-collect and re-train, as all predictions are in the camera coordinate system
# Change the flight speed at traj_opt.yaml and there's no need to re-collect and re-train
def main(): def main():
args = parser().parse_args() args = parser().parse_args()
rsg_root = os.path.dirname(os.path.abspath(__file__)) rsg_root = os.path.dirname(os.path.abspath(__file__))
if args.use_tensorrt: weight = args.trt_file if args.use_tensorrt else f"{rsg_root}/saved/YOPO_{args.trial}/Policy/epoch{args.epoch}_iter{args.iter}.pth"
weight = args.trt_file
else:
weight = rsg_root + "/saved/YOPO_{}/Policy/epoch{}_iter{}.pth".format(args.trial, args.epoch, args.iter)
print("load weight from:", weight) print("load weight from:", weight)
settings = {'use_tensorrt': args.use_tensorrt, settings = {'use_tensorrt': args.use_tensorrt,
'network_frequency': 30,
'img_height': 96, 'img_height': 96,
'img_width': 160, 'img_width': 160,
'goal': [20, 20, 2], # the goal 'goal': [20, 20, 2], # the goal