diff --git a/eval.py b/eval.py index 003d782..761a2d3 100644 --- a/eval.py +++ b/eval.py @@ -16,7 +16,8 @@ def eval_net(net, loader, device, n_val): true_masks = batch['mask'] imgs = imgs.to(device=device, dtype=torch.float32) - true_masks = true_masks.to(device=device, dtype=torch.float32) + mask_type = torch.float32 if net.n_classes == 1 else torch.long + true_masks = true_masks.to(device=device, dtype=mask_type) mask_pred = net(imgs) diff --git a/train.py b/train.py index 76e9d99..c23aff2 100644 --- a/train.py +++ b/train.py @@ -71,7 +71,8 @@ def train_net(net, 'the images are loaded correctly.' imgs = imgs.to(device=device, dtype=torch.float32) - true_masks = true_masks.to(device=device, dtype=torch.float32) + mask_type = torch.float32 if net.n_classes == 1 else torch.long + true_masks = true_masks.to(device=device, dtype=mask_type) masks_pred = net(imgs) loss = criterion(masks_pred, true_masks)