From b7bfea3a5c3ff2f97dfe14d338c093e7e9d442c8 Mon Sep 17 00:00:00 2001 From: TJU_Lu Date: Tue, 8 Jul 2025 22:14:56 +0800 Subject: [PATCH] fine-tuning params --- README.md | 2 +- YOPO/config/traj_opt.yaml | 2 +- YOPO/loss/safety_loss.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index bfa5ea9..f851797 100644 --- a/README.md +++ b/README.md @@ -154,7 +154,7 @@ cd YOPO/ conda activate yopo python train_yopo.py ``` -It takes less than 1 hour to train on 100,000 samples for 50 epochs on an RTX 3080 GPU. Besides, we highly recommend binding the process to P-cores via `taskset -c 1,2,3,4 python train_yopo.py` if your CPU uses a hybrid architecture with P-cores and E-cores. If everything goes well, the training log is as follows: +It takes less than 1 hour to train on 100,000 samples for 50 epochs on an RTX 3080 GPU and i9-12900K CPU. Besides, we highly recommend binding the process to P-cores via `taskset -c 1,2,3,4 python train_yopo.py` if your CPU uses a hybrid architecture with P-cores and E-cores. If everything goes well, the training log is as follows: ``` cd YOPO/saved diff --git a/YOPO/config/traj_opt.yaml b/YOPO/config/traj_opt.yaml index 4994aae..7be5200 100644 --- a/YOPO/config/traj_opt.yaml +++ b/YOPO/config/traj_opt.yaml @@ -7,7 +7,7 @@ vel_max_train: 6.0 acc_max_train: 6.0 # IMPORTANT: weight of costs for unit speed (can be visualized in tensorboard) -wg: 0.1 # guidance +wg: 0.12 # guidance ws: 10.0 # smoothness wc: 0.1 # collision diff --git a/YOPO/loss/safety_loss.py b/YOPO/loss/safety_loss.py index e1310bc..cbe5f4f 100644 --- a/YOPO/loss/safety_loss.py +++ b/YOPO/loss/safety_loss.py @@ -57,10 +57,10 @@ class SafetyLoss(nn.Module): # get info from sdf_map cost = self.get_distance_cost(pos_batch, map_id) - cost_dt = (cost * dt).reshape(-1, pos_coe.shape[1]) # [B*H*V, N, 3] - cost_colli = cost_dt.sum(dim=-1, keepdim=True) + cost_dt = (cost * dt).reshape(-1, pos_coe.shape[1]) # [B*H*V, N] + cost_colli = cost_dt.sum(dim=-1) - return cost_colli.squeeze() + return cost_colli def get_distance_cost(self, pos, map_id): """