feat: basic maskrcnn loss logging
Former-commit-id: 2180d68e979deffb743fecb6b34f5d1a20c7d729 [formerly bf1933b5b68e107d94c3f0834fb7321407c4bf1e] Former-commit-id: 64219fc194f5f2ea41a4922afbc1e7d6dd9ee06a
This commit is contained in:
parent
4696885a30
commit
4ce22005cf
|
@ -3,17 +3,15 @@
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from torchvision.models.detection._utils import Matcher
|
|
||||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
||||||
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
|
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
|
||||||
from torchvision.ops.boxes import box_iou
|
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
|
|
||||||
def get_model_instance_segmentation(num_classes):
|
def get_model_instance_segmentation(num_classes):
|
||||||
# load an instance segmentation model pre-trained on COCO
|
# load an instance segmentation model pre-trained on COCO
|
||||||
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
|
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) # TODO: tester v2
|
||||||
|
|
||||||
# get number of input features for the classifier
|
# get number of input features for the classifier
|
||||||
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
||||||
|
@ -59,8 +57,11 @@ class MRCNNModule(pl.LightningModule):
|
||||||
# fasterrcnn takes both images and targets for training
|
# fasterrcnn takes both images and targets for training
|
||||||
loss_dict = self.model(images, targets)
|
loss_dict = self.model(images, targets)
|
||||||
loss = sum(loss_dict.values())
|
loss = sum(loss_dict.values())
|
||||||
# self.log_dict(loss_dict)
|
|
||||||
# self.log(loss)
|
# log everything
|
||||||
|
self.log_dict(loss_dict)
|
||||||
|
self.log("train/loss", loss)
|
||||||
|
|
||||||
return {"loss": loss, "log": loss_dict}
|
return {"loss": loss, "log": loss_dict}
|
||||||
|
|
||||||
# def validation_step(self, batch, batch_idx):
|
# def validation_step(self, batch, batch_idx):
|
||||||
|
|
Loading…
Reference in a new issue