From 56e24615e3fe548dbe38b542fba2bf4bd107b63e Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Wed, 7 Sep 2022 10:42:58 +0200 Subject: [PATCH] fix: TableLog Former-commit-id: 9ec175df257d4f969c3b024735a021197f070432 [formerly 96dfc4589152e497cfdc5d0f9291957307d2cda4] Former-commit-id: 2f7899f654158e5be4b4ae59db0cc524d1a8ead4 --- src/utils/callback.py | 32 +++++++++++++------------------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/src/utils/callback.py b/src/utils/callback.py index 7908a73..37366e3 100644 --- a/src/utils/callback.py +++ b/src/utils/callback.py @@ -4,10 +4,6 @@ from pytorch_lightning.callbacks import Callback columns = [ "ID", "image", - "ground truth", - "prediction", - "dice", - "dice_bin", ] class_labels = { 1: "sphere", @@ -20,34 +16,32 @@ class TableLog(Callback): def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): # unpacking - if batch_idx == 0: - images, ground_truth = batch - metrics, predictions = outputs + if batch_idx == 2: + images, targets = batch - for i, (img, mask, pred, pred_bin) in enumerate( + for i, (image, target, pred) in enumerate( zip( - images.cpu(), - ground_truth.cpu(), - predictions["linear"].cpu().float(), - predictions["binary"].cpu().squeeze(1).int().numpy(), + images, + targets, + outputs, ) ): self.rows.append( [ i, - wandb.Image(img), - wandb.Image(mask), wandb.Image( - pred, + image.cpu(), masks={ + "ground_truth": { + "mask_data": (target["masks"].cpu().sum(dim=0) > 0.5).int().numpy(), + "class_labels": class_labels, + }, "predictions": { - "mask_data": pred_bin, + "mask_data": (pred["masks"].cpu().sum(dim=0) > 0.5).int().numpy(), "class_labels": class_labels, }, }, ), - metrics["dice"], - metrics["dice_bin"], ] ) @@ -55,7 +49,7 @@ class TableLog(Callback): # log table wandb.log( { - "val/predictions": wandb.Table( + "valid/predictions": wandb.Table( columns=columns, data=self.rows, )