diff --git a/README.md b/README.md index b5f28ac..e8d055d 100644 --- a/README.md +++ b/README.md @@ -184,17 +184,12 @@ source devel/setup.bash roslaunch so3_quadrotor_simulator simulator.launch ``` -**2.3** Start the YOPO inference and the Planner (The implementation of `yopo_planner_node` will be moved to `test_yopo_ros.py` in the future). You can refer to [traj_opt.yaml](./flightlib/configs/traj_opt.yaml) for modification of the flight speed (The given weights are pretrained at 6 m/s and perform smoothly at speeds between 0 - 6 m/s). +**2.3** Start the YOPO inference and the Planner. You can refer to [traj_opt.yaml](./flightlib/configs/traj_opt.yaml) for modification of the flight speed (The given weights are pretrained at 6 m/s and perform smoothly at speeds between 0 - 6 m/s). ``` cd ~/YOPO/run conda activate yopo -python test_yopo_ros.py --trial=1 --epoch=0 --iter=0 -``` - -``` -cd ~/YOPO/flightlib/build -./yopo_planner_node +python test_yopo_ros_new.py --trial=1 --epoch=0 --iter=0 ``` **2.4** Visualization: start the RVIZ and publish the map. diff --git a/flightlib/src/ros_nodes/yopo_planner_node.cpp b/flightlib/src/ros_nodes/yopo_planner_node.cpp index 48e9229..2a68bb9 100644 --- a/flightlib/src/ros_nodes/yopo_planner_node.cpp +++ b/flightlib/src/ros_nodes/yopo_planner_node.cpp @@ -136,8 +136,8 @@ void trajs_vis_cb(const std_msgs::Float32MultiArray::ConstPtr msg) { pcl::PointCloud::Ptr lattice_trajs_cld(new pcl::PointCloud); Eigen::Vector3d pos_1(0.0, 0.0, 0.0), vel_1(0.0, 0.0, 0.0), acc_1(0.0, 0.0, 0.0); for (size_t i = 0; i < lattice_nodes.size(); i++) { - pos_1 = lattice_nodes[i].first; - vel_1 = lattice_nodes[i].second; + pos_1 = quat_ * lattice_nodes[i].first; + vel_1 = quat_ * lattice_nodes[i].second; std::vector endstate_lattice = {pos_1(0), vel_1(0), acc_1(0), pos_1(1), vel_1(1), acc_1(1), pos_1(2), vel_1(2), acc_1(2)}; traj_opt_bridge_for_vis->solveBVP(endstate_lattice); traj_to_pcl(traj_opt_bridge_for_vis, lattice_trajs_cld); diff --git a/flightpolicy/control_msg/_PositionCommand.py b/flightpolicy/control_msg/_PositionCommand.py new file mode 100644 index 0000000..057f693 --- /dev/null +++ b/flightpolicy/control_msg/_PositionCommand.py @@ -0,0 +1,312 @@ +# This Python file uses the following encoding: utf-8 +"""autogenerated by genpy from quadrotor_msgs/PositionCommand.msg. Do not edit.""" +import codecs +import sys +python3 = True if sys.hexversion > 0x03000000 else False +import genpy +import struct + +import geometry_msgs.msg +import std_msgs.msg + +class PositionCommand(genpy.Message): + _md5sum = "4712f0609ca29a79af79a35ca3e3967a" + _type = "quadrotor_msgs/PositionCommand" + _has_header = True # flag to mark the presence of a Header object + _full_text = """Header header +geometry_msgs/Point position +geometry_msgs/Vector3 velocity +geometry_msgs/Vector3 acceleration +float64 yaw +float64 yaw_dot +float64[3] kx +float64[3] kv + +uint32 trajectory_id + +uint8 TRAJECTORY_STATUS_EMPTY = 0 +uint8 TRAJECTORY_STATUS_READY = 1 +uint8 TRAJECTORY_STATUS_COMPLETED = 3 +uint8 TRAJECTROY_STATUS_ABORT = 4 +uint8 TRAJECTORY_STATUS_ILLEGAL_START = 5 +uint8 TRAJECTORY_STATUS_ILLEGAL_FINAL = 6 +uint8 TRAJECTORY_STATUS_IMPOSSIBLE = 7 + +# Its ID number will start from 1, allowing you comparing it with 0. +uint8 trajectory_flag + +================================================================================ +MSG: std_msgs/Header +# Standard metadata for higher-level stamped data types. +# This is generally used to communicate timestamped data +# in a particular coordinate frame. +# +# sequence ID: consecutively increasing ID +uint32 seq +#Two-integer timestamp that is expressed as: +# * stamp.sec: seconds (stamp_secs) since epoch (in Python the variable is called 'secs') +# * stamp.nsec: nanoseconds since stamp_secs (in Python the variable is called 'nsecs') +# time-handling sugar is provided by the client library +time stamp +#Frame this data is associated with +string frame_id + +================================================================================ +MSG: geometry_msgs/Point +# This contains the position of a point in free space +float64 x +float64 y +float64 z + +================================================================================ +MSG: geometry_msgs/Vector3 +# This represents a vector in free space. +# It is only meant to represent a direction. Therefore, it does not +# make sense to apply a translation to it (e.g., when applying a +# generic rigid transformation to a Vector3, tf2 will only apply the +# rotation). If you want your data to be translatable too, use the +# geometry_msgs/Point message instead. + +float64 x +float64 y +float64 z""" + # Pseudo-constants + TRAJECTORY_STATUS_EMPTY = 0 + TRAJECTORY_STATUS_READY = 1 + TRAJECTORY_STATUS_COMPLETED = 3 + TRAJECTROY_STATUS_ABORT = 4 + TRAJECTORY_STATUS_ILLEGAL_START = 5 + TRAJECTORY_STATUS_ILLEGAL_FINAL = 6 + TRAJECTORY_STATUS_IMPOSSIBLE = 7 + + __slots__ = ['header','position','velocity','acceleration','yaw','yaw_dot','kx','kv','trajectory_id','trajectory_flag'] + _slot_types = ['std_msgs/Header','geometry_msgs/Point','geometry_msgs/Vector3','geometry_msgs/Vector3','float64','float64','float64[3]','float64[3]','uint32','uint8'] + + def __init__(self, *args, **kwds): + """ + Constructor. Any message fields that are implicitly/explicitly + set to None will be assigned a default value. The recommend + use is keyword arguments as this is more robust to future message + changes. You cannot mix in-order arguments and keyword arguments. + + The available fields are: + header,position,velocity,acceleration,yaw,yaw_dot,kx,kv,trajectory_id,trajectory_flag + + :param args: complete set of field values, in .msg order + :param kwds: use keyword arguments corresponding to message field names + to set specific fields. + """ + if args or kwds: + super(PositionCommand, self).__init__(*args, **kwds) + # message fields cannot be None, assign default values for those that are + if self.header is None: + self.header = std_msgs.msg.Header() + if self.position is None: + self.position = geometry_msgs.msg.Point() + if self.velocity is None: + self.velocity = geometry_msgs.msg.Vector3() + if self.acceleration is None: + self.acceleration = geometry_msgs.msg.Vector3() + if self.yaw is None: + self.yaw = 0. + if self.yaw_dot is None: + self.yaw_dot = 0. + if self.kx is None: + self.kx = [0.] * 3 + if self.kv is None: + self.kv = [0.] * 3 + if self.trajectory_id is None: + self.trajectory_id = 0 + if self.trajectory_flag is None: + self.trajectory_flag = 0 + else: + self.header = std_msgs.msg.Header() + self.position = geometry_msgs.msg.Point() + self.velocity = geometry_msgs.msg.Vector3() + self.acceleration = geometry_msgs.msg.Vector3() + self.yaw = 0. + self.yaw_dot = 0. + self.kx = [0.] * 3 + self.kv = [0.] * 3 + self.trajectory_id = 0 + self.trajectory_flag = 0 + + def _get_types(self): + """ + internal API method + """ + return self._slot_types + + def serialize(self, buff): + """ + serialize message into buffer + :param buff: buffer, ``StringIO`` + """ + try: + _x = self + buff.write(_get_struct_3I().pack(_x.header.seq, _x.header.stamp.secs, _x.header.stamp.nsecs)) + _x = self.header.frame_id + length = len(_x) + if python3 or type(_x) == unicode: + _x = _x.encode('utf-8') + length = len(_x) + buff.write(struct.Struct(' 0.2 else last_yaw_ + max_yaw_change = YAW_DOT_MAX_PER_SEC * dt + + # Initialize yaw and yawdot + yaw = last_yaw_ + yawdot = 0 + + # Logic for yaw adjustment + if yaw_temp - last_yaw_ > np.pi: + if yaw_temp - last_yaw_ - 2 * np.pi < -max_yaw_change: + yaw = last_yaw_ - max_yaw_change + if yaw < -np.pi: + yaw += 2 * np.pi + yawdot = -YAW_DOT_MAX_PER_SEC + else: + yaw = yaw_temp + if yaw - last_yaw_ > np.pi: + yawdot = -YAW_DOT_MAX_PER_SEC + else: + yawdot = (yaw_temp - last_yaw_) / dt + elif yaw_temp - last_yaw_ < -np.pi: + if yaw_temp - last_yaw_ + 2 * np.pi > max_yaw_change: + yaw = last_yaw_ + max_yaw_change + if yaw > np.pi: + yaw -= 2 * np.pi + yawdot = YAW_DOT_MAX_PER_SEC + else: + yaw = yaw_temp + if yaw - last_yaw_ < -np.pi: + yawdot = YAW_DOT_MAX_PER_SEC + else: + yawdot = (yaw_temp - last_yaw_) / dt + else: + if yaw_temp - last_yaw_ < -max_yaw_change: + yaw = last_yaw_ - max_yaw_change + if yaw < -np.pi: + yaw += 2 * np.pi + yawdot = -YAW_DOT_MAX_PER_SEC + elif yaw_temp - last_yaw_ > max_yaw_change: + yaw = last_yaw_ + max_yaw_change + if yaw > np.pi: + yaw -= 2 * np.pi + yawdot = YAW_DOT_MAX_PER_SEC + else: + yaw = yaw_temp + if yaw - last_yaw_ > np.pi: + yawdot = -YAW_DOT_MAX_PER_SEC + elif yaw - last_yaw_ < -np.pi: + yawdot = YAW_DOT_MAX_PER_SEC + else: + yawdot = (yaw_temp - last_yaw_) / dt + + return yaw, yawdot \ No newline at end of file diff --git a/run/test_yopo_ros.py b/run/test_yopo_ros.py index 40c17a0..bd96b18 100644 --- a/run/test_yopo_ros.py +++ b/run/test_yopo_ros.py @@ -1,3 +1,15 @@ +""" + YOPO Network Inference NODE: + Subscribe odometry and depth messages, and perform network inference + Use: + $ cd ~/YOPO/run + $ conda activate yopo + $ python test_yopo_ros.py --trial=1 --epoch=0 --iter=0 + + $ cd ~/YOPO/flightlib/build + $ ./yopo_planner_node +""" + import rospy from sensor_msgs.msg import Image from nav_msgs.msg import Odometry diff --git a/run/test_yopo_ros_new.py b/run/test_yopo_ros_new.py new file mode 100644 index 0000000..6acb35e --- /dev/null +++ b/run/test_yopo_ros_new.py @@ -0,0 +1,435 @@ +""" + YOPO ROS NODE: + Subscribe odometry and depth messages, perform network inference, solve trajectory, and publish control commands. + Used to replace test_yopo_ros.py and yopo_planner_node.cpp. + If you encounter issues (such as unsmooth) with this script, try using the following instead: + $ cd ~/YOPO/run + $ conda activate yopo + $ python test_yopo_ros.py --trial=1 --epoch=0 --iter=0 + $ cd ~/YOPO/flightlib/build + $ ./yopo_planner_node +""" +import rospy +import std_msgs.msg +from nav_msgs.msg import Odometry +from geometry_msgs.msg import PoseStamped +from sensor_msgs.msg import PointCloud2, PointField, Image +from sensor_msgs import point_cloud2 +from cv_bridge import CvBridge +from threading import Lock +import numpy as np +import cv2 +import os +import torch +import argparse +import time +from ruamel.yaml import YAML +from scipy.spatial.transform import Rotation as R +from flightpolicy.control_msg import PositionCommand +from flightpolicy.yopo.yopo_policy import YopoPolicy +from flightpolicy.yopo.primitive_utils import LatticeParam, LatticePrimitive, Poly5Solver, Polys5Solver, calculate_yaw + +try: + from torch2trt import TRTModule +except ImportError: + print("tensorrt not found.") + + +class YopoNet: + def __init__(self, config, weight): + self.config = config + rospy.init_node('yopo_net', anonymous=False) + # load params + self.height = self.config['img_height'] + self.width = self.config['img_width'] + self.goal = np.array(self.config['goal']) + self.env = self.config['env'] + self.use_trt = self.config['use_tensorrt'] + self.verbose = self.config['verbose'] + self.visualize = self.config['visualize'] + self.Rotation_bc = R.from_euler('ZYX', [0, self.config['pitch_angle_deg'], 0], degrees=True).as_matrix() + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/traj_opt.yaml", 'r')) + self.lattice_space = LatticeParam(cfg) + self.lattice_primitive = LatticePrimitive(self.lattice_space) + + # variables + self.bridge = CvBridge() + self.odom = Odometry() + self.odom_init = False + self.last_yaw = 0.0 + self.ctrl_dt = 0.02 + self.ctrl_time = None + self.desire_init = False + self.arrive = False + self.desire_pos = None + self.desire_vel = None + self.desire_acc = None + self.optimal_poly_x = None + self.optimal_poly_y = None + self.optimal_poly_z = None + self.lock = Lock() + + # eval + self.time_forward = 0.0 + self.time_process = 0.0 + self.time_prepare = 0.0 + self.time_interpolation = 0.0 + self.time_visualize = 0.0 + self.count = 0 + + # Load Network + if self.use_trt: + self.policy = TRTModule() + self.policy.load_state_dict(torch.load(weight)) + else: + saved_variables = torch.load(weight, map_location=self.device) + saved_variables["data"]["lattice_space"] = self.lattice_space + saved_variables["data"]["lattice_primitive"] = self.lattice_primitive + self.policy = YopoPolicy(device=self.device, **saved_variables["data"]) + self.policy.load_state_dict(saved_variables["state_dict"], strict=False) + self.policy.to(self.device) + self.policy.set_training_mode(False) + torch.set_grad_enabled(False) + self.warm_up() + + # ros publisher + odom_topic = self.config['odom_topic'] + depth_topic = self.config['depth_topic'] + self.lattice_traj_pub = rospy.Publisher("/yopo_net/lattice_trajs_visual", PointCloud2, queue_size=1) + self.best_traj_pub = rospy.Publisher("/yopo_net/best_traj_visual", PointCloud2, queue_size=1) + self.all_trajs_pub = rospy.Publisher("/yopo_net/trajs_visual", PointCloud2, queue_size=1) + self.ctrl_pub = rospy.Publisher("/so3_control/pos_cmd", PositionCommand, queue_size=1) + # ros subscriber + self.odom_sub = rospy.Subscriber(odom_topic, Odometry, self.callback_odometry, queue_size=1) + self.depth_sub = rospy.Subscriber(depth_topic, Image, self.callback_depth, queue_size=1) + self.goal_sub = rospy.Subscriber("/move_base_simple/goal", PoseStamped, self.callback_set_goal, queue_size=1) + # ros timer + rospy.sleep(1.0) # wait connection... + self.timer_ctrl = rospy.Timer(rospy.Duration(self.ctrl_dt), self.control_pub) + print("YOPO Net Node Ready!") + rospy.spin() + + def callback_set_goal(self, data): + self.goal = np.asarray([data.pose.position.x, data.pose.position.y, 2]) + self.arrive = False + print(f"New Goal: ({data.pose.position.x:.1f}, {data.pose.position.y:.1f})") + + # the first frame + def callback_odometry(self, data): + self.odom = data + if not self.desire_init: + self.desire_pos = np.array((self.odom.pose.pose.position.x, self.odom.pose.pose.position.y, self.odom.pose.pose.position.z)) + self.desire_vel = np.array((self.odom.twist.twist.linear.x, self.odom.twist.twist.linear.y, self.odom.twist.twist.linear.z)) + self.desire_acc = np.array((0.0, 0.0, 0.0)) + ypr = R.from_quat([self.odom.pose.pose.orientation.x, self.odom.pose.pose.orientation.y, + self.odom.pose.pose.orientation.z, self.odom.pose.pose.orientation.w]).as_euler('ZYX', degrees=False) + self.last_yaw = ypr[0] + self.odom_init = True + + pos = np.array((self.odom.pose.pose.position.x, self.odom.pose.pose.position.y, self.odom.pose.pose.position.z)) + if np.linalg.norm(pos - self.goal) < 4 and not self.arrive: + print("Arrive!") + self.arrive = True + + def process_odom(self): + # Rwb -> Rwc -> Rcw + Rotation_wb = R.from_quat([self.odom.pose.pose.orientation.x, self.odom.pose.pose.orientation.y, + self.odom.pose.pose.orientation.z, self.odom.pose.pose.orientation.w]).as_matrix() + self.Rotation_wc = np.dot(Rotation_wb, self.Rotation_bc) + Rotation_cw = self.Rotation_wc.T + + # vel and acc + vel_w = self.desire_vel + vel_c = np.dot(Rotation_cw, vel_w) + acc_w = self.desire_acc + acc_c = np.dot(Rotation_cw, acc_w) + + # pose and goal_dir + goal_w = (self.goal - self.desire_pos) / np.linalg.norm(self.goal - self.desire_pos) + goal_c = np.dot(Rotation_cw, goal_w) + + vel_acc = np.concatenate((vel_c, acc_c), axis=0) + vel_acc_norm = self.normalize_obs(vel_acc[np.newaxis, :]) + obs_norm = np.hstack((vel_acc_norm, goal_c[np.newaxis, :])) + return obs_norm + + def callback_depth(self, data): + if not self.odom_init: + return + + # 1. Depth Image Process + min_dis, max_dis = 0.03, 20.0 + scale = {'435': 0.001, 'flightmare': 1.0}.get(self.env, 1.0) + + try: + depth = self.bridge.imgmsg_to_cv2(data, "32FC1") + except: + print("CV_bridge ERROR: Possible solutions may be found at https://github.com/TJU-Aerial-Robotics/YOPO/issues/2") + + time0 = time.time() + if depth.shape[0] != self.height or depth.shape[1] != self.width: + depth = cv2.resize(depth, (self.width, self.height), interpolation=cv2.INTER_NEAREST) + depth = np.minimum(depth * scale, max_dis) / max_dis + + # interpolated the nan value (experiment shows that treating nan directly as 0 produces similar results) + nan_mask = np.isnan(depth) | (depth < min_dis) + interpolated_image = cv2.inpaint(np.uint8(depth * 255), np.uint8(nan_mask), 1, cv2.INPAINT_NS) + interpolated_image = interpolated_image.astype(np.float32) / 255.0 + depth = interpolated_image.reshape([1, 1, self.height, self.width]) + # cv2.imshow("1", depth[0][0]) + # cv2.waitKey(1) + + # 2. YOPO Network Inference + # input prepare + time1 = time.time() + depth_input = torch.from_numpy(depth).to(self.device, non_blocking=True) # (non_blocking: copying speed 3x) + obs = self.process_odom() + obs_input = self.prepare_input_observation(obs) + obs_input = obs_input.to(self.device, non_blocking=True) + # torch.cuda.synchronize() + + time2 = time.time() + # Forward (TensorRT: inference speed increased by 5x) + with torch.no_grad(): + network_output = self.policy(depth_input, obs_input) + network_output = network_output.cpu().numpy() # torch.cuda.synchronize() is not needed here + time3 = time.time() + # Replacing PyTorch operation on CUDA with NumPy operation on CPU (speed increased by 10x) + endstate_pred, score_pred = self.process_output(network_output, return_all_preds=self.visualize) + # Vectorization: transform the prediction(P V A in body frame) to the world frame with the attitude (without the position) + endstate_c = endstate_pred.T.reshape(-1, 3, 3) + endstate_w = np.matmul(self.Rotation_wc, endstate_c) + # endstate_w = endstate_w.reshape(-1, 9).T + + action_id = np.argmin(score_pred) if self.visualize else 0 + with self.lock: # Python3.8: threads are scheduled using time slices, add the lock to ensure safety + self.optimal_poly_x = Poly5Solver(self.desire_pos[0], self.desire_vel[0], self.desire_acc[0], + endstate_w[action_id, 0, 0] + self.desire_pos[0], endstate_w[action_id, 0, 1], endstate_w[action_id, 0, 2], self.lattice_space.segment_time) + self.optimal_poly_y = Poly5Solver(self.desire_pos[1], self.desire_vel[1], self.desire_acc[1], + endstate_w[action_id, 1, 0] + self.desire_pos[1], endstate_w[action_id, 1, 1], endstate_w[action_id, 1, 2], self.lattice_space.segment_time) + self.optimal_poly_z = Poly5Solver(self.desire_pos[2], self.desire_vel[2], self.desire_acc[2], + endstate_w[action_id, 2, 0] + self.desire_pos[2], endstate_w[action_id, 2, 1], endstate_w[action_id, 2, 2], self.lattice_space.segment_time) + self.ctrl_time = 0.0 + time4 = time.time() + self.visualize_trajectory(score_pred, endstate_w) + time5 = time.time() + + if self.verbose: + self.time_interpolation = self.time_interpolation + (time1 - time0) + self.time_prepare = self.time_prepare + (time2 - time1) + self.time_forward = self.time_forward + (time3 - time2) + self.time_process = self.time_process + (time4 - time3) + self.time_visualize = self.time_visualize + (time5 - time4) + self.count = self.count + 1 + print(f"Time Consuming:" + f"depth-interpolation: {1000 * self.time_interpolation / self.count:.2f}ms;" + f"data-prepare: {1000 * self.time_prepare / self.count:.2f}ms; " + f"network-inference: {1000 * self.time_forward / self.count:.2f}ms; " + f"post-process: {1000 * self.time_process / self.count:.2f}ms;" + f"visualize-trajectory: {1000 * self.time_visualize / self.count:.2f}ms") + + def control_pub(self, _timer): + if self.ctrl_time is None or self.ctrl_time > self.lattice_space.segment_time: + return + if self.arrive: + self.desire_init = False # ready for next rollout + return + with self.lock: # Python3.8: threads are scheduled using time slices, add the lock to ensure safety and publish frequency + self.ctrl_time += self.ctrl_dt + control_msg = PositionCommand() + control_msg.header.stamp = rospy.Time.now() + control_msg.trajectory_flag = control_msg.TRAJECTORY_STATUS_READY + control_msg.position.x = self.optimal_poly_x.get_position(self.ctrl_time) + control_msg.position.y = self.optimal_poly_y.get_position(self.ctrl_time) + control_msg.position.z = self.optimal_poly_z.get_position(self.ctrl_time) + control_msg.velocity.x = self.optimal_poly_x.get_velocity(self.ctrl_time) + control_msg.velocity.y = self.optimal_poly_y.get_velocity(self.ctrl_time) + control_msg.velocity.z = self.optimal_poly_z.get_velocity(self.ctrl_time) + control_msg.acceleration.x = self.optimal_poly_x.get_acceleration(self.ctrl_time) + control_msg.acceleration.y = self.optimal_poly_y.get_acceleration(self.ctrl_time) + control_msg.acceleration.z = self.optimal_poly_z.get_acceleration(self.ctrl_time) + self.desire_pos = np.array([control_msg.position.x, control_msg.position.y, control_msg.position.z]) + self.desire_vel = np.array([control_msg.velocity.x, control_msg.velocity.y, control_msg.velocity.z]) + self.desire_acc = np.array([control_msg.acceleration.x, control_msg.acceleration.y, control_msg.acceleration.z]) + goal_dir = self.goal - self.desire_pos + yaw, yaw_dot = calculate_yaw(self.desire_vel, goal_dir, self.last_yaw, self.ctrl_dt) + self.last_yaw = yaw + control_msg.yaw = yaw + control_msg.yaw_dot = yaw_dot + self.desire_init = True + self.ctrl_pub.publish(control_msg) + + def process_output(self, network_output, return_all_preds=False): + if network_output.shape[0] != 1: + raise ValueError("batch of output values must be 1 in test!") + network_output = network_output.reshape(10, self.lattice_space.horizon_num * self.lattice_space.vertical_num) + endstate_pred = network_output[0:9, :] + score_pred = network_output[9, :] + + if not return_all_preds: + action_id = np.argmin(score_pred) + lattice_id = self.lattice_space.horizon_num * self.lattice_space.vertical_num - 1 - action_id + endstate_prediction = self.pred_to_endstate(endstate_pred[:, action_id], lattice_id) + endstate_prediction = endstate_prediction[:, np.newaxis] + score_prediction = score_pred[action_id] + else: + endstate_prediction = np.zeros_like(endstate_pred) + score_prediction = score_pred + for i in range(self.lattice_space.horizon_num * self.lattice_space.vertical_num): + lattice_id = self.lattice_space.horizon_num * self.lattice_space.vertical_num - 1 - i + endstate_prediction[:, i] = self.pred_to_endstate(endstate_pred[:, i], lattice_id) + + return endstate_prediction, score_prediction + + def prepare_input_observation(self, obs): + """ + convert the observation from body frame to primitive frame, + and then concatenate it with the depth features (to ensure the translational invariance) + """ + if obs.shape[0] != 1: + raise ValueError("batch of input observations must be 1 in test!") + + obs_return = np.ones((obs.shape[0], obs.shape[1], self.lattice_space.vertical_num, self.lattice_space.horizon_num), dtype=np.float32) + id = 0 + obs_reshaped = obs.reshape(3, 3) + for i in range(self.lattice_space.vertical_num - 1, -1, -1): + for j in range(self.lattice_space.horizon_num - 1, -1, -1): + Rbp = self.lattice_primitive.getRotation(id) + obs_return_reshaped = np.dot(obs_reshaped, Rbp) + obs_return[:, :, i, j] = obs_return_reshaped.reshape(9) + id = id + 1 + return torch.from_numpy(obs_return) + + def pred_to_endstate(self, endstate_pred: np.ndarray, id: int): + """ + Transform the predicted state to the body frame. + """ + delta_yaw = endstate_pred[0] * self.lattice_primitive.yaw_diff + delta_pitch = endstate_pred[1] * self.lattice_primitive.pitch_diff + radio = endstate_pred[2] * self.lattice_space.radio_range + self.lattice_space.radio_range + yaw, pitch = self.lattice_primitive.getAngleLattice(id) + endstate_x = np.cos(pitch + delta_pitch) * np.cos(yaw + delta_yaw) * radio + endstate_y = np.cos(pitch + delta_pitch) * np.sin(yaw + delta_yaw) * radio + endstate_z = np.sin(pitch + delta_pitch) * radio + endstate_p = np.array((endstate_x, endstate_y, endstate_z)) + + endstate_vp = endstate_pred[3:6] * self.lattice_space.vel_max + endstate_ap = endstate_pred[6:9] * self.lattice_space.acc_max + Rpb = self.lattice_primitive.getRotation(id).T + endstate_vb = np.matmul(endstate_vp, Rpb) + endstate_ab = np.matmul(endstate_ap, Rpb) + endstate = np.concatenate((endstate_p, endstate_vb, endstate_ab)) + endstate[[0, 1, 2, 3, 4, 5, 6, 7, 8]] = endstate[[0, 3, 6, 1, 4, 7, 2, 5, 8]] + return endstate + + def normalize_obs(self, vel_acc): + vel_norm = vel_acc[:, 0:3] / self.lattice_space.vel_max + acc_norm = vel_acc[:, 3:6] / self.lattice_space.acc_max + return np.hstack((vel_norm, acc_norm)) + + def visualize_trajectory(self, pred_score, pred_endstate): + dt = self.lattice_space.segment_time / 20.0 + # best predicted trajectory + if self.best_traj_pub.get_num_connections() > 0: + t_values = np.arange(0, self.lattice_space.segment_time, dt) + points_array = np.stack(( + self.optimal_poly_x.get_position(t_values), + self.optimal_poly_y.get_position(t_values), + self.optimal_poly_z.get_position(t_values) + ), axis=-1) + header = std_msgs.msg.Header() + header.stamp = rospy.Time.now() + header.frame_id = 'world' + point_cloud_msg = point_cloud2.create_cloud_xyz32(header, points_array) + self.best_traj_pub.publish(point_cloud_msg) + # lattice primitive + if self.visualize and self.lattice_traj_pub.get_num_connections() > 0: + lattice_endstate = self.lattice_primitive.lattice_pos_node + lattice_endstate = np.dot(lattice_endstate, self.Rotation_wc.T) + zero_state = np.zeros_like(lattice_endstate) + lattice_poly_x = Polys5Solver(self.desire_pos[0], self.desire_vel[0], self.desire_acc[0], + lattice_endstate[:, 0] + self.desire_pos[0], zero_state[:, 0], zero_state[:, 0], self.lattice_space.segment_time) + lattice_poly_y = Polys5Solver(self.desire_pos[1], self.desire_vel[1], self.desire_acc[1], + lattice_endstate[:, 1] + self.desire_pos[1], zero_state[:, 1], zero_state[:, 1], self.lattice_space.segment_time) + lattice_poly_z = Polys5Solver(self.desire_pos[2], self.desire_vel[2], self.desire_acc[2], + lattice_endstate[:, 2] + self.desire_pos[2], zero_state[:, 2], zero_state[:, 2], self.lattice_space.segment_time) + t_values = np.arange(0, self.lattice_space.segment_time, dt) + points_array = np.stack(( + lattice_poly_x.get_position(t_values), + lattice_poly_y.get_position(t_values), + lattice_poly_z.get_position(t_values) + ), axis=-1) + header = std_msgs.msg.Header() + header.stamp = rospy.Time.now() + header.frame_id = 'world' + point_cloud_msg = point_cloud2.create_cloud_xyz32(header, points_array) + self.lattice_traj_pub.publish(point_cloud_msg) + # all predicted trajectories + if self.visualize and self.all_trajs_pub.get_num_connections() > 0: + all_poly_x = Polys5Solver(self.desire_pos[0], self.desire_vel[0], self.desire_acc[0], + pred_endstate[:, 0, 0] + self.desire_pos[0], pred_endstate[:, 0, 1], pred_endstate[:, 0, 2], self.lattice_space.segment_time) + all_poly_y = Polys5Solver(self.desire_pos[1], self.desire_vel[1], self.desire_acc[1], + pred_endstate[:, 1, 0] + self.desire_pos[1], pred_endstate[:, 1, 1], pred_endstate[:, 1, 2], self.lattice_space.segment_time) + all_poly_z = Polys5Solver(self.desire_pos[2], self.desire_vel[2], self.desire_acc[2], + pred_endstate[:, 2, 0] + self.desire_pos[2], pred_endstate[:, 2, 1], pred_endstate[:, 2, 2], self.lattice_space.segment_time) + t_values = np.arange(0, self.lattice_space.segment_time, dt) + points_array = np.stack(( + all_poly_x.get_position(t_values), + all_poly_y.get_position(t_values), + all_poly_z.get_position(t_values) + ), axis=-1) + scores = np.repeat(pred_score, t_values.size) + points_array = np.column_stack((points_array, scores)) + header = std_msgs.msg.Header() + header.stamp = rospy.Time.now() + header.frame_id = 'world' + fields = [PointField('x', 0, PointField.FLOAT32, 1), PointField('y', 4, PointField.FLOAT32, 1), + PointField('z', 8, PointField.FLOAT32, 1), PointField('intensity', 12, PointField.FLOAT32, 1)] + point_cloud_msg = point_cloud2.create_cloud(header, fields, points_array) + self.all_trajs_pub.publish(point_cloud_msg) + + def warm_up(self): + depth = np.zeros(shape=[1, 1, self.height, self.width], dtype=np.float32) + obs = np.zeros(shape=[1, 9], dtype=np.float32) + obs_input = self.prepare_input_observation(obs) + network_output = self.policy(torch.from_numpy(depth).to(self.device), obs_input.to(self.device)) + self.process_output(network_output.cpu().numpy(), return_all_preds=True) + + +def parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--use_tensorrt", type=int, default=0, help="use tensorrt or not") + parser.add_argument("--trial", type=int, default=1, help="trial number") + parser.add_argument("--epoch", type=int, default=0, help="epoch number") + parser.add_argument("--iter", type=int, default=0, help="iter number") + parser.add_argument("--trt_file", type=str, default='yopo_trt.pth', help="tensorrt filename") + return parser + + +# In realworld flight: visualize=False; use_tensorrt=True, and ensure the pitch_angle consistent with your platform +# When modifying the pitch_angle, there's no need to re-collect and re-train, as all predictions are in the camera coordinate system +# Change the flight speed at traj_opt.yaml and there's no need to re-collect and re-train +def main(): + args = parser().parse_args() + rsg_root = os.path.dirname(os.path.abspath(__file__)) + weight = args.trt_file if args.use_tensorrt else f"{rsg_root}/saved/YOPO_{args.trial}/Policy/epoch{args.epoch}_iter{args.iter}.pth" + print("load weight from:", weight) + + settings = {'use_tensorrt': args.use_tensorrt, + 'img_height': 96, + 'img_width': 160, + 'goal': [20, 20, 2], # the goal + 'env': 'flightmare', # use Realsense D435 or Flightmare Simulator ('435' or 'flightmare') + 'pitch_angle_deg': -5, # pitch of camera, ensure consistent with the simulator or your platform (no need to re-collect and re-train when modifying) + 'odom_topic': '/juliett/ground_truth/odom', + 'depth_topic': '/depth_image', + 'verbose': False, # print the latency? + 'visualize': True # visualize all predictions? set False in real flight + } + YopoNet(settings, weight) + + +if __name__ == "__main__": + main()