move network inference to depth callback

This commit is contained in:
TJU_Lu 2024-12-17 22:03:45 +08:00
parent 35cd195a10
commit e3f7af7f01

View File

@ -82,7 +82,7 @@ class YopoNet:
self.odom_ref_sub = rospy.Subscriber("/juliett/state_ref/odom", Odometry, self.callback_odometry_ref, queue_size=1, tcp_nodelay=True) self.odom_ref_sub = rospy.Subscriber("/juliett/state_ref/odom", Odometry, self.callback_odometry_ref, queue_size=1, tcp_nodelay=True)
self.depth_sub = rospy.Subscriber(depth_topic, Image, self.callback_depth, queue_size=1, tcp_nodelay=True) self.depth_sub = rospy.Subscriber(depth_topic, Image, self.callback_depth, queue_size=1, tcp_nodelay=True)
self.goal_sub = rospy.Subscriber("/move_base_simple/goal", PoseStamped, self.callback_set_goal, queue_size=1) self.goal_sub = rospy.Subscriber("/move_base_simple/goal", PoseStamped, self.callback_set_goal, queue_size=1)
self.timer_net = rospy.Timer(rospy.Duration(1. / self.config['network_frequency']), self.test_policy) # self.timer_net = rospy.Timer(rospy.Duration(1. / self.config['network_frequency']), self.test_policy)
print("YOPO Net Node Ready!") print("YOPO Net Node Ready!")
rospy.spin() rospy.spin()
@ -134,6 +134,13 @@ class YopoNet:
return obs_norm return obs_norm
def callback_depth(self, data): def callback_depth(self, data):
# start a new round if no new odom_ref (we will stop publish odom_ref when arrive and assume that the rate of odom is higher than depth)
if not self.new_odom:
self.odom_ref_init = False
return
self.new_odom = False
# 1. Depth Image Process
min_dis, max_dis = 0.03, 20.0 min_dis, max_dis = 0.03, 20.0
scale = {'435': 0.001, 'flightmare': 1.0}.get(self.env, 1.0) scale = {'435': 0.001, 'flightmare': 1.0}.get(self.env, 1.0)
@ -147,24 +154,20 @@ class YopoNet:
depth_ = np.minimum(depth_ * scale, max_dis) / max_dis depth_ = np.minimum(depth_ * scale, max_dis) / max_dis
# interpolated the nan value (experiment shows that treating nan directly as 0 produces similar results) # interpolated the nan value (experiment shows that treating nan directly as 0 produces similar results)
start = time.time() interpolation_start = time.time()
nan_mask = np.isnan(depth_) | (depth_ < min_dis) nan_mask = np.isnan(depth_) | (depth_ < min_dis)
interpolated_image = cv2.inpaint(np.uint8(depth_ * 255), np.uint8(nan_mask), 1, cv2.INPAINT_NS) interpolated_image = cv2.inpaint(np.uint8(depth_ * 255), np.uint8(nan_mask), 1, cv2.INPAINT_NS)
interpolated_image = interpolated_image.astype(np.float32) / 255.0 interpolated_image = interpolated_image.astype(np.float32) / 255.0
depth_ = interpolated_image.reshape([1, 1, self.height, self.width]) depth_ = interpolated_image.reshape([1, 1, self.height, self.width])
if self.verbose: if self.verbose:
self.time_interpolation = self.time_interpolation + (time.time() - start) self.time_interpolation = self.time_interpolation + (time.time() - interpolation_start)
self.count_interpolation = self.count_interpolation + 1 self.count_interpolation = self.count_interpolation + 1
print(f"Time Consuming: depth-interpolation: {1000 * self.time_interpolation / self.count_interpolation:.2f}ms") print(f"Time Consuming: depth-interpolation: {1000 * self.time_interpolation / self.count_interpolation:.2f}ms")
# cv2.imshow("1", depth_[0][0]) # cv2.imshow("1", depth_[0][0])
# cv2.waitKey(1) # cv2.waitKey(1)
self.depth = depth_.astype(np.float32) self.depth = depth_.astype(np.float32)
self.new_depth = True
# TODO: Move the test_policy to callback_depth directly? # 2. YOPO Network Inference
def test_policy(self, _timer):
if self.new_depth and self.new_odom:
self.new_odom, self.new_depth = False, False
obs = self.process_odom() obs = self.process_odom()
odom_sec = self.odom.header.stamp.to_sec() odom_sec = self.odom.header.stamp.to_sec()
@ -221,9 +224,6 @@ class YopoNet:
all_endstate_pred.layout.dim[1].label = "endstate_and_score_num" all_endstate_pred.layout.dim[1].label = "endstate_and_score_num"
self.all_endstate_pub.publish(all_endstate_pred) self.all_endstate_pub.publish(all_endstate_pred)
self.goal_pub.publish(Float32MultiArray(data=self.goal)) self.goal_pub.publish(Float32MultiArray(data=self.goal))
# start a new round
elif not self.new_odom:
self.odom_ref_init = False
def process_output(self, network_output, return_all_preds=False): def process_output(self, network_output, return_all_preds=False):
if network_output.shape[0] != 1: if network_output.shape[0] != 1:
@ -313,17 +313,14 @@ def parser():
# In realworld flight: visualize=False; use_tensorrt=True, and ensure the pitch_angle consistent with your platform # In realworld flight: visualize=False; use_tensorrt=True, and ensure the pitch_angle consistent with your platform
# When modifying the pitch_angle, there's no need to re-collect and re-train, as all predictions are in the camera coordinate system # When modifying the pitch_angle, there's no need to re-collect and re-train, as all predictions are in the camera coordinate system
# Change the flight speed at traj_opt.yaml and there's no need to re-collect and re-train
def main(): def main():
args = parser().parse_args() args = parser().parse_args()
rsg_root = os.path.dirname(os.path.abspath(__file__)) rsg_root = os.path.dirname(os.path.abspath(__file__))
if args.use_tensorrt: weight = args.trt_file if args.use_tensorrt else f"{rsg_root}/saved/YOPO_{args.trial}/Policy/epoch{args.epoch}_iter{args.iter}.pth"
weight = args.trt_file
else:
weight = rsg_root + "/saved/YOPO_{}/Policy/epoch{}_iter{}.pth".format(args.trial, args.epoch, args.iter)
print("load weight from:", weight) print("load weight from:", weight)
settings = {'use_tensorrt': args.use_tensorrt, settings = {'use_tensorrt': args.use_tensorrt,
'network_frequency': 30,
'img_height': 96, 'img_height': 96,
'img_width': 160, 'img_width': 160,
'goal': [20, 20, 2], # the goal 'goal': [20, 20, 2], # the goal