feat: compute mAP for both "bbox" and "segm"
Former-commit-id: 85379c46352b725ceaa7955d7d44b350ef02708a [formerly 2a503dbc4852efc4d21284ded746451ba9aaa495] Former-commit-id: b8a9c3e26b4085b0724b17e79bfa30b7727fb310
This commit is contained in:
parent
5fe7ceb306
commit
f50b758102
|
@ -67,7 +67,8 @@ class MRCNNModule(pl.LightningModule):
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def on_validation_epoch_start(self):
|
def on_validation_epoch_start(self):
|
||||||
self.metric = MeanAveragePrecision(iou_type="bbox")
|
self.metric_bbox = MeanAveragePrecision(iou_type="bbox")
|
||||||
|
self.metric_segm = MeanAveragePrecision(iou_type="segm")
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
def validation_step(self, batch, batch_idx):
|
||||||
# unpack batch
|
# unpack batch
|
||||||
|
@ -75,16 +76,21 @@ class MRCNNModule(pl.LightningModule):
|
||||||
|
|
||||||
preds = self.model(images)
|
preds = self.model(images)
|
||||||
for pred, target in zip(preds, targets):
|
for pred, target in zip(preds, targets):
|
||||||
pred["masks"] = pred["masks"].squeeze(1).bool()
|
pred["masks"] = pred["masks"].squeeze(1).int().bool()
|
||||||
target["masks"] = target["masks"].squeeze(1).bool()
|
target["masks"] = target["masks"].squeeze(1).int().bool()
|
||||||
self.metric.update(preds, targets)
|
self.metric_bbox.update(preds, targets)
|
||||||
|
self.metric_segm.update(preds, targets)
|
||||||
|
|
||||||
return preds
|
return preds
|
||||||
|
|
||||||
def validation_epoch_end(self, outputs):
|
def validation_epoch_end(self, outputs):
|
||||||
# log metrics
|
# log metrics
|
||||||
metric_dict = self.metric.compute()
|
metric_dict = self.metric_bbox.compute()
|
||||||
metric_dict = {f"valid/{key}": val for key, val in metric_dict.items()}
|
metric_dict = {f"valid/bbox/{key}": val for key, val in metric_dict.items()}
|
||||||
|
self.log_dict(metric_dict)
|
||||||
|
|
||||||
|
metric_dict = self.metric_segm.compute()
|
||||||
|
metric_dict = {f"valid/segm/{key}": val for key, val in metric_dict.items()}
|
||||||
self.log_dict(metric_dict)
|
self.log_dict(metric_dict)
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
|
|
Loading…
Reference in a new issue