2021-08-16 00:53:00 +00:00
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
2021-08-21 08:26:42 +00:00
|
|
|
from utils.dice_score import multiclass_dice_coeff, dice_coeff
|
2021-08-16 00:53:00 +00:00
|
|
|
|
|
|
|
|
|
|
|
def evaluate(net, dataloader, device):
|
|
|
|
net.eval()
|
|
|
|
num_val_batches = len(dataloader)
|
|
|
|
dice_score = 0
|
|
|
|
|
|
|
|
# iterate over the validation set
|
2022-06-27 13:39:44 +00:00
|
|
|
for batch in tqdm(dataloader, total=num_val_batches, desc="Validation round", unit="batch", leave=False):
|
|
|
|
image, mask_true = batch["image"], batch["mask"]
|
2021-08-16 00:53:00 +00:00
|
|
|
# move images and labels to correct device and type
|
|
|
|
image = image.to(device=device, dtype=torch.float32)
|
|
|
|
mask_true = mask_true.to(device=device, dtype=torch.long)
|
|
|
|
mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float()
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
# predict the mask
|
|
|
|
mask_pred = net(image)
|
|
|
|
|
|
|
|
# convert to one-hot format
|
|
|
|
if net.n_classes == 1:
|
2021-08-21 06:41:23 +00:00
|
|
|
mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
|
2021-08-21 08:26:42 +00:00
|
|
|
# compute the Dice score
|
|
|
|
dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False)
|
2021-08-16 00:53:00 +00:00
|
|
|
else:
|
|
|
|
mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float()
|
2021-08-21 08:26:42 +00:00
|
|
|
# compute the Dice score, ignoring background
|
2022-06-27 13:39:44 +00:00
|
|
|
dice_score += multiclass_dice_coeff(
|
|
|
|
mask_pred[:, 1:, ...], mask_true[:, 1:, ...], reduce_batch_first=False
|
|
|
|
)
|
2021-08-16 00:53:00 +00:00
|
|
|
|
|
|
|
net.train()
|
2021-10-24 21:07:54 +00:00
|
|
|
|
|
|
|
# Fixes a potential division by zero error
|
|
|
|
if num_val_batches == 0:
|
|
|
|
return dice_score
|
2021-08-16 00:53:00 +00:00
|
|
|
return dice_score / num_val_batches
|