update pretrained model

This commit is contained in:
TJU-Lu 2025-12-05 23:09:22 +08:00
parent 39cece7fac
commit b92ca64731
6 changed files with 4 additions and 5 deletions

View File

@ -101,7 +101,7 @@ You can refer to [config.yaml](Simulator/src/config/config.yaml) for modificatio
**3. Start the YOPO Planner**
You can refer to [traj_opt.yaml](YOPO/config/traj_opt.yaml) for modification of the flight speed (The given weights are pretrained at 6 m/s and perform smoothly at speeds between 0 - 6 m/s).
You can refer to [traj_opt.yaml](YOPO/config/traj_opt.yaml) for modification of the flight speed (The given weights are pretrained at 6 m/s and perform smoothly at speeds between 0 - 6 m/s, and more pretrained models are available at [Releases](https://github.com/TJU-Aerial-Robotics/YOPO/releases)).
```
cd YOPO

View File

@ -66,7 +66,7 @@ class GuidanceLoss(nn.Module):
goal_length = goal_dir.norm(dim=1) # [B]
# length difference along goal direction (cosine similarity)
parallel_diff = (goal_length - traj_along).abs() # [B]
parallel_diff = F.smooth_l1_loss(goal_length, traj_along, reduction='none')
# length perpendicular to goal direction
traj_perp = traj_dir - traj_along.unsqueeze(1) * goal_dir_norm # [B, 3]

BIN
YOPO/saved/YOPO_1/epoch50.pth Executable file → Normal file

Binary file not shown.

View File

@ -140,7 +140,7 @@ class YopoNet:
obs = np.concatenate((vel_c, acc_c, goal_c), axis=0).astype(np.float32)
obs_norm = self.state_transform.normalize_obs(torch.from_numpy(obs[None, :]))
return obs_norm.to(self.device, non_blocking=True)
return obs_norm
@torch.inference_mode()
def callback_depth(self, data):
@ -171,9 +171,8 @@ class YopoNet:
# input prepare
time1 = time.time()
depth_input = torch.from_numpy(depth).to(self.device, non_blocking=True) # (non_blocking: copying speed 3x)
obs_norm = self.process_odom()
obs_norm = self.process_odom().to(self.device, non_blocking=True)
obs_input = self.state_transform.prepare_input(obs_norm)
obs_input = obs_input.to(self.device, non_blocking=True)
# torch.cuda.synchronize()
time2 = time.time()