feat: better image logging

Former-commit-id: 5ce04a5534cef72a3815be13cd79731800b7419f [formerly 47ab70b94c3f4a696b4e5a131e087660dba0b8ba]
Former-commit-id: 81858cf5970cdecdbd0a151e3547a25857e3e958
This commit is contained in:
Laurent Fainsin 2022-09-12 09:28:11 +02:00
parent 0693f02d83
commit b701afe363

View file

@ -7,16 +7,54 @@ columns = [
] ]
class_labels = { class_labels = {
1: "sphere", 1: "sphere",
2: "sphere_gt",
} }
class TableLog(Callback): class TableLog(Callback):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if batch_idx == 0:
rows = []
# unpacking
images, targets = batch
for i, (image, target) in enumerate(
zip(
images,
targets,
)
):
rows.append(
[
i,
wandb.Image(
image.cpu(),
masks={
"ground_truth": {
"mask_data": (target["masks"].cpu().sum(dim=0) > 0.5).int().numpy() * 2,
"class_labels": class_labels,
},
},
),
]
)
wandb.log(
{
"train/predictions": wandb.Table(
columns=columns,
data=rows,
)
}
)
def on_validation_epoch_start(self, trainer, pl_module): def on_validation_epoch_start(self, trainer, pl_module):
self.rows = [] self.rows = []
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
# unpacking
if batch_idx == 2: if batch_idx == 2:
# unpacking
images, targets = batch images, targets = batch
for i, (image, target, pred) in enumerate( for i, (image, target, pred) in enumerate(
@ -26,6 +64,37 @@ class TableLog(Callback):
outputs, outputs,
) )
): ):
box_data_gt = [
{
"position": {
"minX": int(target["boxes"][j][0]),
"minY": int(target["boxes"][j][1]),
"maxX": int(target["boxes"][j][2]),
"maxY": int(target["boxes"][j][3]),
},
"domain": "pixel",
"class_id": 2,
"class_labels": class_labels,
}
for j in range(len(target["labels"]))
]
box_data = [
{
"position": {
"minX": int(pred["boxes"][j][0]),
"minY": int(pred["boxes"][j][1]),
"maxX": int(pred["boxes"][j][2]),
"maxY": int(pred["boxes"][j][3]),
},
"domain": "pixel",
"class_id": 1,
"box_caption": f"{pred['scores'][j]:0.3f}",
"class_labels": class_labels,
}
for j in range(len(pred["labels"]))
]
self.rows.append( self.rows.append(
[ [
i, i,
@ -33,14 +102,18 @@ class TableLog(Callback):
image.cpu(), image.cpu(),
masks={ masks={
"ground_truth": { "ground_truth": {
"mask_data": (target["masks"].cpu().sum(dim=0) > 0.5).int().numpy(), "mask_data": target["masks"].cpu().sum(dim=0).int().numpy() * 2,
"class_labels": class_labels, "class_labels": class_labels,
}, },
"predictions": { "predictions": {
"mask_data": (pred["masks"].cpu().sum(dim=0) > 0.5).int().numpy(), "mask_data": pred["masks"].cpu().sum(dim=0).int().numpy(),
"class_labels": class_labels, "class_labels": class_labels,
}, },
}, },
boxes={
"ground_truth": {"box_data": box_data_gt},
"predictions": {"box_data": box_data},
},
), ),
] ]
) )