From 449965b398830f7850cddd52f0c067f904ff4556 Mon Sep 17 00:00:00 2001 From: milesial Date: Sat, 21 Aug 2021 10:26:42 +0200 Subject: [PATCH] Fix evaluation on 1 class Former-commit-id: d8984977848d924644d575f5e46c475d84ff1772 --- evaluate.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/evaluate.py b/evaluate.py index 7db882d..504432f 100644 --- a/evaluate.py +++ b/evaluate.py @@ -2,7 +2,7 @@ import torch import torch.nn.functional as F from tqdm import tqdm -from utils.dice_score import multiclass_dice_coeff +from utils.dice_score import multiclass_dice_coeff, dice_coeff def evaluate(net, dataloader, device): @@ -25,11 +25,14 @@ def evaluate(net, dataloader, device): # convert to one-hot format if net.n_classes == 1: mask_pred = (F.sigmoid(mask_pred) > 0.5).float() + # compute the Dice score + dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False) else: mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float() + # compute the Dice score, ignoring background + dice_score += multiclass_dice_coeff(mask_pred[:, 1:, ...], mask_true[:, 1:, ...], reduce_batch_first=False) - # compute the Dice score, ignoring background - dice_score += multiclass_dice_coeff(mask_pred[:, 1:, ...], mask_true[:, 1:, ...], reduce_batch_first=False) + net.train() return dice_score / num_val_batches