From f50b758102badccb5d564cdddc2c5b8c13a3f21a Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Mon, 12 Sep 2022 09:25:40 +0200 Subject: [PATCH] feat: compute mAP for both "bbox" and "segm" Former-commit-id: 85379c46352b725ceaa7955d7d44b350ef02708a [formerly 2a503dbc4852efc4d21284ded746451ba9aaa495] Former-commit-id: b8a9c3e26b4085b0724b17e79bfa30b7727fb310 --- src/mrcnn/module.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/mrcnn/module.py b/src/mrcnn/module.py index 03da906..21a7720 100644 --- a/src/mrcnn/module.py +++ b/src/mrcnn/module.py @@ -67,7 +67,8 @@ class MRCNNModule(pl.LightningModule): return loss def on_validation_epoch_start(self): - self.metric = MeanAveragePrecision(iou_type="bbox") + self.metric_bbox = MeanAveragePrecision(iou_type="bbox") + self.metric_segm = MeanAveragePrecision(iou_type="segm") def validation_step(self, batch, batch_idx): # unpack batch @@ -75,16 +76,21 @@ class MRCNNModule(pl.LightningModule): preds = self.model(images) for pred, target in zip(preds, targets): - pred["masks"] = pred["masks"].squeeze(1).bool() - target["masks"] = target["masks"].squeeze(1).bool() - self.metric.update(preds, targets) + pred["masks"] = pred["masks"].squeeze(1).int().bool() + target["masks"] = target["masks"].squeeze(1).int().bool() + self.metric_bbox.update(preds, targets) + self.metric_segm.update(preds, targets) return preds def validation_epoch_end(self, outputs): # log metrics - metric_dict = self.metric.compute() - metric_dict = {f"valid/{key}": val for key, val in metric_dict.items()} + metric_dict = self.metric_bbox.compute() + metric_dict = {f"valid/bbox/{key}": val for key, val in metric_dict.items()} + self.log_dict(metric_dict) + + metric_dict = self.metric_segm.compute() + metric_dict = {f"valid/segm/{key}": val for key, val in metric_dict.items()} self.log_dict(metric_dict) def configure_optimizers(self):