mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
feat: compute mAP for both "bbox" and "segm"
Former-commit-id: 85379c46352b725ceaa7955d7d44b350ef02708a [formerly 2a503dbc4852efc4d21284ded746451ba9aaa495] Former-commit-id: b8a9c3e26b4085b0724b17e79bfa30b7727fb310
This commit is contained in:
parent
5fe7ceb306
commit
f50b758102
|
@ -67,7 +67,8 @@ class MRCNNModule(pl.LightningModule):
|
|||
return loss
|
||||
|
||||
def on_validation_epoch_start(self):
|
||||
self.metric = MeanAveragePrecision(iou_type="bbox")
|
||||
self.metric_bbox = MeanAveragePrecision(iou_type="bbox")
|
||||
self.metric_segm = MeanAveragePrecision(iou_type="segm")
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
# unpack batch
|
||||
|
@ -75,16 +76,21 @@ class MRCNNModule(pl.LightningModule):
|
|||
|
||||
preds = self.model(images)
|
||||
for pred, target in zip(preds, targets):
|
||||
pred["masks"] = pred["masks"].squeeze(1).bool()
|
||||
target["masks"] = target["masks"].squeeze(1).bool()
|
||||
self.metric.update(preds, targets)
|
||||
pred["masks"] = pred["masks"].squeeze(1).int().bool()
|
||||
target["masks"] = target["masks"].squeeze(1).int().bool()
|
||||
self.metric_bbox.update(preds, targets)
|
||||
self.metric_segm.update(preds, targets)
|
||||
|
||||
return preds
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
# log metrics
|
||||
metric_dict = self.metric.compute()
|
||||
metric_dict = {f"valid/{key}": val for key, val in metric_dict.items()}
|
||||
metric_dict = self.metric_bbox.compute()
|
||||
metric_dict = {f"valid/bbox/{key}": val for key, val in metric_dict.items()}
|
||||
self.log_dict(metric_dict)
|
||||
|
||||
metric_dict = self.metric_segm.compute()
|
||||
metric_dict = {f"valid/segm/{key}": val for key, val in metric_dict.items()}
|
||||
self.log_dict(metric_dict)
|
||||
|
||||
def configure_optimizers(self):
|
||||
|
|
Loading…
Reference in a new issue