feat: better image logging
Former-commit-id: 5ce04a5534cef72a3815be13cd79731800b7419f [formerly 47ab70b94c3f4a696b4e5a131e087660dba0b8ba] Former-commit-id: 81858cf5970cdecdbd0a151e3547a25857e3e958
This commit is contained in:
parent
0693f02d83
commit
b701afe363
|
@ -7,16 +7,54 @@ columns = [
|
||||||
]
|
]
|
||||||
class_labels = {
|
class_labels = {
|
||||||
1: "sphere",
|
1: "sphere",
|
||||||
|
2: "sphere_gt",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class TableLog(Callback):
|
class TableLog(Callback):
|
||||||
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||||
|
if batch_idx == 0:
|
||||||
|
rows = []
|
||||||
|
|
||||||
|
# unpacking
|
||||||
|
images, targets = batch
|
||||||
|
|
||||||
|
for i, (image, target) in enumerate(
|
||||||
|
zip(
|
||||||
|
images,
|
||||||
|
targets,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
rows.append(
|
||||||
|
[
|
||||||
|
i,
|
||||||
|
wandb.Image(
|
||||||
|
image.cpu(),
|
||||||
|
masks={
|
||||||
|
"ground_truth": {
|
||||||
|
"mask_data": (target["masks"].cpu().sum(dim=0) > 0.5).int().numpy() * 2,
|
||||||
|
"class_labels": class_labels,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
"train/predictions": wandb.Table(
|
||||||
|
columns=columns,
|
||||||
|
data=rows,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def on_validation_epoch_start(self, trainer, pl_module):
|
def on_validation_epoch_start(self, trainer, pl_module):
|
||||||
self.rows = []
|
self.rows = []
|
||||||
|
|
||||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||||
# unpacking
|
|
||||||
if batch_idx == 2:
|
if batch_idx == 2:
|
||||||
|
# unpacking
|
||||||
images, targets = batch
|
images, targets = batch
|
||||||
|
|
||||||
for i, (image, target, pred) in enumerate(
|
for i, (image, target, pred) in enumerate(
|
||||||
|
@ -26,6 +64,37 @@ class TableLog(Callback):
|
||||||
outputs,
|
outputs,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
|
box_data_gt = [
|
||||||
|
{
|
||||||
|
"position": {
|
||||||
|
"minX": int(target["boxes"][j][0]),
|
||||||
|
"minY": int(target["boxes"][j][1]),
|
||||||
|
"maxX": int(target["boxes"][j][2]),
|
||||||
|
"maxY": int(target["boxes"][j][3]),
|
||||||
|
},
|
||||||
|
"domain": "pixel",
|
||||||
|
"class_id": 2,
|
||||||
|
"class_labels": class_labels,
|
||||||
|
}
|
||||||
|
for j in range(len(target["labels"]))
|
||||||
|
]
|
||||||
|
|
||||||
|
box_data = [
|
||||||
|
{
|
||||||
|
"position": {
|
||||||
|
"minX": int(pred["boxes"][j][0]),
|
||||||
|
"minY": int(pred["boxes"][j][1]),
|
||||||
|
"maxX": int(pred["boxes"][j][2]),
|
||||||
|
"maxY": int(pred["boxes"][j][3]),
|
||||||
|
},
|
||||||
|
"domain": "pixel",
|
||||||
|
"class_id": 1,
|
||||||
|
"box_caption": f"{pred['scores'][j]:0.3f}",
|
||||||
|
"class_labels": class_labels,
|
||||||
|
}
|
||||||
|
for j in range(len(pred["labels"]))
|
||||||
|
]
|
||||||
|
|
||||||
self.rows.append(
|
self.rows.append(
|
||||||
[
|
[
|
||||||
i,
|
i,
|
||||||
|
@ -33,14 +102,18 @@ class TableLog(Callback):
|
||||||
image.cpu(),
|
image.cpu(),
|
||||||
masks={
|
masks={
|
||||||
"ground_truth": {
|
"ground_truth": {
|
||||||
"mask_data": (target["masks"].cpu().sum(dim=0) > 0.5).int().numpy(),
|
"mask_data": target["masks"].cpu().sum(dim=0).int().numpy() * 2,
|
||||||
"class_labels": class_labels,
|
"class_labels": class_labels,
|
||||||
},
|
},
|
||||||
"predictions": {
|
"predictions": {
|
||||||
"mask_data": (pred["masks"].cpu().sum(dim=0) > 0.5).int().numpy(),
|
"mask_data": pred["masks"].cpu().sum(dim=0).int().numpy(),
|
||||||
"class_labels": class_labels,
|
"class_labels": class_labels,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
boxes={
|
||||||
|
"ground_truth": {"box_data": box_data_gt},
|
||||||
|
"predictions": {"box_data": box_data},
|
||||||
|
},
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue