diff --git a/train.py b/train.py index 55f3a82..76e9d99 100644 --- a/train.py +++ b/train.py @@ -70,11 +70,6 @@ def train_net(net, f'but loaded images have {imgs.shape[1]} channels. Please check that ' \ 'the images are loaded correctly.' - assert true_masks.shape[1] == net.n_classes, \ - f'Network has been defined with {net.n_classes} output classes, ' \ - f'but loaded masks have {true_masks.shape[1]} channels. Please check that ' \ - 'the masks are loaded correctly.' - imgs = imgs.to(device=device, dtype=torch.float32) true_masks = true_masks.to(device=device, dtype=torch.float32)