From 8c761f9d06eccf881f8a50c5368884b9c080746b Mon Sep 17 00:00:00 2001 From: milesial Date: Sat, 21 Aug 2021 08:41:23 +0200 Subject: [PATCH] Fix evaluation Former-commit-id: 2b649bc9a337818696291280cd87ad93a6fc5032 --- evaluate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/evaluate.py b/evaluate.py index 88b337b..7db882d 100644 --- a/evaluate.py +++ b/evaluate.py @@ -24,12 +24,12 @@ def evaluate(net, dataloader, device): # convert to one-hot format if net.n_classes == 1: - mask_pred = (F.sigmoid(mask_pred) > 0).float() + mask_pred = (F.sigmoid(mask_pred) > 0.5).float() else: mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float() # compute the Dice score, ignoring background - dice_score += multiclass_dice_coeff(mask_pred[:, :1, ...], mask_true[:, :1, ...], reduce_batch_first=False) + dice_score += multiclass_dice_coeff(mask_pred[:, 1:, ...], mask_true[:, 1:, ...], reduce_batch_first=False) net.train() return dice_score / num_val_batches