mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
fix: TableLog
Former-commit-id: 9ec175df257d4f969c3b024735a021197f070432 [formerly 96dfc4589152e497cfdc5d0f9291957307d2cda4] Former-commit-id: 2f7899f654158e5be4b4ae59db0cc524d1a8ead4
This commit is contained in:
parent
c312513eff
commit
56e24615e3
|
@ -4,10 +4,6 @@ from pytorch_lightning.callbacks import Callback
|
||||||
columns = [
|
columns = [
|
||||||
"ID",
|
"ID",
|
||||||
"image",
|
"image",
|
||||||
"ground truth",
|
|
||||||
"prediction",
|
|
||||||
"dice",
|
|
||||||
"dice_bin",
|
|
||||||
]
|
]
|
||||||
class_labels = {
|
class_labels = {
|
||||||
1: "sphere",
|
1: "sphere",
|
||||||
|
@ -20,34 +16,32 @@ class TableLog(Callback):
|
||||||
|
|
||||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||||
# unpacking
|
# unpacking
|
||||||
if batch_idx == 0:
|
if batch_idx == 2:
|
||||||
images, ground_truth = batch
|
images, targets = batch
|
||||||
metrics, predictions = outputs
|
|
||||||
|
|
||||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
for i, (image, target, pred) in enumerate(
|
||||||
zip(
|
zip(
|
||||||
images.cpu(),
|
images,
|
||||||
ground_truth.cpu(),
|
targets,
|
||||||
predictions["linear"].cpu().float(),
|
outputs,
|
||||||
predictions["binary"].cpu().squeeze(1).int().numpy(),
|
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
self.rows.append(
|
self.rows.append(
|
||||||
[
|
[
|
||||||
i,
|
i,
|
||||||
wandb.Image(img),
|
|
||||||
wandb.Image(mask),
|
|
||||||
wandb.Image(
|
wandb.Image(
|
||||||
pred,
|
image.cpu(),
|
||||||
masks={
|
masks={
|
||||||
|
"ground_truth": {
|
||||||
|
"mask_data": (target["masks"].cpu().sum(dim=0) > 0.5).int().numpy(),
|
||||||
|
"class_labels": class_labels,
|
||||||
|
},
|
||||||
"predictions": {
|
"predictions": {
|
||||||
"mask_data": pred_bin,
|
"mask_data": (pred["masks"].cpu().sum(dim=0) > 0.5).int().numpy(),
|
||||||
"class_labels": class_labels,
|
"class_labels": class_labels,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
metrics["dice"],
|
|
||||||
metrics["dice_bin"],
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -55,7 +49,7 @@ class TableLog(Callback):
|
||||||
# log table
|
# log table
|
||||||
wandb.log(
|
wandb.log(
|
||||||
{
|
{
|
||||||
"val/predictions": wandb.Table(
|
"valid/predictions": wandb.Table(
|
||||||
columns=columns,
|
columns=columns,
|
||||||
data=self.rows,
|
data=self.rows,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue