Fix a issue when test multi-round in simulation
This commit is contained in:
parent
361e57fc2b
commit
c815ae416e
@ -10,9 +10,9 @@ Some realworld experiment: [YouTube](https://youtu.be/LHvtbKmTwvE), [bilibili](h
|
||||
## Introduction:
|
||||
We propose **a learning-based planner for autonomous navigation in obstacle-dense environments** which intergrats (i) perception and mapping, (ii) front-end path searching, and (iii) back-end optimization of classical methods into a single network.
|
||||
|
||||
Considering the multi-modal nature of the navigation problem and to avoid local minima around initial values, our approach adopts a set of motion primitives as anchor to cover the searching space, and predicts the offsets and scores of primitives for further improvement (like the one-stage object detector YOLO).
|
||||
**Learning-based Planner:** Considering the multi-modal nature of the navigation problem and to avoid local minima around initial values, our approach adopts a set of motion primitives as anchor to cover the searching space, and predicts the offsets and scores of primitives for further improvement (like the one-stage object detector YOLO).
|
||||
|
||||
Compared to giving expert demonstrations for imitation in imitation learning or exploring by trial-and-error in reinforcement learning, we directly back-propagate the numerical gradient (e.g. from ESDF) to the weights of neural network in the training process, which is realistic, accurate, and timely.
|
||||
**Training Strategy:** Compared to giving expert demonstrations for imitation in imitation learning or exploring by trial-and-error in reinforcement learning, we directly back-propagate the numerical gradient (e.g. from ESDF) to the weights of neural network in the training process, which is realistic, accurate, and timely.
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
#include <ros/ros.h>
|
||||
#include <std_msgs/Float32MultiArray.h>
|
||||
#include <yaml-cpp/yaml.h>
|
||||
|
||||
#include <Eigen/Core>
|
||||
|
||||
#include "flightlib/controller/PositionCommand.h"
|
||||
@ -20,7 +19,7 @@ quadrotor_msgs::PositionCommand pos_cmd_last;
|
||||
bool odom_init = false;
|
||||
bool odom_ref_init = false;
|
||||
bool yopo_init = false;
|
||||
bool done = false;
|
||||
bool arrive = false;
|
||||
float traj_time = 2.0;
|
||||
float sample_t = 0.0;
|
||||
float last_yaw_ = 0; // NWU
|
||||
@ -33,7 +32,7 @@ Eigen::Vector3d goal_(100, 0, 2);
|
||||
Eigen::Quaterniond quat_(1, 0, 0, 0);
|
||||
Eigen::Vector3d last_pos_(0, 0, 0), last_vel_(0, 0, 0), last_acc_(0, 0, 0);
|
||||
|
||||
ros::Publisher trajs_visual_pub, best_traj_visual_pub, state_ref_pub, ctrl_pub, mpc_ctrl_pub, so3_ctrl_pub, lattice_trajs_visual_pub;
|
||||
ros::Publisher trajs_visual_pub, best_traj_visual_pub, state_ref_pub, our_ctrl_pub, so3_ctrl_pub, lattice_trajs_visual_pub;
|
||||
ros::Subscriber odom_sub, odom_ref_sub, yopo_best_sub, yopo_all_sub, goal_sub;
|
||||
|
||||
void odom_cb(const nav_msgs::Odometry::Ptr msg) {
|
||||
@ -52,10 +51,9 @@ void odom_cb(const nav_msgs::Odometry::Ptr msg) {
|
||||
// check if reach the goal
|
||||
Eigen::Vector3d dist(odom_msg.pose.pose.position.x - goal_(0), odom_msg.pose.pose.position.y - goal_(1),
|
||||
odom_msg.pose.pose.position.z - goal_(2));
|
||||
if (dist.norm() < 4) {
|
||||
if (!done)
|
||||
printf("Done!\n");
|
||||
done = true;
|
||||
if (dist.norm() < 4 && !arrive) {
|
||||
printf("Arrive!\n");
|
||||
arrive = true;
|
||||
}
|
||||
}
|
||||
|
||||
@ -65,7 +63,7 @@ void goal_cb(const std_msgs::Float32MultiArray::Ptr msg) {
|
||||
goal_(1) = msg->data[1];
|
||||
goal_(2) = msg->data[2];
|
||||
if (last_goal != goal_)
|
||||
done = false;
|
||||
arrive = false;
|
||||
}
|
||||
|
||||
void traj_to_pcl(TrajOptimizationBridge* traj_opt_bridge_input, pcl::PointCloud<pcl::PointXYZI>::Ptr cloud, double cost = 0.0) {
|
||||
@ -225,23 +223,23 @@ void ref_pub_cb(const ros::TimerEvent&) {
|
||||
if (!yopo_init)
|
||||
return;
|
||||
|
||||
if (done) {
|
||||
if (arrive) {
|
||||
odom_ref_init = false;
|
||||
// single state control for smoother performance
|
||||
ctrl_ref_last.header.stamp = ros::Time::now();
|
||||
ctrl_ref_last.vel_ref = {0, 0, 0};
|
||||
ctrl_ref_last.acc_ref = {0, 0, 0};
|
||||
ctrl_ref_last.ref_mask = 1;
|
||||
ctrl_pub.publish(ctrl_ref_last);
|
||||
|
||||
// un-smooth, just for simpler demonstration
|
||||
pos_cmd_last.header.stamp = ros::Time::now();
|
||||
pos_cmd_last.velocity.x = 0.95 * pos_cmd_last.velocity.x;
|
||||
pos_cmd_last.velocity.y = 0.95 * pos_cmd_last.velocity.y;
|
||||
pos_cmd_last.velocity.z = 0.95 * pos_cmd_last.velocity.z;
|
||||
pos_cmd_last.acceleration.x = 0.95 * pos_cmd_last.acceleration.x;
|
||||
pos_cmd_last.acceleration.y = 0.95 * pos_cmd_last.acceleration.y;
|
||||
pos_cmd_last.acceleration.z = 0.95 * pos_cmd_last.acceleration.z;
|
||||
pos_cmd_last.yaw_dot = 0.95 * pos_cmd_last.yaw_dot;
|
||||
our_ctrl_pub.publish(ctrl_ref_last);
|
||||
// larger position error, just for simpler demonstration
|
||||
pos_cmd_last.header.stamp = ros::Time::now();
|
||||
// pos_cmd_last.velocity.x = 0.0;
|
||||
// pos_cmd_last.velocity.y = 0.0;
|
||||
// pos_cmd_last.velocity.z = 0.0;
|
||||
// pos_cmd_last.acceleration.x = 0.0;
|
||||
// pos_cmd_last.acceleration.y = 0.0;
|
||||
// pos_cmd_last.acceleration.z = 0.0;
|
||||
// pos_cmd_last.yaw_dot = 0.0;
|
||||
so3_ctrl_pub.publish(pos_cmd_last);
|
||||
return;
|
||||
}
|
||||
@ -261,7 +259,7 @@ void ref_pub_cb(const ros::TimerEvent&) {
|
||||
ctrl_msg.yaw_ref = -yaw_yawdot.first;
|
||||
ctrl_msg.ref_mask = 7;
|
||||
ctrl_ref_last = ctrl_msg;
|
||||
ctrl_pub.publish(ctrl_msg);
|
||||
our_ctrl_pub.publish(ctrl_msg);
|
||||
|
||||
// SO3 Simulator & SO3 Controller
|
||||
quadrotor_msgs::PositionCommand cmd;
|
||||
@ -333,7 +331,7 @@ int main(int argc, char** argv) {
|
||||
state_ref_pub = nh.advertise<nav_msgs::Odometry>("/juliett/state_ref/odom", 10);
|
||||
|
||||
// our PID Controller (realworld) & SO3 Controller (simulation)
|
||||
ctrl_pub = nh.advertise<quad_pos_ctrl::ctrl_ref>("/juliett_pos_ctrl_node/controller/ctrl_ref", 10);
|
||||
our_ctrl_pub = nh.advertise<quad_pos_ctrl::ctrl_ref>("/juliett_pos_ctrl_node/controller/ctrl_ref", 10);
|
||||
so3_ctrl_pub = nh.advertise<quadrotor_msgs::PositionCommand>("/so3_control/pos_cmd", 10);
|
||||
|
||||
odom_sub = nh.subscribe("/juliett/ground_truth/odom", 1, yopo_net::odom_cb, ros::TransportHints().tcpNoDelay());
|
||||
|
||||
@ -79,8 +79,7 @@ class YopoNet:
|
||||
self.goal_pub = rospy.Publisher("/yopo_net/goal", Float32MultiArray, queue_size=1)
|
||||
# ros subscriber
|
||||
self.odom_sub = rospy.Subscriber(odom_topic, Odometry, self.callback_odometry, queue_size=1, tcp_nodelay=True)
|
||||
self.odom_ref_sub = rospy.Subscriber("/juliett/state_ref/odom", Odometry, self.callback_odometry_ref,
|
||||
queue_size=1, tcp_nodelay=True)
|
||||
self.odom_ref_sub = rospy.Subscriber("/juliett/state_ref/odom", Odometry, self.callback_odometry_ref, queue_size=1, tcp_nodelay=True)
|
||||
self.depth_sub = rospy.Subscriber(depth_topic, Image, self.callback_depth, queue_size=1, tcp_nodelay=True)
|
||||
self.goal_sub = rospy.Subscriber("/move_base_simple/goal", PoseStamped, self.callback_set_goal, queue_size=1)
|
||||
self.timer_net = rospy.Timer(rospy.Duration(1. / self.config['network_frequency']), self.test_policy)
|
||||
@ -96,42 +95,31 @@ class YopoNet:
|
||||
# the following frame (The planner is planning from the desired state, instead of the actual state)
|
||||
def callback_odometry_ref(self, data):
|
||||
if not self.odom_ref_init:
|
||||
print("odom ref init")
|
||||
self.odom_ref_init = True
|
||||
self.odom_ref = data
|
||||
self.new_odom = True
|
||||
|
||||
def process_odom(self):
|
||||
# Rwb
|
||||
Rotation_wb = R.from_quat([self.odom.pose.pose.orientation.x,
|
||||
self.odom.pose.pose.orientation.y,
|
||||
self.odom.pose.pose.orientation.z,
|
||||
self.odom.pose.pose.orientation.w]).as_matrix()
|
||||
Rotation_wb = R.from_quat([self.odom.pose.pose.orientation.x, self.odom.pose.pose.orientation.y,
|
||||
self.odom.pose.pose.orientation.z, self.odom.pose.pose.orientation.w]).as_matrix()
|
||||
self.Rotation_wc = np.dot(Rotation_wb, self.Rotation_bc)
|
||||
|
||||
if self.odom_ref_init:
|
||||
odom_data = self.odom_ref
|
||||
# vel_b
|
||||
vel_w = np.array([odom_data.twist.twist.linear.x,
|
||||
odom_data.twist.twist.linear.y,
|
||||
odom_data.twist.twist.linear.z])
|
||||
vel_w = np.array([odom_data.twist.twist.linear.x, odom_data.twist.twist.linear.y, odom_data.twist.twist.linear.z])
|
||||
vel_b = np.dot(np.linalg.inv(self.Rotation_wc), vel_w)
|
||||
# acc_b
|
||||
acc_w = np.array([odom_data.twist.twist.angular.x, # acc stored in angular in our ref_state topic
|
||||
odom_data.twist.twist.angular.y,
|
||||
odom_data.twist.twist.angular.z])
|
||||
# acc_b (acc stored in angular in our ref_state topic)
|
||||
acc_w = np.array([odom_data.twist.twist.angular.x, odom_data.twist.twist.angular.y, odom_data.twist.twist.angular.z])
|
||||
acc_b = np.dot(np.linalg.inv(self.Rotation_wc), acc_w)
|
||||
else:
|
||||
odom_data = self.odom
|
||||
vel_b = np.array([0.0, 0.0, 0.0])
|
||||
acc_b = np.array([0.0, 0.0, 0.0])
|
||||
|
||||
# pose
|
||||
pos = np.array([odom_data.pose.pose.position.x,
|
||||
odom_data.pose.pose.position.y,
|
||||
odom_data.pose.pose.position.z])
|
||||
|
||||
# goal_dir
|
||||
# pose and goal_dir
|
||||
pos = np.array([odom_data.pose.pose.position.x, odom_data.pose.pose.position.y, odom_data.pose.pose.position.z])
|
||||
goal_w = (self.goal - pos) / np.linalg.norm(self.goal - pos)
|
||||
goal_b = np.dot(np.linalg.inv(self.Rotation_wc), goal_w)
|
||||
|
||||
@ -143,15 +131,12 @@ class YopoNet:
|
||||
def callback_depth(self, data):
|
||||
max_dis = 20.0
|
||||
min_dis = 0.03
|
||||
if self.env == '435':
|
||||
scale = 0.001
|
||||
elif self.env == 'flightmare':
|
||||
scale = 1.0
|
||||
scale = {'435': 0.001, 'flightmare': 1.0}.get(self.env, 1.0)
|
||||
|
||||
try:
|
||||
depth_ = self.bridge.imgmsg_to_cv2(data, "32FC1")
|
||||
except:
|
||||
print("CV_bridge ERROR: The ROS path is not included in Python Path!")
|
||||
print("CV_bridge ERROR: Possible solutions may be found at https://github.com/TJU-Aerial-Robotics/YOPO/issues/2")
|
||||
|
||||
if depth_.shape[0] != self.height or depth_.shape[1] != self.width:
|
||||
depth_ = cv2.resize(depth_, (self.width, self.height), interpolation=cv2.INTER_NEAREST)
|
||||
@ -257,15 +242,14 @@ class YopoNet:
|
||||
all_endstate_pred.layout.dim[1].label = "endstate_and_score_num"
|
||||
self.all_endstate_pub.publish(all_endstate_pred)
|
||||
self.goal_pub.publish(Float32MultiArray(data=self.goal))
|
||||
else:
|
||||
if not self.new_odom: # start a new round
|
||||
self.odom_ref_init = False
|
||||
# start a new round
|
||||
elif not self.new_odom:
|
||||
self.odom_ref_init = False
|
||||
|
||||
def trt_process(self, input_tensor: torch.Tensor, return_all_preds=False) -> torch.Tensor:
|
||||
batch_size = input_tensor.shape[0]
|
||||
input_tensor = input_tensor.cpu().numpy()
|
||||
input_tensor = input_tensor.reshape(batch_size, 10,
|
||||
self.lattice_space.horizon_num * self.lattice_space.vertical_num)
|
||||
input_tensor = input_tensor.reshape(batch_size, 10, self.lattice_space.horizon_num * self.lattice_space.vertical_num)
|
||||
endstate_pred = input_tensor[:, 0:9, :]
|
||||
score_pred = input_tensor[:, 9, :]
|
||||
|
||||
@ -292,9 +276,7 @@ class YopoNet:
|
||||
convert the observation from body frame to primitive frame,
|
||||
and then concatenate it with the depth features (to ensure the translational invariance)
|
||||
"""
|
||||
obs_return = np.ones(
|
||||
(obs.shape[0], self.lattice_space.vertical_num, self.lattice_space.horizon_num, obs.shape[1]),
|
||||
dtype=np.float32)
|
||||
obs_return = np.ones((obs.shape[0], self.lattice_space.vertical_num, self.lattice_space.horizon_num, obs.shape[1]), dtype=np.float32)
|
||||
id = 0
|
||||
v_b = obs[:, 0:3]
|
||||
a_b = obs[:, 3:6]
|
||||
@ -348,8 +330,7 @@ class YopoNet:
|
||||
trt_output = self.policy(torch.from_numpy(depth).to(self.device), obs_input.to(self.device))
|
||||
self.trt_process(trt_output, return_all_preds=True)
|
||||
else:
|
||||
self.policy.predict(torch.from_numpy(depth).to(self.device), obs_input.to(self.device),
|
||||
return_all_preds=True)
|
||||
self.policy.predict(torch.from_numpy(depth).to(self.device), obs_input.to(self.device), return_all_preds=True)
|
||||
|
||||
|
||||
def parser():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user