feat: display binary mask on top of prediction

Former-commit-id: a81b30919c6ef0998617822fd57e54209dcb2cfa [formerly 9f9cf75289c5f731dc6c93f86dff237aed88722c]
Former-commit-id: d7272aefcf16e2e22081811c0452d7cb7e0980f3
This commit is contained in:
Laurent Fainsin 2022-07-01 11:19:15 +02:00
parent 31ceb97996
commit beafa768f7

View file

@ -13,6 +13,10 @@ from unet import UNet
from utils.dice import dice_coeff from utils.dice import dice_coeff
from utils.paste import RandomPaste from utils.paste import RandomPaste
class_labels = {
1: "sphere",
}
def main(): def main():
# setup logging # setup logging
@ -215,14 +219,28 @@ def main():
# save the last validation batch to table # save the last validation batch to table
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"]) table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
for i, (img, mask, pred) in enumerate( for i, (img, mask, pred, pred_bin) in enumerate(
zip( zip(
images.to("cpu"), images.to("cpu"),
masks_true.to("cpu"), masks_true.to("cpu"),
masks_pred.to("cpu"), masks_pred.to("cpu"),
masks_pred_bin.to("cpu").squeeze().int().numpy(),
) )
): ):
table.add_data(i, wandb.Image(img), wandb.Image(mask), wandb.Image(pred)) table.add_data(
i,
wandb.Image(img),
wandb.Image(mask),
wandb.Image(
pred,
masks={
"predictions": {
"mask_data": pred_bin,
"class_labels": class_labels,
},
},
),
)
# log validation metrics # log validation metrics
wandb.log( wandb.log(