2021-08-16 00:53:00 +00:00
|
|
|
import torch
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
2022-06-28 14:36:50 +00:00
|
|
|
import wandb
|
|
|
|
from src.utils.dice import dice_coeff
|
2021-08-16 00:53:00 +00:00
|
|
|
|
|
|
|
|
|
|
|
def evaluate(net, dataloader, device):
|
|
|
|
net.eval()
|
|
|
|
num_val_batches = len(dataloader)
|
|
|
|
dice_score = 0
|
|
|
|
|
|
|
|
# iterate over the validation set
|
2022-06-29 08:10:56 +00:00
|
|
|
with tqdm(dataloader, total=len(dataloader.dataset), desc="val", unit="img", leave=False) as pbar:
|
2022-06-28 09:36:43 +00:00
|
|
|
for images, masks_true in dataloader:
|
|
|
|
# move images and labels to correct device
|
|
|
|
images = images.to(device=device)
|
2022-06-29 08:10:56 +00:00
|
|
|
masks_true = masks_true.unsqueeze(1).float().to(device=device)
|
2022-06-28 09:36:43 +00:00
|
|
|
|
2022-06-29 08:10:56 +00:00
|
|
|
# forward, predict the mask
|
2022-06-28 09:36:43 +00:00
|
|
|
with torch.inference_mode():
|
|
|
|
masks_pred = net(images)
|
|
|
|
masks_pred = (torch.sigmoid(masks_pred) > 0.5).float()
|
|
|
|
|
2021-08-21 08:26:42 +00:00
|
|
|
# compute the Dice score
|
2022-06-28 09:36:43 +00:00
|
|
|
dice_score += dice_coeff(masks_pred, masks_true, reduce_batch_first=False)
|
|
|
|
|
2022-06-29 08:10:56 +00:00
|
|
|
# update progress bar
|
2022-06-28 09:36:43 +00:00
|
|
|
pbar.update(images.shape[0])
|
2021-08-16 00:53:00 +00:00
|
|
|
|
2022-06-28 14:36:50 +00:00
|
|
|
# save some images to wandb
|
|
|
|
table = wandb.Table(columns=["image", "mask", "prediction"])
|
|
|
|
for img, mask, pred in zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu")):
|
|
|
|
table.add_data(wandb.Image(img), wandb.Image(mask), wandb.Image(pred))
|
|
|
|
wandb.log({"predictions_table": table}, commit=False)
|
|
|
|
|
2021-08-16 00:53:00 +00:00
|
|
|
net.train()
|
2021-10-24 21:07:54 +00:00
|
|
|
|
|
|
|
# Fixes a potential division by zero error
|
2022-06-28 09:36:43 +00:00
|
|
|
return dice_score / num_val_batches if num_val_batches else dice_score
|