feat: create overlay instead of table
Former-commit-id: e6a0ea3ce845ab6ff524f98cebd49151da5765de [formerly 1fc368fdaba7b892357848fd47a8e53ed34a2c35] Former-commit-id: afc34abfffe0e309d6e4ae77d4d56d4f372b9e9d
This commit is contained in:
parent
dac6237906
commit
81938b944e
|
@ -1,9 +1,14 @@
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
from src.utils.dice import dice_coeff
|
from src.utils.dice import dice_coeff
|
||||||
|
|
||||||
|
class_labels = {
|
||||||
|
1: "sphere",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def evaluate(net, dataloader, device):
|
def evaluate(net, dataloader, device):
|
||||||
net.eval()
|
net.eval()
|
||||||
|
@ -29,10 +34,28 @@ def evaluate(net, dataloader, device):
|
||||||
pbar.update(images.shape[0])
|
pbar.update(images.shape[0])
|
||||||
|
|
||||||
# save some images to wandb
|
# save some images to wandb
|
||||||
table = wandb.Table(columns=["id", "image", "mask", "prediction"])
|
overlays = []
|
||||||
for i, (img, mask, pred) in enumerate(zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu"))):
|
for img, mask, pred in zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu")):
|
||||||
table.add_data(i, wandb.Image(img), wandb.Image(mask), wandb.Image(pred))
|
mask_img = np.asarray(mask > 0.5, np.uint8).squeeze(0) # tester des trucs sans le threshold
|
||||||
wandb.log({"predictions_table": table}, commit=False)
|
pred_img = np.asarray(pred > 0.5, np.uint8).squeeze(0)
|
||||||
|
|
||||||
|
overlays.append(
|
||||||
|
wandb.Image(
|
||||||
|
img,
|
||||||
|
masks={
|
||||||
|
"ground_truth": {
|
||||||
|
"mask_data": mask_img,
|
||||||
|
"class_labels": class_labels,
|
||||||
|
},
|
||||||
|
"predictions": {
|
||||||
|
"mask_data": pred_img,
|
||||||
|
"class_labels": class_labels,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
wandb.log({"val/images": overlays})
|
||||||
|
|
||||||
net.train()
|
net.train()
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue