Update mask type for muticlass

Former-commit-id: 4dcb7b8440c5f36ff2565c67f56f8f029b589c80
This commit is contained in:
milesial 2019-12-13 17:36:12 +01:00
parent 4e1f0398a1
commit 5f4ce7dba9
2 changed files with 4 additions and 2 deletions

View file

@ -16,7 +16,8 @@ def eval_net(net, loader, device, n_val):
true_masks = batch['mask']
imgs = imgs.to(device=device, dtype=torch.float32)
true_masks = true_masks.to(device=device, dtype=torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device=device, dtype=mask_type)
mask_pred = net(imgs)

View file

@ -71,7 +71,8 @@ def train_net(net,
'the images are loaded correctly.'
imgs = imgs.to(device=device, dtype=torch.float32)
true_masks = true_masks.to(device=device, dtype=torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device=device, dtype=mask_type)
masks_pred = net(imgs)
loss = criterion(masks_pred, true_masks)