107 lines
4.9 KiB
Python
107 lines
4.9 KiB
Python
import os
|
|
import cv2
|
|
import numpy as np
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from ruamel.yaml import YAML
|
|
import time
|
|
from scipy.spatial.transform import Rotation as R
|
|
|
|
|
|
class YopoDataset(Dataset):
|
|
def __init__(self):
|
|
super(YopoDataset, self).__init__()
|
|
cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/traj_opt.yaml", 'r'))
|
|
scale = 32 # 神经网络下采样倍数
|
|
self.height = scale * cfg["vertical_num"]
|
|
self.width = scale * cfg["horizon_num"]
|
|
multiple_ = 0.5 * cfg["vel_max"]
|
|
# The x-direction follows a log-normal distribution,
|
|
# while the yz-direction follows a normal distribution with a mean of 0.
|
|
self.v_max = cfg["vel_max"]
|
|
v_des = multiple_ * cfg["vx_mean_unit"]
|
|
self.vx_lognorm_mean = np.log(self.v_max - v_des)
|
|
self.vx_logmorm_sigma = np.log(np.sqrt(v_des))
|
|
self.v_mean = multiple_ * np.array([cfg["vx_mean_unit"], cfg["vy_mean_unit"], cfg["vz_mean_unit"]])
|
|
self.v_var = multiple_ * multiple_ * np.array([cfg["vx_var_unit"], cfg["vy_var_unit"], cfg["vz_var_unit"]])
|
|
self.a_mean = multiple_ * multiple_ * np.array([cfg["ax_mean_unit"], cfg["ay_mean_unit"], cfg["az_mean_unit"]])
|
|
self.a_var = multiple_ * multiple_ * multiple_ * multiple_ * np.array([cfg["ax_var_unit"], cfg["ay_var_unit"], cfg["az_var_unit"]])
|
|
|
|
print("Loading dataset, it may take a while...")
|
|
data_cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/vec_env.yaml", 'r'))
|
|
data_dir = os.environ["FLIGHTMARE_PATH"] + data_cfg["env"]["dataset_path"]
|
|
|
|
self.img_list = []
|
|
self.map_idx = []
|
|
self.positions = np.empty((0, 3))
|
|
self.quaternions = np.empty((0, 4))
|
|
subfolders = [f.path for f in os.scandir(data_dir) if f.is_dir()]
|
|
subfolders.sort(key=lambda x: os.path.basename(x).lower())
|
|
for i in range(len(subfolders)):
|
|
img_dir = subfolders[i]
|
|
file_names = [filename
|
|
for filename in os.listdir(img_dir)
|
|
if os.path.splitext(filename)[1] == '.tif']
|
|
file_names.sort(key=lambda x: int(x.split('.')[0].split("_")[1])) # sort by filename
|
|
images = [cv2.imread(img_dir + "/" + filename, -1).astype(np.float32) for filename in file_names]
|
|
self.img_list.extend(images)
|
|
self.map_idx.extend([i] * len(images))
|
|
|
|
label_path = img_dir + "/label.npz"
|
|
labels = np.load(label_path)
|
|
self.positions = np.vstack((self.positions, labels["positions"]))
|
|
self.quaternions = np.vstack((self.quaternions, labels["quaternions"]))
|
|
|
|
print("Dataset loaded!")
|
|
|
|
def __len__(self):
|
|
return len(self.img_list)
|
|
|
|
def __getitem__(self, item):
|
|
if self.img_list[item].shape[-2] != self.height or self.img_list[item].shape[-1] != self.width:
|
|
self.img_list[item] = cv2.resize(self.img_list[item], (self.width, self.height)) # OpenCV and NumPy is Dif
|
|
|
|
if len(self.img_list[item].shape) == 2:
|
|
self.img_list[item] = np.expand_dims(self.img_list[item], axis=0)
|
|
|
|
vel, acc = self._get_random_state()
|
|
|
|
# generate random goal in front of the quadrotor.
|
|
q_wxyz = self.quaternions[item, :] # q: wxyz
|
|
R_WB = R.from_quat([q_wxyz[1], q_wxyz[2], q_wxyz[3], q_wxyz[0]])
|
|
euler_angles = R_WB.as_euler('ZYX', degrees=False) # [yaw(z) pitch(y) roll(x)]
|
|
R_wB = R.from_euler('ZYX', [0, euler_angles[1], euler_angles[2]], degrees=False)
|
|
goal_w = np.random.randn(3) + np.array([2, 0, 0])
|
|
goal_b = R_wB.inv().apply(goal_w)
|
|
|
|
goal_dist = np.linalg.norm(goal_b)
|
|
goal_dir = goal_b / goal_dist
|
|
random_obs = np.hstack((vel, acc, goal_dir))
|
|
|
|
return (self.img_list[item], self.positions[item, :], self.quaternions[item, :], random_obs,
|
|
self.map_idx[item]) # in body frame, vel_acc no-normalization
|
|
|
|
def _get_random_state(self):
|
|
vel = self.v_mean + np.sqrt(self.v_var) * np.random.randn(3)
|
|
acc = self.a_mean + np.sqrt(self.a_var) * np.random.randn(3)
|
|
|
|
right_skewed_vx = -1
|
|
while right_skewed_vx < 0:
|
|
right_skewed_vx = np.random.lognormal(mean=self.vx_lognorm_mean, sigma=self.vx_logmorm_sigma, size=None)
|
|
right_skewed_vx = -right_skewed_vx + self.v_max + 0.2 # +0.2 to ensure v_max can be sampled
|
|
vel[0] = right_skewed_vx
|
|
# distribution of vx is visualized in docs/distribution_of_sampled_velocity.png (v_max=6)
|
|
return vel, acc
|
|
|
|
|
|
if __name__ == '__main__':
|
|
data_loader = DataLoader(YopoDataset(), batch_size=32, shuffle=True, num_workers=4)
|
|
|
|
start = time.time()
|
|
for epoch in range(1):
|
|
last = time.time()
|
|
for i, (depth, pos, quat, obs, id) in enumerate(data_loader):
|
|
pass
|
|
end = time.time()
|
|
|
|
print("总耗时:", end - start)
|