diff --git a/src/mrcnn/module.py b/src/mrcnn/module.py index 03da906..21a7720 100644 --- a/src/mrcnn/module.py +++ b/src/mrcnn/module.py @@ -67,7 +67,8 @@ class MRCNNModule(pl.LightningModule): return loss def on_validation_epoch_start(self): - self.metric = MeanAveragePrecision(iou_type="bbox") + self.metric_bbox = MeanAveragePrecision(iou_type="bbox") + self.metric_segm = MeanAveragePrecision(iou_type="segm") def validation_step(self, batch, batch_idx): # unpack batch @@ -75,16 +76,21 @@ class MRCNNModule(pl.LightningModule): preds = self.model(images) for pred, target in zip(preds, targets): - pred["masks"] = pred["masks"].squeeze(1).bool() - target["masks"] = target["masks"].squeeze(1).bool() - self.metric.update(preds, targets) + pred["masks"] = pred["masks"].squeeze(1).int().bool() + target["masks"] = target["masks"].squeeze(1).int().bool() + self.metric_bbox.update(preds, targets) + self.metric_segm.update(preds, targets) return preds def validation_epoch_end(self, outputs): # log metrics - metric_dict = self.metric.compute() - metric_dict = {f"valid/{key}": val for key, val in metric_dict.items()} + metric_dict = self.metric_bbox.compute() + metric_dict = {f"valid/bbox/{key}": val for key, val in metric_dict.items()} + self.log_dict(metric_dict) + + metric_dict = self.metric_segm.compute() + metric_dict = {f"valid/segm/{key}": val for key, val in metric_dict.items()} self.log_dict(metric_dict) def configure_optimizers(self):