From 81938b944e7ec515733b68d2bdd22ce8823755ac Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Thu, 30 Jun 2022 10:48:05 +0200 Subject: [PATCH] feat: create overlay instead of table Former-commit-id: e6a0ea3ce845ab6ff524f98cebd49151da5765de [formerly 1fc368fdaba7b892357848fd47a8e53ed34a2c35] Former-commit-id: afc34abfffe0e309d6e4ae77d4d56d4f372b9e9d --- src/evaluate.py | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/src/evaluate.py b/src/evaluate.py index 1d0995c..3811d14 100644 --- a/src/evaluate.py +++ b/src/evaluate.py @@ -1,9 +1,14 @@ +import numpy as np import torch from tqdm import tqdm import wandb from src.utils.dice import dice_coeff +class_labels = { + 1: "sphere", +} + def evaluate(net, dataloader, device): net.eval() @@ -29,10 +34,28 @@ def evaluate(net, dataloader, device): pbar.update(images.shape[0]) # save some images to wandb - table = wandb.Table(columns=["id", "image", "mask", "prediction"]) - for i, (img, mask, pred) in enumerate(zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu"))): - table.add_data(i, wandb.Image(img), wandb.Image(mask), wandb.Image(pred)) - wandb.log({"predictions_table": table}, commit=False) + overlays = [] + for img, mask, pred in zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu")): + mask_img = np.asarray(mask > 0.5, np.uint8).squeeze(0) # tester des trucs sans le threshold + pred_img = np.asarray(pred > 0.5, np.uint8).squeeze(0) + + overlays.append( + wandb.Image( + img, + masks={ + "ground_truth": { + "mask_data": mask_img, + "class_labels": class_labels, + }, + "predictions": { + "mask_data": pred_img, + "class_labels": class_labels, + }, + }, + ) + ) + + wandb.log({"val/images": overlays}) net.train()