fix: TableLog

Former-commit-id: 9ec175df257d4f969c3b024735a021197f070432 [formerly 96dfc4589152e497cfdc5d0f9291957307d2cda4]
Former-commit-id: 2f7899f654158e5be4b4ae59db0cc524d1a8ead4
This commit is contained in:
Laurent Fainsin 2022-09-07 10:42:58 +02:00
parent c312513eff
commit 56e24615e3

View file

@ -4,10 +4,6 @@ from pytorch_lightning.callbacks import Callback
columns = [ columns = [
"ID", "ID",
"image", "image",
"ground truth",
"prediction",
"dice",
"dice_bin",
] ]
class_labels = { class_labels = {
1: "sphere", 1: "sphere",
@ -20,34 +16,32 @@ class TableLog(Callback):
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
# unpacking # unpacking
if batch_idx == 0: if batch_idx == 2:
images, ground_truth = batch images, targets = batch
metrics, predictions = outputs
for i, (img, mask, pred, pred_bin) in enumerate( for i, (image, target, pred) in enumerate(
zip( zip(
images.cpu(), images,
ground_truth.cpu(), targets,
predictions["linear"].cpu().float(), outputs,
predictions["binary"].cpu().squeeze(1).int().numpy(),
) )
): ):
self.rows.append( self.rows.append(
[ [
i, i,
wandb.Image(img),
wandb.Image(mask),
wandb.Image( wandb.Image(
pred, image.cpu(),
masks={ masks={
"ground_truth": {
"mask_data": (target["masks"].cpu().sum(dim=0) > 0.5).int().numpy(),
"class_labels": class_labels,
},
"predictions": { "predictions": {
"mask_data": pred_bin, "mask_data": (pred["masks"].cpu().sum(dim=0) > 0.5).int().numpy(),
"class_labels": class_labels, "class_labels": class_labels,
}, },
}, },
), ),
metrics["dice"],
metrics["dice_bin"],
] ]
) )
@ -55,7 +49,7 @@ class TableLog(Callback):
# log table # log table
wandb.log( wandb.log(
{ {
"val/predictions": wandb.Table( "valid/predictions": wandb.Table(
columns=columns, columns=columns,
data=self.rows, data=self.rows,
) )