Speed up dataloader (train speed x110%)
This commit is contained in:
parent
b7bfea3a5c
commit
6f7060838c
@ -58,8 +58,10 @@ class YopoTrainer:
|
|||||||
print("Network Loaded! Loading Dataset...")
|
print("Network Loaded! Loading Dataset...")
|
||||||
|
|
||||||
# dataset
|
# dataset
|
||||||
self.train_dataloader = DataLoader(YOPODataset(mode='train'), batch_size=self.batch_size, shuffle=True, num_workers=0)
|
self.train_dataloader = DataLoader(YOPODataset(mode='train'), batch_size=self.batch_size, shuffle=True,
|
||||||
self.val_dataloader = DataLoader(YOPODataset(mode='valid'), batch_size=self.batch_size, shuffle=False, num_workers=0)
|
num_workers=1, pin_memory=True)
|
||||||
|
self.val_dataloader = DataLoader(YOPODataset(mode='valid'), batch_size=self.batch_size, shuffle=False,
|
||||||
|
num_workers=1, pin_memory=True)
|
||||||
print("Dataset Loaded!")
|
print("Dataset Loaded!")
|
||||||
|
|
||||||
def train(self, epoch, save_interval=None):
|
def train(self, epoch, save_interval=None):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user