diff --git a/YOPO/policy/yopo_trainer.py b/YOPO/policy/yopo_trainer.py index 5c766a3..96d837e 100644 --- a/YOPO/policy/yopo_trainer.py +++ b/YOPO/policy/yopo_trainer.py @@ -58,8 +58,10 @@ class YopoTrainer: print("Network Loaded! Loading Dataset...") # dataset - self.train_dataloader = DataLoader(YOPODataset(mode='train'), batch_size=self.batch_size, shuffle=True, num_workers=0) - self.val_dataloader = DataLoader(YOPODataset(mode='valid'), batch_size=self.batch_size, shuffle=False, num_workers=0) + self.train_dataloader = DataLoader(YOPODataset(mode='train'), batch_size=self.batch_size, shuffle=True, + num_workers=1, pin_memory=True) + self.val_dataloader = DataLoader(YOPODataset(mode='valid'), batch_size=self.batch_size, shuffle=False, + num_workers=1, pin_memory=True) print("Dataset Loaded!") def train(self, epoch, save_interval=None):