feat: basic maskrcnn loss logging
Former-commit-id: 2180d68e979deffb743fecb6b34f5d1a20c7d729 [formerly bf1933b5b68e107d94c3f0834fb7321407c4bf1e] Former-commit-id: 64219fc194f5f2ea41a4922afbc1e7d6dd9ee06a
This commit is contained in:
parent
4696885a30
commit
4ce22005cf
|
@ -3,17 +3,15 @@
|
|||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchvision
|
||||
from torchvision.models.detection._utils import Matcher
|
||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
||||
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
|
||||
from torchvision.ops.boxes import box_iou
|
||||
|
||||
import wandb
|
||||
|
||||
|
||||
def get_model_instance_segmentation(num_classes):
|
||||
# load an instance segmentation model pre-trained on COCO
|
||||
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
|
||||
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) # TODO: tester v2
|
||||
|
||||
# get number of input features for the classifier
|
||||
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
||||
|
@ -59,8 +57,11 @@ class MRCNNModule(pl.LightningModule):
|
|||
# fasterrcnn takes both images and targets for training
|
||||
loss_dict = self.model(images, targets)
|
||||
loss = sum(loss_dict.values())
|
||||
# self.log_dict(loss_dict)
|
||||
# self.log(loss)
|
||||
|
||||
# log everything
|
||||
self.log_dict(loss_dict)
|
||||
self.log("train/loss", loss)
|
||||
|
||||
return {"loss": loss, "log": loss_dict}
|
||||
|
||||
# def validation_step(self, batch, batch_idx):
|
||||
|
|
Loading…
Reference in a new issue