mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
feat: better image logging
Former-commit-id: 5ce04a5534cef72a3815be13cd79731800b7419f [formerly 47ab70b94c3f4a696b4e5a131e087660dba0b8ba] Former-commit-id: 81858cf5970cdecdbd0a151e3547a25857e3e958
This commit is contained in:
parent
0693f02d83
commit
b701afe363
|
@ -7,16 +7,54 @@ columns = [
|
|||
]
|
||||
class_labels = {
|
||||
1: "sphere",
|
||||
2: "sphere_gt",
|
||||
}
|
||||
|
||||
|
||||
class TableLog(Callback):
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
if batch_idx == 0:
|
||||
rows = []
|
||||
|
||||
# unpacking
|
||||
images, targets = batch
|
||||
|
||||
for i, (image, target) in enumerate(
|
||||
zip(
|
||||
images,
|
||||
targets,
|
||||
)
|
||||
):
|
||||
rows.append(
|
||||
[
|
||||
i,
|
||||
wandb.Image(
|
||||
image.cpu(),
|
||||
masks={
|
||||
"ground_truth": {
|
||||
"mask_data": (target["masks"].cpu().sum(dim=0) > 0.5).int().numpy() * 2,
|
||||
"class_labels": class_labels,
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
wandb.log(
|
||||
{
|
||||
"train/predictions": wandb.Table(
|
||||
columns=columns,
|
||||
data=rows,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
def on_validation_epoch_start(self, trainer, pl_module):
|
||||
self.rows = []
|
||||
|
||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
# unpacking
|
||||
if batch_idx == 2:
|
||||
# unpacking
|
||||
images, targets = batch
|
||||
|
||||
for i, (image, target, pred) in enumerate(
|
||||
|
@ -26,6 +64,37 @@ class TableLog(Callback):
|
|||
outputs,
|
||||
)
|
||||
):
|
||||
box_data_gt = [
|
||||
{
|
||||
"position": {
|
||||
"minX": int(target["boxes"][j][0]),
|
||||
"minY": int(target["boxes"][j][1]),
|
||||
"maxX": int(target["boxes"][j][2]),
|
||||
"maxY": int(target["boxes"][j][3]),
|
||||
},
|
||||
"domain": "pixel",
|
||||
"class_id": 2,
|
||||
"class_labels": class_labels,
|
||||
}
|
||||
for j in range(len(target["labels"]))
|
||||
]
|
||||
|
||||
box_data = [
|
||||
{
|
||||
"position": {
|
||||
"minX": int(pred["boxes"][j][0]),
|
||||
"minY": int(pred["boxes"][j][1]),
|
||||
"maxX": int(pred["boxes"][j][2]),
|
||||
"maxY": int(pred["boxes"][j][3]),
|
||||
},
|
||||
"domain": "pixel",
|
||||
"class_id": 1,
|
||||
"box_caption": f"{pred['scores'][j]:0.3f}",
|
||||
"class_labels": class_labels,
|
||||
}
|
||||
for j in range(len(pred["labels"]))
|
||||
]
|
||||
|
||||
self.rows.append(
|
||||
[
|
||||
i,
|
||||
|
@ -33,14 +102,18 @@ class TableLog(Callback):
|
|||
image.cpu(),
|
||||
masks={
|
||||
"ground_truth": {
|
||||
"mask_data": (target["masks"].cpu().sum(dim=0) > 0.5).int().numpy(),
|
||||
"mask_data": target["masks"].cpu().sum(dim=0).int().numpy() * 2,
|
||||
"class_labels": class_labels,
|
||||
},
|
||||
"predictions": {
|
||||
"mask_data": (pred["masks"].cpu().sum(dim=0) > 0.5).int().numpy(),
|
||||
"mask_data": pred["masks"].cpu().sum(dim=0).int().numpy(),
|
||||
"class_labels": class_labels,
|
||||
},
|
||||
},
|
||||
boxes={
|
||||
"ground_truth": {"box_data": box_data_gt},
|
||||
"predictions": {"box_data": box_data},
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue