Fix evaluation on 1 class

Former-commit-id: d8984977848d924644d575f5e46c475d84ff1772
This commit is contained in:
milesial 2021-08-21 10:26:42 +02:00 committed by GitHub
parent 8c761f9d06
commit 449965b398

View file

@ -2,7 +2,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from tqdm import tqdm from tqdm import tqdm
from utils.dice_score import multiclass_dice_coeff from utils.dice_score import multiclass_dice_coeff, dice_coeff
def evaluate(net, dataloader, device): def evaluate(net, dataloader, device):
@ -25,11 +25,14 @@ def evaluate(net, dataloader, device):
# convert to one-hot format # convert to one-hot format
if net.n_classes == 1: if net.n_classes == 1:
mask_pred = (F.sigmoid(mask_pred) > 0.5).float() mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
# compute the Dice score
dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False)
else: else:
mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float() mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float()
# compute the Dice score, ignoring background
dice_score += multiclass_dice_coeff(mask_pred[:, 1:, ...], mask_true[:, 1:, ...], reduce_batch_first=False)
# compute the Dice score, ignoring background
dice_score += multiclass_dice_coeff(mask_pred[:, 1:, ...], mask_true[:, 1:, ...], reduce_batch_first=False)
net.train() net.train()
return dice_score / num_val_batches return dice_score / num_val_batches