modify dataloader and update readme
This commit is contained in:
parent
c445f0727f
commit
cd0008c94e
@ -162,7 +162,7 @@ conda activate yopo
|
|||||||
tensorboard --logdir=./
|
tensorboard --logdir=./
|
||||||
```
|
```
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="docs/train_log.png" alt="train_log" width="80%"/>
|
<img src="docs/train_log.png" alt="train_log" width="100%"/>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
Besides, you can refer to [traj_opt.yaml](YOPO/config/traj_opt.yaml) for modifications of trajectory optimization (e.g. the speed and penalties).
|
Besides, you can refer to [traj_opt.yaml](YOPO/config/traj_opt.yaml) for modifications of trajectory optimization (e.g. the speed and penalties).
|
||||||
|
|||||||
@ -26,6 +26,8 @@ class GuidanceLoss(nn.Module):
|
|||||||
traj_dir = end_pos - cur_pos # [B, 3]
|
traj_dir = end_pos - cur_pos # [B, 3]
|
||||||
goal_dir = goal - cur_pos # [B, 3]
|
goal_dir = goal - cur_pos # [B, 3]
|
||||||
|
|
||||||
|
# NOTE: distance_loss performs better in general tasks, while we choose terminal_aware_similarity_loss only for higher speed in large-scale scenario.
|
||||||
|
# guidance_loss = self.distance_loss(traj_dir, goal_dir)
|
||||||
guidance_loss = self.terminal_aware_similarity_loss(traj_dir, goal_dir)
|
guidance_loss = self.terminal_aware_similarity_loss(traj_dir, goal_dir)
|
||||||
return guidance_loss
|
return guidance_loss
|
||||||
|
|
||||||
|
|||||||
@ -31,17 +31,18 @@ class YOPODataset(Dataset):
|
|||||||
if mode == 'train': self.print_data()
|
if mode == 'train': self.print_data()
|
||||||
|
|
||||||
# dataset
|
# dataset
|
||||||
print("Loading", mode, "dataset, it may take a while...")
|
|
||||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
data_dir = os.path.join(base_dir, "../", cfg["dataset_path"])
|
data_dir = os.path.join(base_dir, "../", cfg["dataset_path"])
|
||||||
self.img_list, self.map_idx, self.positions, self.quaternions = [], [], np.empty((0, 3), dtype=np.float32), np.empty((0, 4), dtype=np.float32)
|
self.img_list, self.map_idx, self.positions, self.quaternions = [], [], np.empty((0, 3), dtype=np.float32), np.empty((0, 4), dtype=np.float32)
|
||||||
|
|
||||||
datafolders = [f.path for f in os.scandir(data_dir) if f.is_dir()]
|
datafolders = [f.path for f in os.scandir(data_dir) if f.is_dir()]
|
||||||
datafolders.sort(key=lambda x: int(os.path.basename(x)))
|
datafolders.sort(key=lambda x: int(os.path.basename(x)))
|
||||||
print("Datafolders:")
|
if mode == 'train':
|
||||||
for folder in datafolders:
|
print("Datafolders:")
|
||||||
print(" ", folder)
|
for folder in datafolders:
|
||||||
|
print(" ", folder)
|
||||||
|
|
||||||
|
print("Loading", mode, "dataset")
|
||||||
for data_idx in range(len(datafolders)):
|
for data_idx in range(len(datafolders)):
|
||||||
datafolder = datafolders[data_idx]
|
datafolder = datafolders[data_idx]
|
||||||
|
|
||||||
@ -75,7 +76,6 @@ class YOPODataset(Dataset):
|
|||||||
print(f"{'Positions' :<12} | Count: {self.positions.shape[0]:<3} | Shape: {self.positions.shape[1]}")
|
print(f"{'Positions' :<12} | Count: {self.positions.shape[0]:<3} | Shape: {self.positions.shape[1]}")
|
||||||
print(f"{'Quaternions' :<12} | Count: {self.quaternions.shape[0]:<3} | Shape: {self.quaternions.shape[1]}")
|
print(f"{'Quaternions' :<12} | Count: {self.quaternions.shape[0]:<3} | Shape: {self.quaternions.shape[1]}")
|
||||||
print("==================================================")
|
print("==================================================")
|
||||||
print(mode.capitalize(), "data loaded!")
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.img_list)
|
return len(self.img_list)
|
||||||
@ -210,10 +210,11 @@ class YOPODataset(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
# plot the random sample
|
||||||
dataset = YOPODataset()
|
dataset = YOPODataset()
|
||||||
# dataset.plot_sample_distribution()
|
dataset.plot_sample_distribution()
|
||||||
|
|
||||||
dataset = YOPODataset()
|
# select the best num_workers
|
||||||
max_workers = os.cpu_count()
|
max_workers = os.cpu_count()
|
||||||
print(f"\n✅ cpu_count = {max_workers}")
|
print(f"\n✅ cpu_count = {max_workers}")
|
||||||
|
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 69 KiB After Width: | Height: | Size: 133 KiB |
Loading…
x
Reference in New Issue
Block a user