modify dataloader and update readme

This commit is contained in:
TJU-Lu 2025-10-15 23:37:38 +08:00
parent c445f0727f
commit cd0008c94e
4 changed files with 11 additions and 8 deletions

View File

@ -162,7 +162,7 @@ conda activate yopo
tensorboard --logdir=./ tensorboard --logdir=./
``` ```
<p align="center"> <p align="center">
<img src="docs/train_log.png" alt="train_log" width="80%"/> <img src="docs/train_log.png" alt="train_log" width="100%"/>
</p> </p>
Besides, you can refer to [traj_opt.yaml](YOPO/config/traj_opt.yaml) for modifications of trajectory optimization (e.g. the speed and penalties). Besides, you can refer to [traj_opt.yaml](YOPO/config/traj_opt.yaml) for modifications of trajectory optimization (e.g. the speed and penalties).

View File

@ -26,6 +26,8 @@ class GuidanceLoss(nn.Module):
traj_dir = end_pos - cur_pos # [B, 3] traj_dir = end_pos - cur_pos # [B, 3]
goal_dir = goal - cur_pos # [B, 3] goal_dir = goal - cur_pos # [B, 3]
# NOTE: distance_loss performs better in general tasks, while we choose terminal_aware_similarity_loss only for higher speed in large-scale scenario.
# guidance_loss = self.distance_loss(traj_dir, goal_dir)
guidance_loss = self.terminal_aware_similarity_loss(traj_dir, goal_dir) guidance_loss = self.terminal_aware_similarity_loss(traj_dir, goal_dir)
return guidance_loss return guidance_loss

View File

@ -31,17 +31,18 @@ class YOPODataset(Dataset):
if mode == 'train': self.print_data() if mode == 'train': self.print_data()
# dataset # dataset
print("Loading", mode, "dataset, it may take a while...")
base_dir = os.path.dirname(os.path.abspath(__file__)) base_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.join(base_dir, "../", cfg["dataset_path"]) data_dir = os.path.join(base_dir, "../", cfg["dataset_path"])
self.img_list, self.map_idx, self.positions, self.quaternions = [], [], np.empty((0, 3), dtype=np.float32), np.empty((0, 4), dtype=np.float32) self.img_list, self.map_idx, self.positions, self.quaternions = [], [], np.empty((0, 3), dtype=np.float32), np.empty((0, 4), dtype=np.float32)
datafolders = [f.path for f in os.scandir(data_dir) if f.is_dir()] datafolders = [f.path for f in os.scandir(data_dir) if f.is_dir()]
datafolders.sort(key=lambda x: int(os.path.basename(x))) datafolders.sort(key=lambda x: int(os.path.basename(x)))
if mode == 'train':
print("Datafolders:") print("Datafolders:")
for folder in datafolders: for folder in datafolders:
print(" ", folder) print(" ", folder)
print("Loading", mode, "dataset")
for data_idx in range(len(datafolders)): for data_idx in range(len(datafolders)):
datafolder = datafolders[data_idx] datafolder = datafolders[data_idx]
@ -75,7 +76,6 @@ class YOPODataset(Dataset):
print(f"{'Positions' :<12} | Count: {self.positions.shape[0]:<3} | Shape: {self.positions.shape[1]}") print(f"{'Positions' :<12} | Count: {self.positions.shape[0]:<3} | Shape: {self.positions.shape[1]}")
print(f"{'Quaternions' :<12} | Count: {self.quaternions.shape[0]:<3} | Shape: {self.quaternions.shape[1]}") print(f"{'Quaternions' :<12} | Count: {self.quaternions.shape[0]:<3} | Shape: {self.quaternions.shape[1]}")
print("==================================================") print("==================================================")
print(mode.capitalize(), "data loaded!")
def __len__(self): def __len__(self):
return len(self.img_list) return len(self.img_list)
@ -210,10 +210,11 @@ class YOPODataset(Dataset):
if __name__ == '__main__': if __name__ == '__main__':
# plot the random sample
dataset = YOPODataset() dataset = YOPODataset()
# dataset.plot_sample_distribution() dataset.plot_sample_distribution()
dataset = YOPODataset() # select the best num_workers
max_workers = os.cpu_count() max_workers = os.cpu_count()
print(f"\n✅ cpu_count = {max_workers}") print(f"\n✅ cpu_count = {max_workers}")

Binary file not shown.

Before

Width:  |  Height:  |  Size: 69 KiB

After

Width:  |  Height:  |  Size: 133 KiB