Speed up dataloader (train speed x110%)

This commit is contained in:
TJU_Lu 2025-07-16 21:41:27 +08:00
parent b7bfea3a5c
commit 6f7060838c

View File

@ -58,8 +58,10 @@ class YopoTrainer:
print("Network Loaded! Loading Dataset...") print("Network Loaded! Loading Dataset...")
# dataset # dataset
self.train_dataloader = DataLoader(YOPODataset(mode='train'), batch_size=self.batch_size, shuffle=True, num_workers=0) self.train_dataloader = DataLoader(YOPODataset(mode='train'), batch_size=self.batch_size, shuffle=True,
self.val_dataloader = DataLoader(YOPODataset(mode='valid'), batch_size=self.batch_size, shuffle=False, num_workers=0) num_workers=1, pin_memory=True)
self.val_dataloader = DataLoader(YOPODataset(mode='valid'), batch_size=self.batch_size, shuffle=False,
num_workers=1, pin_memory=True)
print("Dataset Loaded!") print("Dataset Loaded!")
def train(self, epoch, save_interval=None): def train(self, epoch, save_interval=None):