feat: basic maskrcnn loss logging

Former-commit-id: 2180d68e979deffb743fecb6b34f5d1a20c7d729 [formerly bf1933b5b68e107d94c3f0834fb7321407c4bf1e]
Former-commit-id: 64219fc194f5f2ea41a4922afbc1e7d6dd9ee06a
This commit is contained in:
Laurent Fainsin 2022-08-26 11:09:25 +02:00
parent 4696885a30
commit 4ce22005cf

View file

@ -3,17 +3,15 @@
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torchvision import torchvision
from torchvision.models.detection._utils import Matcher
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.ops.boxes import box_iou
import wandb import wandb
def get_model_instance_segmentation(num_classes): def get_model_instance_segmentation(num_classes):
# load an instance segmentation model pre-trained on COCO # load an instance segmentation model pre-trained on COCO
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) # TODO: tester v2
# get number of input features for the classifier # get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features in_features = model.roi_heads.box_predictor.cls_score.in_features
@ -59,8 +57,11 @@ class MRCNNModule(pl.LightningModule):
# fasterrcnn takes both images and targets for training # fasterrcnn takes both images and targets for training
loss_dict = self.model(images, targets) loss_dict = self.model(images, targets)
loss = sum(loss_dict.values()) loss = sum(loss_dict.values())
# self.log_dict(loss_dict)
# self.log(loss) # log everything
self.log_dict(loss_dict)
self.log("train/loss", loss)
return {"loss": loss, "log": loss_dict} return {"loss": loss, "log": loss_dict}
# def validation_step(self, batch, batch_idx): # def validation_step(self, batch, batch_idx):