move test_yopo_ros.py and yopo_planner_node.cpp to single script

This commit is contained in:
TJU_Lu 2024-12-24 18:03:49 +08:00
parent 09e832c829
commit 1a9f7c9f42
7 changed files with 892 additions and 12 deletions

View File

@ -184,17 +184,12 @@ source devel/setup.bash
roslaunch so3_quadrotor_simulator simulator.launch
```
**2.3** Start the YOPO inference and the Planner (The implementation of `yopo_planner_node` will be moved to `test_yopo_ros.py` in the future). You can refer to [traj_opt.yaml](./flightlib/configs/traj_opt.yaml) for modification of the flight speed (The given weights are pretrained at 6 m/s and perform smoothly at speeds between 0 - 6 m/s).
**2.3** Start the YOPO inference and the Planner. You can refer to [traj_opt.yaml](./flightlib/configs/traj_opt.yaml) for modification of the flight speed (The given weights are pretrained at 6 m/s and perform smoothly at speeds between 0 - 6 m/s).
```
cd ~/YOPO/run
conda activate yopo
python test_yopo_ros.py --trial=1 --epoch=0 --iter=0
```
```
cd ~/YOPO/flightlib/build
./yopo_planner_node
python test_yopo_ros_new.py --trial=1 --epoch=0 --iter=0
```
**2.4** Visualization: start the RVIZ and publish the map.

View File

@ -136,8 +136,8 @@ void trajs_vis_cb(const std_msgs::Float32MultiArray::ConstPtr msg) {
pcl::PointCloud<pcl::PointXYZI>::Ptr lattice_trajs_cld(new pcl::PointCloud<pcl::PointXYZI>);
Eigen::Vector3d pos_1(0.0, 0.0, 0.0), vel_1(0.0, 0.0, 0.0), acc_1(0.0, 0.0, 0.0);
for (size_t i = 0; i < lattice_nodes.size(); i++) {
pos_1 = lattice_nodes[i].first;
vel_1 = lattice_nodes[i].second;
pos_1 = quat_ * lattice_nodes[i].first;
vel_1 = quat_ * lattice_nodes[i].second;
std::vector<double> endstate_lattice = {pos_1(0), vel_1(0), acc_1(0), pos_1(1), vel_1(1), acc_1(1), pos_1(2), vel_1(2), acc_1(2)};
traj_opt_bridge_for_vis->solveBVP(endstate_lattice);
traj_to_pcl(traj_opt_bridge_for_vis, lattice_trajs_cld);

View File

@ -0,0 +1,312 @@
# This Python file uses the following encoding: utf-8
"""autogenerated by genpy from quadrotor_msgs/PositionCommand.msg. Do not edit."""
import codecs
import sys
python3 = True if sys.hexversion > 0x03000000 else False
import genpy
import struct
import geometry_msgs.msg
import std_msgs.msg
class PositionCommand(genpy.Message):
_md5sum = "4712f0609ca29a79af79a35ca3e3967a"
_type = "quadrotor_msgs/PositionCommand"
_has_header = True # flag to mark the presence of a Header object
_full_text = """Header header
geometry_msgs/Point position
geometry_msgs/Vector3 velocity
geometry_msgs/Vector3 acceleration
float64 yaw
float64 yaw_dot
float64[3] kx
float64[3] kv
uint32 trajectory_id
uint8 TRAJECTORY_STATUS_EMPTY = 0
uint8 TRAJECTORY_STATUS_READY = 1
uint8 TRAJECTORY_STATUS_COMPLETED = 3
uint8 TRAJECTROY_STATUS_ABORT = 4
uint8 TRAJECTORY_STATUS_ILLEGAL_START = 5
uint8 TRAJECTORY_STATUS_ILLEGAL_FINAL = 6
uint8 TRAJECTORY_STATUS_IMPOSSIBLE = 7
# Its ID number will start from 1, allowing you comparing it with 0.
uint8 trajectory_flag
================================================================================
MSG: std_msgs/Header
# Standard metadata for higher-level stamped data types.
# This is generally used to communicate timestamped data
# in a particular coordinate frame.
#
# sequence ID: consecutively increasing ID
uint32 seq
#Two-integer timestamp that is expressed as:
# * stamp.sec: seconds (stamp_secs) since epoch (in Python the variable is called 'secs')
# * stamp.nsec: nanoseconds since stamp_secs (in Python the variable is called 'nsecs')
# time-handling sugar is provided by the client library
time stamp
#Frame this data is associated with
string frame_id
================================================================================
MSG: geometry_msgs/Point
# This contains the position of a point in free space
float64 x
float64 y
float64 z
================================================================================
MSG: geometry_msgs/Vector3
# This represents a vector in free space.
# It is only meant to represent a direction. Therefore, it does not
# make sense to apply a translation to it (e.g., when applying a
# generic rigid transformation to a Vector3, tf2 will only apply the
# rotation). If you want your data to be translatable too, use the
# geometry_msgs/Point message instead.
float64 x
float64 y
float64 z"""
# Pseudo-constants
TRAJECTORY_STATUS_EMPTY = 0
TRAJECTORY_STATUS_READY = 1
TRAJECTORY_STATUS_COMPLETED = 3
TRAJECTROY_STATUS_ABORT = 4
TRAJECTORY_STATUS_ILLEGAL_START = 5
TRAJECTORY_STATUS_ILLEGAL_FINAL = 6
TRAJECTORY_STATUS_IMPOSSIBLE = 7
__slots__ = ['header','position','velocity','acceleration','yaw','yaw_dot','kx','kv','trajectory_id','trajectory_flag']
_slot_types = ['std_msgs/Header','geometry_msgs/Point','geometry_msgs/Vector3','geometry_msgs/Vector3','float64','float64','float64[3]','float64[3]','uint32','uint8']
def __init__(self, *args, **kwds):
"""
Constructor. Any message fields that are implicitly/explicitly
set to None will be assigned a default value. The recommend
use is keyword arguments as this is more robust to future message
changes. You cannot mix in-order arguments and keyword arguments.
The available fields are:
header,position,velocity,acceleration,yaw,yaw_dot,kx,kv,trajectory_id,trajectory_flag
:param args: complete set of field values, in .msg order
:param kwds: use keyword arguments corresponding to message field names
to set specific fields.
"""
if args or kwds:
super(PositionCommand, self).__init__(*args, **kwds)
# message fields cannot be None, assign default values for those that are
if self.header is None:
self.header = std_msgs.msg.Header()
if self.position is None:
self.position = geometry_msgs.msg.Point()
if self.velocity is None:
self.velocity = geometry_msgs.msg.Vector3()
if self.acceleration is None:
self.acceleration = geometry_msgs.msg.Vector3()
if self.yaw is None:
self.yaw = 0.
if self.yaw_dot is None:
self.yaw_dot = 0.
if self.kx is None:
self.kx = [0.] * 3
if self.kv is None:
self.kv = [0.] * 3
if self.trajectory_id is None:
self.trajectory_id = 0
if self.trajectory_flag is None:
self.trajectory_flag = 0
else:
self.header = std_msgs.msg.Header()
self.position = geometry_msgs.msg.Point()
self.velocity = geometry_msgs.msg.Vector3()
self.acceleration = geometry_msgs.msg.Vector3()
self.yaw = 0.
self.yaw_dot = 0.
self.kx = [0.] * 3
self.kv = [0.] * 3
self.trajectory_id = 0
self.trajectory_flag = 0
def _get_types(self):
"""
internal API method
"""
return self._slot_types
def serialize(self, buff):
"""
serialize message into buffer
:param buff: buffer, ``StringIO``
"""
try:
_x = self
buff.write(_get_struct_3I().pack(_x.header.seq, _x.header.stamp.secs, _x.header.stamp.nsecs))
_x = self.header.frame_id
length = len(_x)
if python3 or type(_x) == unicode:
_x = _x.encode('utf-8')
length = len(_x)
buff.write(struct.Struct('<I%ss'%length).pack(length, _x))
_x = self
buff.write(_get_struct_11d().pack(_x.position.x, _x.position.y, _x.position.z, _x.velocity.x, _x.velocity.y, _x.velocity.z, _x.acceleration.x, _x.acceleration.y, _x.acceleration.z, _x.yaw, _x.yaw_dot))
buff.write(_get_struct_3d().pack(*self.kx))
buff.write(_get_struct_3d().pack(*self.kv))
_x = self
buff.write(_get_struct_IB().pack(_x.trajectory_id, _x.trajectory_flag))
except struct.error as se: self._check_types(struct.error("%s: '%s' when writing '%s'" % (type(se), str(se), str(locals().get('_x', self)))))
except TypeError as te: self._check_types(ValueError("%s: '%s' when writing '%s'" % (type(te), str(te), str(locals().get('_x', self)))))
def deserialize(self, str):
"""
unpack serialized message in str into this message instance
:param str: byte array of serialized message, ``str``
"""
if python3:
codecs.lookup_error("rosmsg").msg_type = self._type
try:
if self.header is None:
self.header = std_msgs.msg.Header()
if self.position is None:
self.position = geometry_msgs.msg.Point()
if self.velocity is None:
self.velocity = geometry_msgs.msg.Vector3()
if self.acceleration is None:
self.acceleration = geometry_msgs.msg.Vector3()
end = 0
_x = self
start = end
end += 12
(_x.header.seq, _x.header.stamp.secs, _x.header.stamp.nsecs,) = _get_struct_3I().unpack(str[start:end])
start = end
end += 4
(length,) = _struct_I.unpack(str[start:end])
start = end
end += length
if python3:
self.header.frame_id = str[start:end].decode('utf-8', 'rosmsg')
else:
self.header.frame_id = str[start:end]
_x = self
start = end
end += 88
(_x.position.x, _x.position.y, _x.position.z, _x.velocity.x, _x.velocity.y, _x.velocity.z, _x.acceleration.x, _x.acceleration.y, _x.acceleration.z, _x.yaw, _x.yaw_dot,) = _get_struct_11d().unpack(str[start:end])
start = end
end += 24
self.kx = _get_struct_3d().unpack(str[start:end])
start = end
end += 24
self.kv = _get_struct_3d().unpack(str[start:end])
_x = self
start = end
end += 5
(_x.trajectory_id, _x.trajectory_flag,) = _get_struct_IB().unpack(str[start:end])
return self
except struct.error as e:
raise genpy.DeserializationError(e) # most likely buffer underfill
def serialize_numpy(self, buff, numpy):
"""
serialize message with numpy array types into buffer
:param buff: buffer, ``StringIO``
:param numpy: numpy python module
"""
try:
_x = self
buff.write(_get_struct_3I().pack(_x.header.seq, _x.header.stamp.secs, _x.header.stamp.nsecs))
_x = self.header.frame_id
length = len(_x)
if python3 or type(_x) == unicode:
_x = _x.encode('utf-8')
length = len(_x)
buff.write(struct.Struct('<I%ss'%length).pack(length, _x))
_x = self
buff.write(_get_struct_11d().pack(_x.position.x, _x.position.y, _x.position.z, _x.velocity.x, _x.velocity.y, _x.velocity.z, _x.acceleration.x, _x.acceleration.y, _x.acceleration.z, _x.yaw, _x.yaw_dot))
buff.write(self.kx.tostring())
buff.write(self.kv.tostring())
_x = self
buff.write(_get_struct_IB().pack(_x.trajectory_id, _x.trajectory_flag))
except struct.error as se: self._check_types(struct.error("%s: '%s' when writing '%s'" % (type(se), str(se), str(locals().get('_x', self)))))
except TypeError as te: self._check_types(ValueError("%s: '%s' when writing '%s'" % (type(te), str(te), str(locals().get('_x', self)))))
def deserialize_numpy(self, str, numpy):
"""
unpack serialized message in str into this message instance using numpy for array types
:param str: byte array of serialized message, ``str``
:param numpy: numpy python module
"""
if python3:
codecs.lookup_error("rosmsg").msg_type = self._type
try:
if self.header is None:
self.header = std_msgs.msg.Header()
if self.position is None:
self.position = geometry_msgs.msg.Point()
if self.velocity is None:
self.velocity = geometry_msgs.msg.Vector3()
if self.acceleration is None:
self.acceleration = geometry_msgs.msg.Vector3()
end = 0
_x = self
start = end
end += 12
(_x.header.seq, _x.header.stamp.secs, _x.header.stamp.nsecs,) = _get_struct_3I().unpack(str[start:end])
start = end
end += 4
(length,) = _struct_I.unpack(str[start:end])
start = end
end += length
if python3:
self.header.frame_id = str[start:end].decode('utf-8', 'rosmsg')
else:
self.header.frame_id = str[start:end]
_x = self
start = end
end += 88
(_x.position.x, _x.position.y, _x.position.z, _x.velocity.x, _x.velocity.y, _x.velocity.z, _x.acceleration.x, _x.acceleration.y, _x.acceleration.z, _x.yaw, _x.yaw_dot,) = _get_struct_11d().unpack(str[start:end])
start = end
end += 24
self.kx = numpy.frombuffer(str[start:end], dtype=numpy.float64, count=3)
start = end
end += 24
self.kv = numpy.frombuffer(str[start:end], dtype=numpy.float64, count=3)
_x = self
start = end
end += 5
(_x.trajectory_id, _x.trajectory_flag,) = _get_struct_IB().unpack(str[start:end])
return self
except struct.error as e:
raise genpy.DeserializationError(e) # most likely buffer underfill
_struct_I = genpy.struct_I
def _get_struct_I():
global _struct_I
return _struct_I
_struct_11d = None
def _get_struct_11d():
global _struct_11d
if _struct_11d is None:
_struct_11d = struct.Struct("<11d")
return _struct_11d
_struct_3I = None
def _get_struct_3I():
global _struct_3I
if _struct_3I is None:
_struct_3I = struct.Struct("<3I")
return _struct_3I
_struct_3d = None
def _get_struct_3d():
global _struct_3d
if _struct_3d is None:
_struct_3d = struct.Struct("<3d")
return _struct_3d
_struct_IB = None
def _get_struct_IB():
global _struct_IB
if _struct_IB is None:
_struct_IB = struct.Struct("<IB")
return _struct_IB

View File

@ -0,0 +1 @@
from ._PositionCommand import *

View File

@ -5,7 +5,7 @@ from scipy.spatial.transform import Rotation as R
class LatticeParam():
def __init__(self, cfg):
self.vel_max = cfg["vel_max"]
segment_time = 2 * cfg["radio_range"] / self.vel_max
self.segment_time = 2 * cfg["radio_range"] / self.vel_max
self.horizon_num = cfg["horizon_num"]
self.vertical_num = cfg["vertical_num"]
self.radio_num = cfg["radio_num"]
@ -17,10 +17,10 @@ class LatticeParam():
self.radio_range = cfg["radio_range"]
self.vel_fov = cfg["vel_fov"]
self.vel_prefile = cfg["vel_prefile"]
self.acc_max = self.vel_max / segment_time
self.acc_max = self.vel_max / self.segment_time
print("---------------------")
print("| max speed = ", round(self.vel_max, 1), " |")
print("| traj time = ", round(segment_time, 1), " |")
print("| traj time = ", round(self.segment_time, 1), " |")
print("| max radio = ", round(2 * self.radio_range, 1), " |")
print("---------------------")
@ -95,6 +95,63 @@ class LatticePrimitive():
return self.lattice_Rbp_list[id]
class Poly5Solver:
def __init__(self, pos0, vel0, acc0, pos1, vel1, acc1, Tf):
""" 5-th order polynomial at each Axis """
State_Mat = np.array([pos0, vel0, acc0, pos1, vel1, acc1])
t = Tf
Coef_inv = np.array([[1, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0],
[0, 0, 1 / 2, 0, 0, 0],
[-10 / t ** 3, -6 / t ** 2, -3 / (2 * t), 10 / t ** 3, -4 / t ** 2, 1 / (2 * t)],
[15 / t ** 4, 8 / t ** 3, 3 / (2 * t ** 2), -15 / t ** 4, 7 / t ** 3, -1 / t ** 2],
[-6 / t ** 5, -3 / t ** 4, -1 / (2 * t ** 3), 6 / t ** 5, -3 / t ** 4, 1 / (2 * t ** 3)]])
self.A = np.dot(Coef_inv, State_Mat)
def get_snap(self, t):
"""Return the scalar jerk at time t."""
return 24 * self.A[4] + 120 * self.A[5] * t
def get_jerk(self, t):
"""Return the scalar jerk at time t."""
return 6 * self.A[3] + 24 * self.A[4] * t + 60 * self.A[5] * t * t
def get_acceleration(self, t):
"""Return the scalar acceleration at time t."""
return 2 * self.A[2] + 6 * self.A[3] * t + 12 * self.A[4] * t * t + 20 * self.A[5] * t * t * t
def get_velocity(self, t):
"""Return the scalar velocity at time t."""
return self.A[1] + 2 * self.A[2] * t + 3 * self.A[3] * t * t + 4 * self.A[4] * t * t * t + \
5 * self.A[5] * t * t * t * t
def get_position(self, t):
"""Return the scalar position at time t."""
return self.A[0] + self.A[1] * t + self.A[2] * t * t + self.A[3] * t * t * t + self.A[4] * t * t * t * t + \
self.A[5] * t * t * t * t * t
class Polys5Solver:
def __init__(self, pos0, vel0, acc0, pos1, vel1, acc1, Tf):
""" multiple 5-th order polynomials at each Axis (only used for visualization of multiple trajectories) """
N = len(pos1)
State_Mat = np.array([[pos0] * N, [vel0] * N, [acc0] * N, pos1, vel1, acc1])
t = Tf
Coef_inv = np.array([[1, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0],
[0, 0, 1 / 2, 0, 0, 0],
[-10 / t ** 3, -6 / t ** 2, -3 / (2 * t), 10 / t ** 3, -4 / t ** 2, 1 / (2 * t)],
[15 / t ** 4, 8 / t ** 3, 3 / (2 * t ** 2), -15 / t ** 4, 7 / t ** 3, -1 / t ** 2],
[-6 / t ** 5, -3 / t ** 4, -1 / (2 * t ** 3), 6 / t ** 5, -3 / t ** 4, 1 / (2 * t ** 3)]])
self.A = np.dot(Coef_inv, State_Mat)
def get_position(self, t):
"""Return the position array at time t."""
t = np.atleast_1d(t)
result = (self.A[0][:, np.newaxis] + self.A[1][:, np.newaxis] * t + self.A[2][:, np.newaxis] * t ** 2 +
self.A[3][:, np.newaxis] * t ** 3 + self.A[4][:, np.newaxis] * t ** 4 + self.A[5][:, np.newaxis] * t ** 5 )
return result.flatten()
"""
From body to world
p_w = Rwb * p_b + t_w
@ -135,3 +192,71 @@ def rotate_inv(q_wb, pos_w): # quat: wxzy
def transform_inv(q_wb, tw, pos_w):
pos_b = rotate_inv(q_wb, pos_w - tw)
return pos_b
def calculate_yaw(vel_dir, goal_dir, last_yaw_, dt, max_yaw_rate=0.3):
YAW_DOT_MAX_PER_SEC = max_yaw_rate * np.pi
# Normalize direction of velocity
vel_dir = vel_dir / (np.linalg.norm(vel_dir) + 1e-5)
# Direction of goal
goal_dist = np.linalg.norm(goal_dir)
goal_dir = goal_dir / (goal_dist + 1e-5) # Prevent division by zero
# Desired direction
dir_des = vel_dir + goal_dir
# Temporary yaw calculation
yaw_temp = np.arctan2(dir_des[1], dir_des[0]) if goal_dist > 0.2 else last_yaw_
max_yaw_change = YAW_DOT_MAX_PER_SEC * dt
# Initialize yaw and yawdot
yaw = last_yaw_
yawdot = 0
# Logic for yaw adjustment
if yaw_temp - last_yaw_ > np.pi:
if yaw_temp - last_yaw_ - 2 * np.pi < -max_yaw_change:
yaw = last_yaw_ - max_yaw_change
if yaw < -np.pi:
yaw += 2 * np.pi
yawdot = -YAW_DOT_MAX_PER_SEC
else:
yaw = yaw_temp
if yaw - last_yaw_ > np.pi:
yawdot = -YAW_DOT_MAX_PER_SEC
else:
yawdot = (yaw_temp - last_yaw_) / dt
elif yaw_temp - last_yaw_ < -np.pi:
if yaw_temp - last_yaw_ + 2 * np.pi > max_yaw_change:
yaw = last_yaw_ + max_yaw_change
if yaw > np.pi:
yaw -= 2 * np.pi
yawdot = YAW_DOT_MAX_PER_SEC
else:
yaw = yaw_temp
if yaw - last_yaw_ < -np.pi:
yawdot = YAW_DOT_MAX_PER_SEC
else:
yawdot = (yaw_temp - last_yaw_) / dt
else:
if yaw_temp - last_yaw_ < -max_yaw_change:
yaw = last_yaw_ - max_yaw_change
if yaw < -np.pi:
yaw += 2 * np.pi
yawdot = -YAW_DOT_MAX_PER_SEC
elif yaw_temp - last_yaw_ > max_yaw_change:
yaw = last_yaw_ + max_yaw_change
if yaw > np.pi:
yaw -= 2 * np.pi
yawdot = YAW_DOT_MAX_PER_SEC
else:
yaw = yaw_temp
if yaw - last_yaw_ > np.pi:
yawdot = -YAW_DOT_MAX_PER_SEC
elif yaw - last_yaw_ < -np.pi:
yawdot = YAW_DOT_MAX_PER_SEC
else:
yawdot = (yaw_temp - last_yaw_) / dt
return yaw, yawdot

View File

@ -1,3 +1,15 @@
"""
YOPO Network Inference NODE:
Subscribe odometry and depth messages, and perform network inference
Use:
$ cd ~/YOPO/run
$ conda activate yopo
$ python test_yopo_ros.py --trial=1 --epoch=0 --iter=0
$ cd ~/YOPO/flightlib/build
$ ./yopo_planner_node
"""
import rospy
from sensor_msgs.msg import Image
from nav_msgs.msg import Odometry

435
run/test_yopo_ros_new.py Normal file
View File

@ -0,0 +1,435 @@
"""
YOPO ROS NODE:
Subscribe odometry and depth messages, perform network inference, solve trajectory, and publish control commands.
Used to replace test_yopo_ros.py and yopo_planner_node.cpp.
If you encounter issues (such as unsmooth) with this script, try using the following instead:
$ cd ~/YOPO/run
$ conda activate yopo
$ python test_yopo_ros.py --trial=1 --epoch=0 --iter=0
$ cd ~/YOPO/flightlib/build
$ ./yopo_planner_node
"""
import rospy
import std_msgs.msg
from nav_msgs.msg import Odometry
from geometry_msgs.msg import PoseStamped
from sensor_msgs.msg import PointCloud2, PointField, Image
from sensor_msgs import point_cloud2
from cv_bridge import CvBridge
from threading import Lock
import numpy as np
import cv2
import os
import torch
import argparse
import time
from ruamel.yaml import YAML
from scipy.spatial.transform import Rotation as R
from flightpolicy.control_msg import PositionCommand
from flightpolicy.yopo.yopo_policy import YopoPolicy
from flightpolicy.yopo.primitive_utils import LatticeParam, LatticePrimitive, Poly5Solver, Polys5Solver, calculate_yaw
try:
from torch2trt import TRTModule
except ImportError:
print("tensorrt not found.")
class YopoNet:
def __init__(self, config, weight):
self.config = config
rospy.init_node('yopo_net', anonymous=False)
# load params
self.height = self.config['img_height']
self.width = self.config['img_width']
self.goal = np.array(self.config['goal'])
self.env = self.config['env']
self.use_trt = self.config['use_tensorrt']
self.verbose = self.config['verbose']
self.visualize = self.config['visualize']
self.Rotation_bc = R.from_euler('ZYX', [0, self.config['pitch_angle_deg'], 0], degrees=True).as_matrix()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/traj_opt.yaml", 'r'))
self.lattice_space = LatticeParam(cfg)
self.lattice_primitive = LatticePrimitive(self.lattice_space)
# variables
self.bridge = CvBridge()
self.odom = Odometry()
self.odom_init = False
self.last_yaw = 0.0
self.ctrl_dt = 0.02
self.ctrl_time = None
self.desire_init = False
self.arrive = False
self.desire_pos = None
self.desire_vel = None
self.desire_acc = None
self.optimal_poly_x = None
self.optimal_poly_y = None
self.optimal_poly_z = None
self.lock = Lock()
# eval
self.time_forward = 0.0
self.time_process = 0.0
self.time_prepare = 0.0
self.time_interpolation = 0.0
self.time_visualize = 0.0
self.count = 0
# Load Network
if self.use_trt:
self.policy = TRTModule()
self.policy.load_state_dict(torch.load(weight))
else:
saved_variables = torch.load(weight, map_location=self.device)
saved_variables["data"]["lattice_space"] = self.lattice_space
saved_variables["data"]["lattice_primitive"] = self.lattice_primitive
self.policy = YopoPolicy(device=self.device, **saved_variables["data"])
self.policy.load_state_dict(saved_variables["state_dict"], strict=False)
self.policy.to(self.device)
self.policy.set_training_mode(False)
torch.set_grad_enabled(False)
self.warm_up()
# ros publisher
odom_topic = self.config['odom_topic']
depth_topic = self.config['depth_topic']
self.lattice_traj_pub = rospy.Publisher("/yopo_net/lattice_trajs_visual", PointCloud2, queue_size=1)
self.best_traj_pub = rospy.Publisher("/yopo_net/best_traj_visual", PointCloud2, queue_size=1)
self.all_trajs_pub = rospy.Publisher("/yopo_net/trajs_visual", PointCloud2, queue_size=1)
self.ctrl_pub = rospy.Publisher("/so3_control/pos_cmd", PositionCommand, queue_size=1)
# ros subscriber
self.odom_sub = rospy.Subscriber(odom_topic, Odometry, self.callback_odometry, queue_size=1)
self.depth_sub = rospy.Subscriber(depth_topic, Image, self.callback_depth, queue_size=1)
self.goal_sub = rospy.Subscriber("/move_base_simple/goal", PoseStamped, self.callback_set_goal, queue_size=1)
# ros timer
rospy.sleep(1.0) # wait connection...
self.timer_ctrl = rospy.Timer(rospy.Duration(self.ctrl_dt), self.control_pub)
print("YOPO Net Node Ready!")
rospy.spin()
def callback_set_goal(self, data):
self.goal = np.asarray([data.pose.position.x, data.pose.position.y, 2])
self.arrive = False
print(f"New Goal: ({data.pose.position.x:.1f}, {data.pose.position.y:.1f})")
# the first frame
def callback_odometry(self, data):
self.odom = data
if not self.desire_init:
self.desire_pos = np.array((self.odom.pose.pose.position.x, self.odom.pose.pose.position.y, self.odom.pose.pose.position.z))
self.desire_vel = np.array((self.odom.twist.twist.linear.x, self.odom.twist.twist.linear.y, self.odom.twist.twist.linear.z))
self.desire_acc = np.array((0.0, 0.0, 0.0))
ypr = R.from_quat([self.odom.pose.pose.orientation.x, self.odom.pose.pose.orientation.y,
self.odom.pose.pose.orientation.z, self.odom.pose.pose.orientation.w]).as_euler('ZYX', degrees=False)
self.last_yaw = ypr[0]
self.odom_init = True
pos = np.array((self.odom.pose.pose.position.x, self.odom.pose.pose.position.y, self.odom.pose.pose.position.z))
if np.linalg.norm(pos - self.goal) < 4 and not self.arrive:
print("Arrive!")
self.arrive = True
def process_odom(self):
# Rwb -> Rwc -> Rcw
Rotation_wb = R.from_quat([self.odom.pose.pose.orientation.x, self.odom.pose.pose.orientation.y,
self.odom.pose.pose.orientation.z, self.odom.pose.pose.orientation.w]).as_matrix()
self.Rotation_wc = np.dot(Rotation_wb, self.Rotation_bc)
Rotation_cw = self.Rotation_wc.T
# vel and acc
vel_w = self.desire_vel
vel_c = np.dot(Rotation_cw, vel_w)
acc_w = self.desire_acc
acc_c = np.dot(Rotation_cw, acc_w)
# pose and goal_dir
goal_w = (self.goal - self.desire_pos) / np.linalg.norm(self.goal - self.desire_pos)
goal_c = np.dot(Rotation_cw, goal_w)
vel_acc = np.concatenate((vel_c, acc_c), axis=0)
vel_acc_norm = self.normalize_obs(vel_acc[np.newaxis, :])
obs_norm = np.hstack((vel_acc_norm, goal_c[np.newaxis, :]))
return obs_norm
def callback_depth(self, data):
if not self.odom_init:
return
# 1. Depth Image Process
min_dis, max_dis = 0.03, 20.0
scale = {'435': 0.001, 'flightmare': 1.0}.get(self.env, 1.0)
try:
depth = self.bridge.imgmsg_to_cv2(data, "32FC1")
except:
print("CV_bridge ERROR: Possible solutions may be found at https://github.com/TJU-Aerial-Robotics/YOPO/issues/2")
time0 = time.time()
if depth.shape[0] != self.height or depth.shape[1] != self.width:
depth = cv2.resize(depth, (self.width, self.height), interpolation=cv2.INTER_NEAREST)
depth = np.minimum(depth * scale, max_dis) / max_dis
# interpolated the nan value (experiment shows that treating nan directly as 0 produces similar results)
nan_mask = np.isnan(depth) | (depth < min_dis)
interpolated_image = cv2.inpaint(np.uint8(depth * 255), np.uint8(nan_mask), 1, cv2.INPAINT_NS)
interpolated_image = interpolated_image.astype(np.float32) / 255.0
depth = interpolated_image.reshape([1, 1, self.height, self.width])
# cv2.imshow("1", depth[0][0])
# cv2.waitKey(1)
# 2. YOPO Network Inference
# input prepare
time1 = time.time()
depth_input = torch.from_numpy(depth).to(self.device, non_blocking=True) # (non_blocking: copying speed 3x)
obs = self.process_odom()
obs_input = self.prepare_input_observation(obs)
obs_input = obs_input.to(self.device, non_blocking=True)
# torch.cuda.synchronize()
time2 = time.time()
# Forward (TensorRT: inference speed increased by 5x)
with torch.no_grad():
network_output = self.policy(depth_input, obs_input)
network_output = network_output.cpu().numpy() # torch.cuda.synchronize() is not needed here
time3 = time.time()
# Replacing PyTorch operation on CUDA with NumPy operation on CPU (speed increased by 10x)
endstate_pred, score_pred = self.process_output(network_output, return_all_preds=self.visualize)
# Vectorization: transform the prediction(P V A in body frame) to the world frame with the attitude (without the position)
endstate_c = endstate_pred.T.reshape(-1, 3, 3)
endstate_w = np.matmul(self.Rotation_wc, endstate_c)
# endstate_w = endstate_w.reshape(-1, 9).T
action_id = np.argmin(score_pred) if self.visualize else 0
with self.lock: # Python3.8: threads are scheduled using time slices, add the lock to ensure safety
self.optimal_poly_x = Poly5Solver(self.desire_pos[0], self.desire_vel[0], self.desire_acc[0],
endstate_w[action_id, 0, 0] + self.desire_pos[0], endstate_w[action_id, 0, 1], endstate_w[action_id, 0, 2], self.lattice_space.segment_time)
self.optimal_poly_y = Poly5Solver(self.desire_pos[1], self.desire_vel[1], self.desire_acc[1],
endstate_w[action_id, 1, 0] + self.desire_pos[1], endstate_w[action_id, 1, 1], endstate_w[action_id, 1, 2], self.lattice_space.segment_time)
self.optimal_poly_z = Poly5Solver(self.desire_pos[2], self.desire_vel[2], self.desire_acc[2],
endstate_w[action_id, 2, 0] + self.desire_pos[2], endstate_w[action_id, 2, 1], endstate_w[action_id, 2, 2], self.lattice_space.segment_time)
self.ctrl_time = 0.0
time4 = time.time()
self.visualize_trajectory(score_pred, endstate_w)
time5 = time.time()
if self.verbose:
self.time_interpolation = self.time_interpolation + (time1 - time0)
self.time_prepare = self.time_prepare + (time2 - time1)
self.time_forward = self.time_forward + (time3 - time2)
self.time_process = self.time_process + (time4 - time3)
self.time_visualize = self.time_visualize + (time5 - time4)
self.count = self.count + 1
print(f"Time Consuming:"
f"depth-interpolation: {1000 * self.time_interpolation / self.count:.2f}ms;"
f"data-prepare: {1000 * self.time_prepare / self.count:.2f}ms; "
f"network-inference: {1000 * self.time_forward / self.count:.2f}ms; "
f"post-process: {1000 * self.time_process / self.count:.2f}ms;"
f"visualize-trajectory: {1000 * self.time_visualize / self.count:.2f}ms")
def control_pub(self, _timer):
if self.ctrl_time is None or self.ctrl_time > self.lattice_space.segment_time:
return
if self.arrive:
self.desire_init = False # ready for next rollout
return
with self.lock: # Python3.8: threads are scheduled using time slices, add the lock to ensure safety and publish frequency
self.ctrl_time += self.ctrl_dt
control_msg = PositionCommand()
control_msg.header.stamp = rospy.Time.now()
control_msg.trajectory_flag = control_msg.TRAJECTORY_STATUS_READY
control_msg.position.x = self.optimal_poly_x.get_position(self.ctrl_time)
control_msg.position.y = self.optimal_poly_y.get_position(self.ctrl_time)
control_msg.position.z = self.optimal_poly_z.get_position(self.ctrl_time)
control_msg.velocity.x = self.optimal_poly_x.get_velocity(self.ctrl_time)
control_msg.velocity.y = self.optimal_poly_y.get_velocity(self.ctrl_time)
control_msg.velocity.z = self.optimal_poly_z.get_velocity(self.ctrl_time)
control_msg.acceleration.x = self.optimal_poly_x.get_acceleration(self.ctrl_time)
control_msg.acceleration.y = self.optimal_poly_y.get_acceleration(self.ctrl_time)
control_msg.acceleration.z = self.optimal_poly_z.get_acceleration(self.ctrl_time)
self.desire_pos = np.array([control_msg.position.x, control_msg.position.y, control_msg.position.z])
self.desire_vel = np.array([control_msg.velocity.x, control_msg.velocity.y, control_msg.velocity.z])
self.desire_acc = np.array([control_msg.acceleration.x, control_msg.acceleration.y, control_msg.acceleration.z])
goal_dir = self.goal - self.desire_pos
yaw, yaw_dot = calculate_yaw(self.desire_vel, goal_dir, self.last_yaw, self.ctrl_dt)
self.last_yaw = yaw
control_msg.yaw = yaw
control_msg.yaw_dot = yaw_dot
self.desire_init = True
self.ctrl_pub.publish(control_msg)
def process_output(self, network_output, return_all_preds=False):
if network_output.shape[0] != 1:
raise ValueError("batch of output values must be 1 in test!")
network_output = network_output.reshape(10, self.lattice_space.horizon_num * self.lattice_space.vertical_num)
endstate_pred = network_output[0:9, :]
score_pred = network_output[9, :]
if not return_all_preds:
action_id = np.argmin(score_pred)
lattice_id = self.lattice_space.horizon_num * self.lattice_space.vertical_num - 1 - action_id
endstate_prediction = self.pred_to_endstate(endstate_pred[:, action_id], lattice_id)
endstate_prediction = endstate_prediction[:, np.newaxis]
score_prediction = score_pred[action_id]
else:
endstate_prediction = np.zeros_like(endstate_pred)
score_prediction = score_pred
for i in range(self.lattice_space.horizon_num * self.lattice_space.vertical_num):
lattice_id = self.lattice_space.horizon_num * self.lattice_space.vertical_num - 1 - i
endstate_prediction[:, i] = self.pred_to_endstate(endstate_pred[:, i], lattice_id)
return endstate_prediction, score_prediction
def prepare_input_observation(self, obs):
"""
convert the observation from body frame to primitive frame,
and then concatenate it with the depth features (to ensure the translational invariance)
"""
if obs.shape[0] != 1:
raise ValueError("batch of input observations must be 1 in test!")
obs_return = np.ones((obs.shape[0], obs.shape[1], self.lattice_space.vertical_num, self.lattice_space.horizon_num), dtype=np.float32)
id = 0
obs_reshaped = obs.reshape(3, 3)
for i in range(self.lattice_space.vertical_num - 1, -1, -1):
for j in range(self.lattice_space.horizon_num - 1, -1, -1):
Rbp = self.lattice_primitive.getRotation(id)
obs_return_reshaped = np.dot(obs_reshaped, Rbp)
obs_return[:, :, i, j] = obs_return_reshaped.reshape(9)
id = id + 1
return torch.from_numpy(obs_return)
def pred_to_endstate(self, endstate_pred: np.ndarray, id: int):
"""
Transform the predicted state to the body frame.
"""
delta_yaw = endstate_pred[0] * self.lattice_primitive.yaw_diff
delta_pitch = endstate_pred[1] * self.lattice_primitive.pitch_diff
radio = endstate_pred[2] * self.lattice_space.radio_range + self.lattice_space.radio_range
yaw, pitch = self.lattice_primitive.getAngleLattice(id)
endstate_x = np.cos(pitch + delta_pitch) * np.cos(yaw + delta_yaw) * radio
endstate_y = np.cos(pitch + delta_pitch) * np.sin(yaw + delta_yaw) * radio
endstate_z = np.sin(pitch + delta_pitch) * radio
endstate_p = np.array((endstate_x, endstate_y, endstate_z))
endstate_vp = endstate_pred[3:6] * self.lattice_space.vel_max
endstate_ap = endstate_pred[6:9] * self.lattice_space.acc_max
Rpb = self.lattice_primitive.getRotation(id).T
endstate_vb = np.matmul(endstate_vp, Rpb)
endstate_ab = np.matmul(endstate_ap, Rpb)
endstate = np.concatenate((endstate_p, endstate_vb, endstate_ab))
endstate[[0, 1, 2, 3, 4, 5, 6, 7, 8]] = endstate[[0, 3, 6, 1, 4, 7, 2, 5, 8]]
return endstate
def normalize_obs(self, vel_acc):
vel_norm = vel_acc[:, 0:3] / self.lattice_space.vel_max
acc_norm = vel_acc[:, 3:6] / self.lattice_space.acc_max
return np.hstack((vel_norm, acc_norm))
def visualize_trajectory(self, pred_score, pred_endstate):
dt = self.lattice_space.segment_time / 20.0
# best predicted trajectory
if self.best_traj_pub.get_num_connections() > 0:
t_values = np.arange(0, self.lattice_space.segment_time, dt)
points_array = np.stack((
self.optimal_poly_x.get_position(t_values),
self.optimal_poly_y.get_position(t_values),
self.optimal_poly_z.get_position(t_values)
), axis=-1)
header = std_msgs.msg.Header()
header.stamp = rospy.Time.now()
header.frame_id = 'world'
point_cloud_msg = point_cloud2.create_cloud_xyz32(header, points_array)
self.best_traj_pub.publish(point_cloud_msg)
# lattice primitive
if self.visualize and self.lattice_traj_pub.get_num_connections() > 0:
lattice_endstate = self.lattice_primitive.lattice_pos_node
lattice_endstate = np.dot(lattice_endstate, self.Rotation_wc.T)
zero_state = np.zeros_like(lattice_endstate)
lattice_poly_x = Polys5Solver(self.desire_pos[0], self.desire_vel[0], self.desire_acc[0],
lattice_endstate[:, 0] + self.desire_pos[0], zero_state[:, 0], zero_state[:, 0], self.lattice_space.segment_time)
lattice_poly_y = Polys5Solver(self.desire_pos[1], self.desire_vel[1], self.desire_acc[1],
lattice_endstate[:, 1] + self.desire_pos[1], zero_state[:, 1], zero_state[:, 1], self.lattice_space.segment_time)
lattice_poly_z = Polys5Solver(self.desire_pos[2], self.desire_vel[2], self.desire_acc[2],
lattice_endstate[:, 2] + self.desire_pos[2], zero_state[:, 2], zero_state[:, 2], self.lattice_space.segment_time)
t_values = np.arange(0, self.lattice_space.segment_time, dt)
points_array = np.stack((
lattice_poly_x.get_position(t_values),
lattice_poly_y.get_position(t_values),
lattice_poly_z.get_position(t_values)
), axis=-1)
header = std_msgs.msg.Header()
header.stamp = rospy.Time.now()
header.frame_id = 'world'
point_cloud_msg = point_cloud2.create_cloud_xyz32(header, points_array)
self.lattice_traj_pub.publish(point_cloud_msg)
# all predicted trajectories
if self.visualize and self.all_trajs_pub.get_num_connections() > 0:
all_poly_x = Polys5Solver(self.desire_pos[0], self.desire_vel[0], self.desire_acc[0],
pred_endstate[:, 0, 0] + self.desire_pos[0], pred_endstate[:, 0, 1], pred_endstate[:, 0, 2], self.lattice_space.segment_time)
all_poly_y = Polys5Solver(self.desire_pos[1], self.desire_vel[1], self.desire_acc[1],
pred_endstate[:, 1, 0] + self.desire_pos[1], pred_endstate[:, 1, 1], pred_endstate[:, 1, 2], self.lattice_space.segment_time)
all_poly_z = Polys5Solver(self.desire_pos[2], self.desire_vel[2], self.desire_acc[2],
pred_endstate[:, 2, 0] + self.desire_pos[2], pred_endstate[:, 2, 1], pred_endstate[:, 2, 2], self.lattice_space.segment_time)
t_values = np.arange(0, self.lattice_space.segment_time, dt)
points_array = np.stack((
all_poly_x.get_position(t_values),
all_poly_y.get_position(t_values),
all_poly_z.get_position(t_values)
), axis=-1)
scores = np.repeat(pred_score, t_values.size)
points_array = np.column_stack((points_array, scores))
header = std_msgs.msg.Header()
header.stamp = rospy.Time.now()
header.frame_id = 'world'
fields = [PointField('x', 0, PointField.FLOAT32, 1), PointField('y', 4, PointField.FLOAT32, 1),
PointField('z', 8, PointField.FLOAT32, 1), PointField('intensity', 12, PointField.FLOAT32, 1)]
point_cloud_msg = point_cloud2.create_cloud(header, fields, points_array)
self.all_trajs_pub.publish(point_cloud_msg)
def warm_up(self):
depth = np.zeros(shape=[1, 1, self.height, self.width], dtype=np.float32)
obs = np.zeros(shape=[1, 9], dtype=np.float32)
obs_input = self.prepare_input_observation(obs)
network_output = self.policy(torch.from_numpy(depth).to(self.device), obs_input.to(self.device))
self.process_output(network_output.cpu().numpy(), return_all_preds=True)
def parser():
parser = argparse.ArgumentParser()
parser.add_argument("--use_tensorrt", type=int, default=0, help="use tensorrt or not")
parser.add_argument("--trial", type=int, default=1, help="trial number")
parser.add_argument("--epoch", type=int, default=0, help="epoch number")
parser.add_argument("--iter", type=int, default=0, help="iter number")
parser.add_argument("--trt_file", type=str, default='yopo_trt.pth', help="tensorrt filename")
return parser
# In realworld flight: visualize=False; use_tensorrt=True, and ensure the pitch_angle consistent with your platform
# When modifying the pitch_angle, there's no need to re-collect and re-train, as all predictions are in the camera coordinate system
# Change the flight speed at traj_opt.yaml and there's no need to re-collect and re-train
def main():
args = parser().parse_args()
rsg_root = os.path.dirname(os.path.abspath(__file__))
weight = args.trt_file if args.use_tensorrt else f"{rsg_root}/saved/YOPO_{args.trial}/Policy/epoch{args.epoch}_iter{args.iter}.pth"
print("load weight from:", weight)
settings = {'use_tensorrt': args.use_tensorrt,
'img_height': 96,
'img_width': 160,
'goal': [20, 20, 2], # the goal
'env': 'flightmare', # use Realsense D435 or Flightmare Simulator ('435' or 'flightmare')
'pitch_angle_deg': -5, # pitch of camera, ensure consistent with the simulator or your platform (no need to re-collect and re-train when modifying)
'odom_topic': '/juliett/ground_truth/odom',
'depth_topic': '/depth_image',
'verbose': False, # print the latency?
'visualize': True # visualize all predictions? set False in real flight
}
YopoNet(settings, weight)
if __name__ == "__main__":
main()