diff --git a/train.py b/train.py index 5ac3f3b..90a60ae 100644 --- a/train.py +++ b/train.py @@ -35,7 +35,7 @@ def train_net(net, n_train = len(dataset) - n_val train, val = random_split(dataset, [n_train, n_val]) train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) - val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True) + val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True) writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}') global_step = 0