YOPO/flightpolicy/yopo/dataloader.py

109 lines
4.9 KiB
Python

import os
import cv2
import numpy as np
from torch.utils.data import Dataset, DataLoader
from ruamel.yaml import YAML
import time
from scipy.spatial.transform import Rotation as R
class YopoDataset(Dataset):
def __init__(self):
super(YopoDataset, self).__init__()
cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/traj_opt.yaml", 'r'))
scale = 32 # 神经网络下采样倍数
self.height = scale * cfg["vertical_num"]
self.width = scale * cfg["horizon_num"]
multiple_ = 0.5 * cfg["vel_max"]
# The x-direction follows a log-normal distribution,
# while the yz-direction follows a normal distribution with a mean of 0.
self.v_max = cfg["vel_max"]
v_des = multiple_ * cfg["vx_mean_unit"]
self.vx_lognorm_mean = np.log(self.v_max - v_des)
self.vx_logmorm_sigma = np.log(np.sqrt(v_des))
self.v_mean = multiple_ * np.array([cfg["vx_mean_unit"], cfg["vy_mean_unit"], cfg["vz_mean_unit"]])
self.v_var = multiple_ * multiple_ * np.array([cfg["vx_var_unit"], cfg["vy_var_unit"], cfg["vz_var_unit"]])
self.a_mean = multiple_ * multiple_ * np.array([cfg["ax_mean_unit"], cfg["ay_mean_unit"], cfg["az_mean_unit"]])
self.a_var = multiple_ * multiple_ * multiple_ * multiple_ * np.array([cfg["ax_var_unit"], cfg["ay_var_unit"], cfg["az_var_unit"]])
print("Loading dataset, it may take a while...")
data_cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/vec_env.yaml", 'r'))
data_dir = os.environ["FLIGHTMARE_PATH"] + data_cfg["env"]["dataset_path"]
self.img_list = []
self.map_idx = []
self.positions = np.empty((0, 3))
self.quaternions = np.empty((0, 4))
subfolders = [f.path for f in os.scandir(data_dir) if f.is_dir()]
subfolders.sort(key=lambda x: int(os.path.basename(x)))
for folder in subfolders:
print(folder)
for i in range(len(subfolders)):
img_dir = subfolders[i]
file_names = [filename
for filename in os.listdir(img_dir)
if os.path.splitext(filename)[1] == '.tif']
file_names.sort(key=lambda x: int(x.split('.')[0].split("_")[1])) # sort by filename
images = [cv2.imread(img_dir + "/" + filename, -1).astype(np.float32) for filename in file_names]
self.img_list.extend(images)
self.map_idx.extend([i] * len(images))
label_path = img_dir + "/label.npz"
labels = np.load(label_path)
self.positions = np.vstack((self.positions, labels["positions"]))
self.quaternions = np.vstack((self.quaternions, labels["quaternions"]))
print("Dataset loaded!")
def __len__(self):
return len(self.img_list)
def __getitem__(self, item):
if self.img_list[item].shape[-2] != self.height or self.img_list[item].shape[-1] != self.width:
self.img_list[item] = cv2.resize(self.img_list[item], (self.width, self.height)) # OpenCV and NumPy is Dif
if len(self.img_list[item].shape) == 2:
self.img_list[item] = np.expand_dims(self.img_list[item], axis=0)
vel, acc = self._get_random_state()
# generate random goal in front of the quadrotor.
q_wxyz = self.quaternions[item, :] # q: wxyz
R_WB = R.from_quat([q_wxyz[1], q_wxyz[2], q_wxyz[3], q_wxyz[0]])
euler_angles = R_WB.as_euler('ZYX', degrees=False) # [yaw(z) pitch(y) roll(x)]
R_wB = R.from_euler('ZYX', [0, euler_angles[1], euler_angles[2]], degrees=False)
goal_w = np.random.randn(3) + np.array([2, 0, 0])
goal_b = R_wB.inv().apply(goal_w)
goal_dist = np.linalg.norm(goal_b)
goal_dir = goal_b / goal_dist
random_obs = np.hstack((vel, acc, goal_dir))
return (self.img_list[item], self.positions[item, :], self.quaternions[item, :], random_obs,
self.map_idx[item]) # in body frame, vel_acc no-normalization
def _get_random_state(self):
vel = self.v_mean + np.sqrt(self.v_var) * np.random.randn(3)
acc = self.a_mean + np.sqrt(self.a_var) * np.random.randn(3)
right_skewed_vx = -1
while right_skewed_vx < 0:
right_skewed_vx = np.random.lognormal(mean=self.vx_lognorm_mean, sigma=self.vx_logmorm_sigma, size=None)
right_skewed_vx = -right_skewed_vx + self.v_max + 0.2 # +0.2 to ensure v_max can be sampled
vel[0] = right_skewed_vx
# distribution of vx is visualized in docs/distribution_of_sampled_velocity.png (v_max=6)
return vel, acc
if __name__ == '__main__':
data_loader = DataLoader(YopoDataset(), batch_size=32, shuffle=True, num_workers=4)
start = time.time()
for epoch in range(1):
last = time.time()
for i, (depth, pos, quat, obs, id) in enumerate(data_loader):
pass
end = time.time()
print("总耗时:", end - start)