diff --git a/src/train.py b/src/train.py index 4a434ed..a48164e 100644 --- a/src/train.py +++ b/src/train.py @@ -13,6 +13,10 @@ from unet import UNet from utils.dice import dice_coeff from utils.paste import RandomPaste +class_labels = { + 1: "sphere", +} + def main(): # setup logging @@ -215,14 +219,28 @@ def main(): # save the last validation batch to table table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"]) - for i, (img, mask, pred) in enumerate( + for i, (img, mask, pred, pred_bin) in enumerate( zip( images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu"), + masks_pred_bin.to("cpu").squeeze().int().numpy(), ) ): - table.add_data(i, wandb.Image(img), wandb.Image(mask), wandb.Image(pred)) + table.add_data( + i, + wandb.Image(img), + wandb.Image(mask), + wandb.Image( + pred, + masks={ + "predictions": { + "mask_data": pred_bin, + "class_labels": class_labels, + }, + }, + ), + ) # log validation metrics wandb.log(