fix: TableLog

Former-commit-id: 9ec175df257d4f969c3b024735a021197f070432 [formerly 96dfc4589152e497cfdc5d0f9291957307d2cda4]
Former-commit-id: 2f7899f654158e5be4b4ae59db0cc524d1a8ead4
This commit is contained in:
Laurent Fainsin 2022-09-07 10:42:58 +02:00
parent c312513eff
commit 56e24615e3

View file

@ -4,10 +4,6 @@ from pytorch_lightning.callbacks import Callback
columns = [
"ID",
"image",
"ground truth",
"prediction",
"dice",
"dice_bin",
]
class_labels = {
1: "sphere",
@ -20,34 +16,32 @@ class TableLog(Callback):
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
# unpacking
if batch_idx == 0:
images, ground_truth = batch
metrics, predictions = outputs
if batch_idx == 2:
images, targets = batch
for i, (img, mask, pred, pred_bin) in enumerate(
for i, (image, target, pred) in enumerate(
zip(
images.cpu(),
ground_truth.cpu(),
predictions["linear"].cpu().float(),
predictions["binary"].cpu().squeeze(1).int().numpy(),
images,
targets,
outputs,
)
):
self.rows.append(
[
i,
wandb.Image(img),
wandb.Image(mask),
wandb.Image(
pred,
image.cpu(),
masks={
"ground_truth": {
"mask_data": (target["masks"].cpu().sum(dim=0) > 0.5).int().numpy(),
"class_labels": class_labels,
},
"predictions": {
"mask_data": pred_bin,
"mask_data": (pred["masks"].cpu().sum(dim=0) > 0.5).int().numpy(),
"class_labels": class_labels,
},
},
),
metrics["dice"],
metrics["dice_bin"],
]
)
@ -55,7 +49,7 @@ class TableLog(Callback):
# log table
wandb.log(
{
"val/predictions": wandb.Table(
"valid/predictions": wandb.Table(
columns=columns,
data=self.rows,
)