diff --git a/YOPO/policy/yopo_dataset.py b/YOPO/policy/yopo_dataset.py index 6fa16ff..6ececd5 100644 --- a/YOPO/policy/yopo_dataset.py +++ b/YOPO/policy/yopo_dataset.py @@ -1,6 +1,7 @@ import os, sys import cv2 import time +import torch import numpy as np from torch.utils.data import Dataset, DataLoader from scipy.spatial.transform import Rotation as R @@ -44,10 +45,10 @@ class YOPODataset(Dataset): for data_idx in range(len(datafolders)): datafolder = datafolders[data_idx] - image_file_names = [filename + image_file_names = [datafolder + "/" + filename for filename in os.listdir(datafolder) if os.path.splitext(filename)[1] == '.png'] - image_file_names.sort(key=lambda x: int(x.split('.')[0].split("_")[1])) # sort by filename to align with the label + image_file_names.sort(key=lambda x: int(os.path.basename(x).split('.')[0].split("_")[1])) # sort by filename to align with the label states = np.loadtxt(data_dir + f"/pose-{data_idx}.csv", delimiter=',', skiprows=1).astype(np.float32) positions = states[:, 0:3] @@ -57,28 +58,20 @@ class YOPODataset(Dataset): image_file_names, positions, quaternions, test_size=val_ratio, random_state=0) if mode == 'train': - images = [cv2.imread(datafolder + "/" + filename, -1).astype(np.float32) for filename in file_names_train] - self.img_list.extend(images) + self.img_list.extend(file_names_train) self.positions = np.vstack((self.positions, positions_train.astype(np.float32))) self.quaternions = np.vstack((self.quaternions, quaternions_train.astype(np.float32))) + self.map_idx.extend([data_idx] * len(file_names_train)) elif mode == 'valid': - images = [cv2.imread(datafolder + "/" + filename, -1).astype(np.float32) for filename in file_names_val] - self.img_list.extend(images) + self.img_list.extend(file_names_val) self.positions = np.vstack((self.positions, positions_val.astype(np.float32))) self.quaternions = np.vstack((self.quaternions, quaternions_val.astype(np.float32))) + self.map_idx.extend([data_idx] * len(file_names_val)) else: raise ValueError(f"Invalid mode {mode}. Choose from 'train', 'valid'.") - self.map_idx.extend([data_idx] * len(images)) - - # NOTE: The depth images are normalized from 0–20m to a 0–1 and converted to int16 during data collection. - self.img_list = [np.expand_dims( - cv2.resize(img, (self.width, self.height), interpolation=cv2.INTER_NEAREST) / 65535.0, - axis=0) - for img in self.img_list] - print(f"=============== {mode.capitalize()} Data Summary ===============") - print(f"{'Images' :<12} | Count: {len(self.img_list):<3} | Shape: {self.img_list[0].shape}") + print(f"{'Images' :<12} | Count: {len(self.img_list):<3} | Shape: {self.width},{self.height}") print(f"{'Positions' :<12} | Count: {self.positions.shape[0]:<3} | Shape: {self.positions.shape[1]}") print(f"{'Quaternions' :<12} | Count: {self.quaternions.shape[0]:<3} | Shape: {self.quaternions.shape[1]}") print("==================================================") @@ -88,9 +81,15 @@ class YOPODataset(Dataset): return len(self.img_list) def __getitem__(self, item): + # 1. read the image + # NOTE: The depth images are normalized from 0–20m to a 0–1 and converted to int16 during data collection. + image = cv2.imread(self.img_list[item], -1).astype(np.float32) + image = np.expand_dims(cv2.resize(image, (self.width, self.height), interpolation=cv2.INTER_NEAREST) / 65535.0, axis=0) + + # 2. get random vel, acc vel_b, acc_b = self._get_random_state() - # generate random goal in front of the quadrotor. + # 3. generate random goal in front of the quadrotor. q_wxyz = self.quaternions[item, :] # q: wxyz R_WB = R.from_quat([q_wxyz[1], q_wxyz[2], q_wxyz[3], q_wxyz[0]]) euler_angles = R_WB.as_euler('ZYX', degrees=False) # [yaw(z) pitch(y) roll(x)] @@ -101,7 +100,7 @@ class YOPODataset(Dataset): random_obs = np.hstack((vel_b, acc_b, goal_b)).astype(np.float32) rot_wb = R_WB.as_matrix().astype(np.float32) # transform to rot_matrix in numpy is faster than using quat in pytorch # vel & acc & goal are in body frame, NWU, and no-normalization - return self.img_list[item], self.positions[item], rot_wb, random_obs, self.map_idx[item] + return image, self.positions[item], rot_wb, random_obs, self.map_idx[item] def _get_random_state(self): while True: @@ -212,14 +211,23 @@ class YOPODataset(Dataset): if __name__ == '__main__': dataset = YOPODataset() - dataset.plot_sample_distribution() - data_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4) + # dataset.plot_sample_distribution() - start = time.time() - for epoch in range(1): - last = time.time() - for i, (depth, pos, quat, obs, id) in enumerate(data_loader): - pass - end = time.time() + dataset = YOPODataset() + max_workers = os.cpu_count() + print(f"\n✅ cpu_count = {max_workers}") - print("加载1个epoch总耗时:", end - start) + results = [] + for nw in range(0, max_workers + 1): + data_loader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=nw) + start = time.time() + for i, _ in enumerate(data_loader): + if i > 50: # 只测前50个batch + break + torch.cuda.synchronize() if torch.cuda.is_available() else None + elapsed = time.time() - start + results.append((nw, elapsed)) + print(f"num_workers={nw}: {elapsed:.3f}s") + + best = min(results, key=lambda x: x[1]) + print(f"\n✅ 最优 num_workers = {best[0]}, 平均耗时={best[1]:.3f}s") diff --git a/YOPO/policy/yopo_trainer.py b/YOPO/policy/yopo_trainer.py index 626cbcc..2bb5f59 100644 --- a/YOPO/policy/yopo_trainer.py +++ b/YOPO/policy/yopo_trainer.py @@ -57,11 +57,11 @@ class YopoTrainer: self.optimizer = torch.optim.AdamW(self.policy.parameters(), lr=learning_rate, fused=True) print("Network Loaded! Loading Dataset...") - # dataset + # dataset (you can adjust num_workers according to your training speed) self.train_dataloader = DataLoader(YOPODataset(mode='train'), batch_size=self.batch_size, shuffle=True, - num_workers=1, pin_memory=True) + num_workers=4, pin_memory=True) self.val_dataloader = DataLoader(YOPODataset(mode='valid'), batch_size=self.batch_size, shuffle=False, - num_workers=1, pin_memory=True) + num_workers=4, pin_memory=True) print("Dataset Loaded!") def train(self, epoch, save_interval=None):