feat: compute mAP for both "bbox" and "segm"

Former-commit-id: 85379c46352b725ceaa7955d7d44b350ef02708a [formerly 2a503dbc4852efc4d21284ded746451ba9aaa495]
Former-commit-id: b8a9c3e26b4085b0724b17e79bfa30b7727fb310
This commit is contained in:
Laurent Fainsin 2022-09-12 09:25:40 +02:00
parent 5fe7ceb306
commit f50b758102

View file

@ -67,7 +67,8 @@ class MRCNNModule(pl.LightningModule):
return loss return loss
def on_validation_epoch_start(self): def on_validation_epoch_start(self):
self.metric = MeanAveragePrecision(iou_type="bbox") self.metric_bbox = MeanAveragePrecision(iou_type="bbox")
self.metric_segm = MeanAveragePrecision(iou_type="segm")
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
# unpack batch # unpack batch
@ -75,16 +76,21 @@ class MRCNNModule(pl.LightningModule):
preds = self.model(images) preds = self.model(images)
for pred, target in zip(preds, targets): for pred, target in zip(preds, targets):
pred["masks"] = pred["masks"].squeeze(1).bool() pred["masks"] = pred["masks"].squeeze(1).int().bool()
target["masks"] = target["masks"].squeeze(1).bool() target["masks"] = target["masks"].squeeze(1).int().bool()
self.metric.update(preds, targets) self.metric_bbox.update(preds, targets)
self.metric_segm.update(preds, targets)
return preds return preds
def validation_epoch_end(self, outputs): def validation_epoch_end(self, outputs):
# log metrics # log metrics
metric_dict = self.metric.compute() metric_dict = self.metric_bbox.compute()
metric_dict = {f"valid/{key}": val for key, val in metric_dict.items()} metric_dict = {f"valid/bbox/{key}": val for key, val in metric_dict.items()}
self.log_dict(metric_dict)
metric_dict = self.metric_segm.compute()
metric_dict = {f"valid/segm/{key}": val for key, val in metric_dict.items()}
self.log_dict(metric_dict) self.log_dict(metric_dict)
def configure_optimizers(self): def configure_optimizers(self):