diff --git a/README.md b/README.md index 45b39dd..78e8007 100644 --- a/README.md +++ b/README.md @@ -162,7 +162,7 @@ conda activate yopo tensorboard --logdir=./ ```

- train_log + train_log

Besides, you can refer to [traj_opt.yaml](YOPO/config/traj_opt.yaml) for modifications of trajectory optimization (e.g. the speed and penalties). diff --git a/YOPO/loss/guidance_loss.py b/YOPO/loss/guidance_loss.py index a4646a3..66e111b 100644 --- a/YOPO/loss/guidance_loss.py +++ b/YOPO/loss/guidance_loss.py @@ -26,6 +26,8 @@ class GuidanceLoss(nn.Module): traj_dir = end_pos - cur_pos # [B, 3] goal_dir = goal - cur_pos # [B, 3] + # NOTE: distance_loss performs better in general tasks, while we choose terminal_aware_similarity_loss only for higher speed in large-scale scenario. + # guidance_loss = self.distance_loss(traj_dir, goal_dir) guidance_loss = self.terminal_aware_similarity_loss(traj_dir, goal_dir) return guidance_loss diff --git a/YOPO/policy/yopo_dataset.py b/YOPO/policy/yopo_dataset.py index 6ececd5..2273e01 100644 --- a/YOPO/policy/yopo_dataset.py +++ b/YOPO/policy/yopo_dataset.py @@ -31,17 +31,18 @@ class YOPODataset(Dataset): if mode == 'train': self.print_data() # dataset - print("Loading", mode, "dataset, it may take a while...") base_dir = os.path.dirname(os.path.abspath(__file__)) data_dir = os.path.join(base_dir, "../", cfg["dataset_path"]) self.img_list, self.map_idx, self.positions, self.quaternions = [], [], np.empty((0, 3), dtype=np.float32), np.empty((0, 4), dtype=np.float32) datafolders = [f.path for f in os.scandir(data_dir) if f.is_dir()] datafolders.sort(key=lambda x: int(os.path.basename(x))) - print("Datafolders:") - for folder in datafolders: - print(" ", folder) + if mode == 'train': + print("Datafolders:") + for folder in datafolders: + print(" ", folder) + print("Loading", mode, "dataset") for data_idx in range(len(datafolders)): datafolder = datafolders[data_idx] @@ -75,7 +76,6 @@ class YOPODataset(Dataset): print(f"{'Positions' :<12} | Count: {self.positions.shape[0]:<3} | Shape: {self.positions.shape[1]}") print(f"{'Quaternions' :<12} | Count: {self.quaternions.shape[0]:<3} | Shape: {self.quaternions.shape[1]}") print("==================================================") - print(mode.capitalize(), "data loaded!") def __len__(self): return len(self.img_list) @@ -210,10 +210,11 @@ class YOPODataset(Dataset): if __name__ == '__main__': + # plot the random sample dataset = YOPODataset() - # dataset.plot_sample_distribution() + dataset.plot_sample_distribution() - dataset = YOPODataset() + # select the best num_workers max_workers = os.cpu_count() print(f"\n✅ cpu_count = {max_workers}") diff --git a/docs/train_log.png b/docs/train_log.png index 7afb83e..96a87b5 100644 Binary files a/docs/train_log.png and b/docs/train_log.png differ