Update mask type for muticlass
Former-commit-id: 4dcb7b8440c5f36ff2565c67f56f8f029b589c80
This commit is contained in:
parent
4e1f0398a1
commit
5f4ce7dba9
3
eval.py
3
eval.py
|
@ -16,7 +16,8 @@ def eval_net(net, loader, device, n_val):
|
|||
true_masks = batch['mask']
|
||||
|
||||
imgs = imgs.to(device=device, dtype=torch.float32)
|
||||
true_masks = true_masks.to(device=device, dtype=torch.float32)
|
||||
mask_type = torch.float32 if net.n_classes == 1 else torch.long
|
||||
true_masks = true_masks.to(device=device, dtype=mask_type)
|
||||
|
||||
mask_pred = net(imgs)
|
||||
|
||||
|
|
3
train.py
3
train.py
|
@ -71,7 +71,8 @@ def train_net(net,
|
|||
'the images are loaded correctly.'
|
||||
|
||||
imgs = imgs.to(device=device, dtype=torch.float32)
|
||||
true_masks = true_masks.to(device=device, dtype=torch.float32)
|
||||
mask_type = torch.float32 if net.n_classes == 1 else torch.long
|
||||
true_masks = true_masks.to(device=device, dtype=mask_type)
|
||||
|
||||
masks_pred = net(imgs)
|
||||
loss = criterion(masks_pred, true_masks)
|
||||
|
|
Loading…
Reference in a new issue