From 6f7060838c7ed52c308b4e2b7cd66cbb2dcd59b6 Mon Sep 17 00:00:00 2001 From: TJU_Lu Date: Wed, 16 Jul 2025 21:41:27 +0800 Subject: [PATCH] Speed up dataloader (train speed x110%) --- YOPO/policy/yopo_trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/YOPO/policy/yopo_trainer.py b/YOPO/policy/yopo_trainer.py index 5c766a3..96d837e 100644 --- a/YOPO/policy/yopo_trainer.py +++ b/YOPO/policy/yopo_trainer.py @@ -58,8 +58,10 @@ class YopoTrainer: print("Network Loaded! Loading Dataset...") # dataset - self.train_dataloader = DataLoader(YOPODataset(mode='train'), batch_size=self.batch_size, shuffle=True, num_workers=0) - self.val_dataloader = DataLoader(YOPODataset(mode='valid'), batch_size=self.batch_size, shuffle=False, num_workers=0) + self.train_dataloader = DataLoader(YOPODataset(mode='train'), batch_size=self.batch_size, shuffle=True, + num_workers=1, pin_memory=True) + self.val_dataloader = DataLoader(YOPODataset(mode='valid'), batch_size=self.batch_size, shuffle=False, + num_workers=1, pin_memory=True) print("Dataset Loaded!") def train(self, epoch, save_interval=None):