diff --git a/src/utils/callback.py b/src/utils/callback.py index 37366e3..b00c956 100644 --- a/src/utils/callback.py +++ b/src/utils/callback.py @@ -7,16 +7,54 @@ columns = [ ] class_labels = { 1: "sphere", + 2: "sphere_gt", } class TableLog(Callback): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if batch_idx == 0: + rows = [] + + # unpacking + images, targets = batch + + for i, (image, target) in enumerate( + zip( + images, + targets, + ) + ): + rows.append( + [ + i, + wandb.Image( + image.cpu(), + masks={ + "ground_truth": { + "mask_data": (target["masks"].cpu().sum(dim=0) > 0.5).int().numpy() * 2, + "class_labels": class_labels, + }, + }, + ), + ] + ) + + wandb.log( + { + "train/predictions": wandb.Table( + columns=columns, + data=rows, + ) + } + ) + def on_validation_epoch_start(self, trainer, pl_module): self.rows = [] def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - # unpacking if batch_idx == 2: + # unpacking images, targets = batch for i, (image, target, pred) in enumerate( @@ -26,6 +64,37 @@ class TableLog(Callback): outputs, ) ): + box_data_gt = [ + { + "position": { + "minX": int(target["boxes"][j][0]), + "minY": int(target["boxes"][j][1]), + "maxX": int(target["boxes"][j][2]), + "maxY": int(target["boxes"][j][3]), + }, + "domain": "pixel", + "class_id": 2, + "class_labels": class_labels, + } + for j in range(len(target["labels"])) + ] + + box_data = [ + { + "position": { + "minX": int(pred["boxes"][j][0]), + "minY": int(pred["boxes"][j][1]), + "maxX": int(pred["boxes"][j][2]), + "maxY": int(pred["boxes"][j][3]), + }, + "domain": "pixel", + "class_id": 1, + "box_caption": f"{pred['scores'][j]:0.3f}", + "class_labels": class_labels, + } + for j in range(len(pred["labels"])) + ] + self.rows.append( [ i, @@ -33,14 +102,18 @@ class TableLog(Callback): image.cpu(), masks={ "ground_truth": { - "mask_data": (target["masks"].cpu().sum(dim=0) > 0.5).int().numpy(), + "mask_data": target["masks"].cpu().sum(dim=0).int().numpy() * 2, "class_labels": class_labels, }, "predictions": { - "mask_data": (pred["masks"].cpu().sum(dim=0) > 0.5).int().numpy(), + "mask_data": pred["masks"].cpu().sum(dim=0).int().numpy(), "class_labels": class_labels, }, }, + boxes={ + "ground_truth": {"box_data": box_data_gt}, + "predictions": {"box_data": box_data}, + }, ), ] )