move network inference to depth callback
This commit is contained in:
parent
35cd195a10
commit
e3f7af7f01
@ -82,7 +82,7 @@ class YopoNet:
|
|||||||
self.odom_ref_sub = rospy.Subscriber("/juliett/state_ref/odom", Odometry, self.callback_odometry_ref, queue_size=1, tcp_nodelay=True)
|
self.odom_ref_sub = rospy.Subscriber("/juliett/state_ref/odom", Odometry, self.callback_odometry_ref, queue_size=1, tcp_nodelay=True)
|
||||||
self.depth_sub = rospy.Subscriber(depth_topic, Image, self.callback_depth, queue_size=1, tcp_nodelay=True)
|
self.depth_sub = rospy.Subscriber(depth_topic, Image, self.callback_depth, queue_size=1, tcp_nodelay=True)
|
||||||
self.goal_sub = rospy.Subscriber("/move_base_simple/goal", PoseStamped, self.callback_set_goal, queue_size=1)
|
self.goal_sub = rospy.Subscriber("/move_base_simple/goal", PoseStamped, self.callback_set_goal, queue_size=1)
|
||||||
self.timer_net = rospy.Timer(rospy.Duration(1. / self.config['network_frequency']), self.test_policy)
|
# self.timer_net = rospy.Timer(rospy.Duration(1. / self.config['network_frequency']), self.test_policy)
|
||||||
print("YOPO Net Node Ready!")
|
print("YOPO Net Node Ready!")
|
||||||
rospy.spin()
|
rospy.spin()
|
||||||
|
|
||||||
@ -134,6 +134,13 @@ class YopoNet:
|
|||||||
return obs_norm
|
return obs_norm
|
||||||
|
|
||||||
def callback_depth(self, data):
|
def callback_depth(self, data):
|
||||||
|
# start a new round if no new odom_ref (we will stop publish odom_ref when arrive and assume that the rate of odom is higher than depth)
|
||||||
|
if not self.new_odom:
|
||||||
|
self.odom_ref_init = False
|
||||||
|
return
|
||||||
|
self.new_odom = False
|
||||||
|
|
||||||
|
# 1. Depth Image Process
|
||||||
min_dis, max_dis = 0.03, 20.0
|
min_dis, max_dis = 0.03, 20.0
|
||||||
scale = {'435': 0.001, 'flightmare': 1.0}.get(self.env, 1.0)
|
scale = {'435': 0.001, 'flightmare': 1.0}.get(self.env, 1.0)
|
||||||
|
|
||||||
@ -147,24 +154,20 @@ class YopoNet:
|
|||||||
depth_ = np.minimum(depth_ * scale, max_dis) / max_dis
|
depth_ = np.minimum(depth_ * scale, max_dis) / max_dis
|
||||||
|
|
||||||
# interpolated the nan value (experiment shows that treating nan directly as 0 produces similar results)
|
# interpolated the nan value (experiment shows that treating nan directly as 0 produces similar results)
|
||||||
start = time.time()
|
interpolation_start = time.time()
|
||||||
nan_mask = np.isnan(depth_) | (depth_ < min_dis)
|
nan_mask = np.isnan(depth_) | (depth_ < min_dis)
|
||||||
interpolated_image = cv2.inpaint(np.uint8(depth_ * 255), np.uint8(nan_mask), 1, cv2.INPAINT_NS)
|
interpolated_image = cv2.inpaint(np.uint8(depth_ * 255), np.uint8(nan_mask), 1, cv2.INPAINT_NS)
|
||||||
interpolated_image = interpolated_image.astype(np.float32) / 255.0
|
interpolated_image = interpolated_image.astype(np.float32) / 255.0
|
||||||
depth_ = interpolated_image.reshape([1, 1, self.height, self.width])
|
depth_ = interpolated_image.reshape([1, 1, self.height, self.width])
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
self.time_interpolation = self.time_interpolation + (time.time() - start)
|
self.time_interpolation = self.time_interpolation + (time.time() - interpolation_start)
|
||||||
self.count_interpolation = self.count_interpolation + 1
|
self.count_interpolation = self.count_interpolation + 1
|
||||||
print(f"Time Consuming: depth-interpolation: {1000 * self.time_interpolation / self.count_interpolation:.2f}ms")
|
print(f"Time Consuming: depth-interpolation: {1000 * self.time_interpolation / self.count_interpolation:.2f}ms")
|
||||||
# cv2.imshow("1", depth_[0][0])
|
# cv2.imshow("1", depth_[0][0])
|
||||||
# cv2.waitKey(1)
|
# cv2.waitKey(1)
|
||||||
self.depth = depth_.astype(np.float32)
|
self.depth = depth_.astype(np.float32)
|
||||||
self.new_depth = True
|
|
||||||
|
|
||||||
# TODO: Move the test_policy to callback_depth directly?
|
# 2. YOPO Network Inference
|
||||||
def test_policy(self, _timer):
|
|
||||||
if self.new_depth and self.new_odom:
|
|
||||||
self.new_odom, self.new_depth = False, False
|
|
||||||
obs = self.process_odom()
|
obs = self.process_odom()
|
||||||
odom_sec = self.odom.header.stamp.to_sec()
|
odom_sec = self.odom.header.stamp.to_sec()
|
||||||
|
|
||||||
@ -221,9 +224,6 @@ class YopoNet:
|
|||||||
all_endstate_pred.layout.dim[1].label = "endstate_and_score_num"
|
all_endstate_pred.layout.dim[1].label = "endstate_and_score_num"
|
||||||
self.all_endstate_pub.publish(all_endstate_pred)
|
self.all_endstate_pub.publish(all_endstate_pred)
|
||||||
self.goal_pub.publish(Float32MultiArray(data=self.goal))
|
self.goal_pub.publish(Float32MultiArray(data=self.goal))
|
||||||
# start a new round
|
|
||||||
elif not self.new_odom:
|
|
||||||
self.odom_ref_init = False
|
|
||||||
|
|
||||||
def process_output(self, network_output, return_all_preds=False):
|
def process_output(self, network_output, return_all_preds=False):
|
||||||
if network_output.shape[0] != 1:
|
if network_output.shape[0] != 1:
|
||||||
@ -313,17 +313,14 @@ def parser():
|
|||||||
|
|
||||||
# In realworld flight: visualize=False; use_tensorrt=True, and ensure the pitch_angle consistent with your platform
|
# In realworld flight: visualize=False; use_tensorrt=True, and ensure the pitch_angle consistent with your platform
|
||||||
# When modifying the pitch_angle, there's no need to re-collect and re-train, as all predictions are in the camera coordinate system
|
# When modifying the pitch_angle, there's no need to re-collect and re-train, as all predictions are in the camera coordinate system
|
||||||
|
# Change the flight speed at traj_opt.yaml and there's no need to re-collect and re-train
|
||||||
def main():
|
def main():
|
||||||
args = parser().parse_args()
|
args = parser().parse_args()
|
||||||
rsg_root = os.path.dirname(os.path.abspath(__file__))
|
rsg_root = os.path.dirname(os.path.abspath(__file__))
|
||||||
if args.use_tensorrt:
|
weight = args.trt_file if args.use_tensorrt else f"{rsg_root}/saved/YOPO_{args.trial}/Policy/epoch{args.epoch}_iter{args.iter}.pth"
|
||||||
weight = args.trt_file
|
|
||||||
else:
|
|
||||||
weight = rsg_root + "/saved/YOPO_{}/Policy/epoch{}_iter{}.pth".format(args.trial, args.epoch, args.iter)
|
|
||||||
print("load weight from:", weight)
|
print("load weight from:", weight)
|
||||||
|
|
||||||
settings = {'use_tensorrt': args.use_tensorrt,
|
settings = {'use_tensorrt': args.use_tensorrt,
|
||||||
'network_frequency': 30,
|
|
||||||
'img_height': 96,
|
'img_height': 96,
|
||||||
'img_width': 160,
|
'img_width': 160,
|
||||||
'goal': [20, 20, 2], # the goal
|
'goal': [20, 20, 2], # the goal
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user