mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
fix: TableLog
Former-commit-id: 9ec175df257d4f969c3b024735a021197f070432 [formerly 96dfc4589152e497cfdc5d0f9291957307d2cda4] Former-commit-id: 2f7899f654158e5be4b4ae59db0cc524d1a8ead4
This commit is contained in:
parent
c312513eff
commit
56e24615e3
|
@ -4,10 +4,6 @@ from pytorch_lightning.callbacks import Callback
|
|||
columns = [
|
||||
"ID",
|
||||
"image",
|
||||
"ground truth",
|
||||
"prediction",
|
||||
"dice",
|
||||
"dice_bin",
|
||||
]
|
||||
class_labels = {
|
||||
1: "sphere",
|
||||
|
@ -20,34 +16,32 @@ class TableLog(Callback):
|
|||
|
||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
# unpacking
|
||||
if batch_idx == 0:
|
||||
images, ground_truth = batch
|
||||
metrics, predictions = outputs
|
||||
if batch_idx == 2:
|
||||
images, targets = batch
|
||||
|
||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||
for i, (image, target, pred) in enumerate(
|
||||
zip(
|
||||
images.cpu(),
|
||||
ground_truth.cpu(),
|
||||
predictions["linear"].cpu().float(),
|
||||
predictions["binary"].cpu().squeeze(1).int().numpy(),
|
||||
images,
|
||||
targets,
|
||||
outputs,
|
||||
)
|
||||
):
|
||||
self.rows.append(
|
||||
[
|
||||
i,
|
||||
wandb.Image(img),
|
||||
wandb.Image(mask),
|
||||
wandb.Image(
|
||||
pred,
|
||||
image.cpu(),
|
||||
masks={
|
||||
"ground_truth": {
|
||||
"mask_data": (target["masks"].cpu().sum(dim=0) > 0.5).int().numpy(),
|
||||
"class_labels": class_labels,
|
||||
},
|
||||
"predictions": {
|
||||
"mask_data": pred_bin,
|
||||
"mask_data": (pred["masks"].cpu().sum(dim=0) > 0.5).int().numpy(),
|
||||
"class_labels": class_labels,
|
||||
},
|
||||
},
|
||||
),
|
||||
metrics["dice"],
|
||||
metrics["dice_bin"],
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -55,7 +49,7 @@ class TableLog(Callback):
|
|||
# log table
|
||||
wandb.log(
|
||||
{
|
||||
"val/predictions": wandb.Table(
|
||||
"valid/predictions": wandb.Table(
|
||||
columns=columns,
|
||||
data=self.rows,
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue