From beafa768f71c8f057c7a7fc3e76683c3ab78c059 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Fri, 1 Jul 2022 11:19:15 +0200 Subject: [PATCH] feat: display binary mask on top of prediction Former-commit-id: a81b30919c6ef0998617822fd57e54209dcb2cfa [formerly 9f9cf75289c5f731dc6c93f86dff237aed88722c] Former-commit-id: d7272aefcf16e2e22081811c0452d7cb7e0980f3 --- src/train.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/train.py b/src/train.py index 4a434ed..a48164e 100644 --- a/src/train.py +++ b/src/train.py @@ -13,6 +13,10 @@ from unet import UNet from utils.dice import dice_coeff from utils.paste import RandomPaste +class_labels = { + 1: "sphere", +} + def main(): # setup logging @@ -215,14 +219,28 @@ def main(): # save the last validation batch to table table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"]) - for i, (img, mask, pred) in enumerate( + for i, (img, mask, pred, pred_bin) in enumerate( zip( images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu"), + masks_pred_bin.to("cpu").squeeze().int().numpy(), ) ): - table.add_data(i, wandb.Image(img), wandb.Image(mask), wandb.Image(pred)) + table.add_data( + i, + wandb.Image(img), + wandb.Image(mask), + wandb.Image( + pred, + masks={ + "predictions": { + "mask_data": pred_bin, + "class_labels": class_labels, + }, + }, + ), + ) # log validation metrics wandb.log(