feat: docstrings + typing
Former-commit-id: 38dfddce97808be4077aa7d943f34096429bc49c [formerly e5d46bbdd39dbd70b65f8deb5fe9ad66d0ecd9b0] Former-commit-id: 35f5b2cd55cc38756e03c013ee705b137b425239
This commit is contained in:
parent
ea9430a1ff
commit
185060d469
1
src/modules/__init__.py
Normal file
1
src/modules/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from .mrcnn import MRCNNModule
|
|
@ -1,19 +1,31 @@
|
|||
"""Pytorch lightning wrapper for model."""
|
||||
"""Mask R-CNN Pytorch Lightning Module for Object Detection and Segmentation."""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchvision
|
||||
import wandb
|
||||
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
|
||||
from torchmetrics.detection.mean_ap import MeanAveragePrecision
|
||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
||||
from torchvision.models.detection.mask_rcnn import (
|
||||
MaskRCNN,
|
||||
MaskRCNN_ResNet50_FPN_Weights,
|
||||
MaskRCNNPredictor,
|
||||
)
|
||||
|
||||
Prediction = List[Dict[str, torch.Tensor]]
|
||||
|
||||
def get_model_instance_segmentation(num_classes):
|
||||
|
||||
def get_model_instance_segmentation(n_classes: int) -> MaskRCNN:
|
||||
"""Returns a Torchvision MaskRCNN model for finetunning.
|
||||
|
||||
Args:
|
||||
n_classes (int): number of classes the model should predict, background included
|
||||
|
||||
Returns:
|
||||
MaskRCNN: the model ready to be used
|
||||
"""
|
||||
# load an instance segmentation model pre-trained on COCO
|
||||
model = torchvision.models.detection.maskrcnn_resnet50_fpn(
|
||||
weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT,
|
||||
|
@ -23,19 +35,26 @@ def get_model_instance_segmentation(num_classes):
|
|||
# get number of input features for the classifier
|
||||
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
||||
# replace the pre-trained head with a new one
|
||||
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
||||
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, n_classes)
|
||||
|
||||
# now get the number of input features for the mask classifier
|
||||
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
|
||||
hidden_layer = 256
|
||||
# and replace the mask predictor with a new one
|
||||
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
|
||||
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, n_classes)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class MRCNNModule(pl.LightningModule):
|
||||
def __init__(self, n_classes):
|
||||
"""Mask R-CNN Pytorch Lightning Module encapsulating commong PyTorch functions."""
|
||||
|
||||
def __init__(self, n_classes: int) -> None:
|
||||
"""Constructor, build model, save hyperparameters.
|
||||
|
||||
Args:
|
||||
n_classes (int): number of classes the model should predict, background included
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Hyperparameters
|
||||
|
@ -50,16 +69,38 @@ class MRCNNModule(pl.LightningModule):
|
|||
# onnx export
|
||||
self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True).half()
|
||||
|
||||
def forward(self, imgs):
|
||||
self.model.eval()
|
||||
return self.model(imgs)
|
||||
# torchmetrics
|
||||
self.metric_bbox = MeanAveragePrecision(iou_type="bbox")
|
||||
self.metric_segm = MeanAveragePrecision(iou_type="segm")
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
def forward(self, imgs: torch.Tensor) -> Prediction: # type: ignore
|
||||
"""Make a forward pass (prediction), usefull for onnx export.
|
||||
|
||||
Args:
|
||||
imgs (torch.Tensor): the images whose prediction we wish to make
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the predictions
|
||||
"""
|
||||
self.model.eval()
|
||||
pred: Prediction = self.model(imgs)
|
||||
return pred
|
||||
|
||||
def training_step(self, batch: torch.Tensor, batch_idx: int) -> float: # type: ignore
|
||||
"""PyTorch training step.
|
||||
|
||||
Args:
|
||||
batch (torch.Tensor): the batch to train the model on
|
||||
batch_idx (int): the batch index number
|
||||
|
||||
Returns:
|
||||
float: the training loss of this step
|
||||
"""
|
||||
# unpack batch
|
||||
images, targets = batch
|
||||
|
||||
# compute loss
|
||||
loss_dict = self.model(images, targets)
|
||||
loss_dict: dict[str, float] = self.model(images, targets)
|
||||
loss_dict = {f"train/{key}": val for key, val in loss_dict.items()}
|
||||
loss = sum(loss_dict.values())
|
||||
loss_dict["train/loss"] = loss
|
||||
|
@ -69,15 +110,28 @@ class MRCNNModule(pl.LightningModule):
|
|||
|
||||
return loss
|
||||
|
||||
def on_validation_epoch_start(self):
|
||||
self.metric_bbox = MeanAveragePrecision(iou_type="bbox")
|
||||
self.metric_segm = MeanAveragePrecision(iou_type="segm")
|
||||
def on_validation_epoch_start(self) -> None:
|
||||
"""Reset TorchMetrics."""
|
||||
self.metric_bbox.reset()
|
||||
self.metric_segm.reset()
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
def validation_step(self, batch: torch.Tensor, batch_idx: int) -> Prediction: # type: ignore
|
||||
"""PyTorch validation step.
|
||||
|
||||
Args:
|
||||
batch (torch.Tensor): the batch to evaluate the model on
|
||||
batch_idx (int): the batch index number
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the predictions
|
||||
"""
|
||||
# unpack batch
|
||||
images, targets = batch
|
||||
|
||||
preds = self.model(images)
|
||||
# make prediction
|
||||
preds: Prediction = self.model(images)
|
||||
|
||||
# update TorchMetrics from predictions
|
||||
for pred, target in zip(preds, targets):
|
||||
pred["masks"] = pred["masks"].squeeze(1).int().bool()
|
||||
target["masks"] = target["masks"].squeeze(1).int().bool()
|
||||
|
@ -86,17 +140,28 @@ class MRCNNModule(pl.LightningModule):
|
|||
|
||||
return preds
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
# log metrics
|
||||
def validation_epoch_end(self, outputs: List[Prediction]) -> None: # type: ignore
|
||||
"""Compute TorchMetrics.
|
||||
|
||||
Args:
|
||||
outputs (List[Prediction]): list of predictions from validation steps
|
||||
"""
|
||||
# compute and log bounding boxes metrics
|
||||
metric_dict = self.metric_bbox.compute()
|
||||
metric_dict = {f"valid/bbox/{key}": val for key, val in metric_dict.items()}
|
||||
self.log_dict(metric_dict)
|
||||
|
||||
# compute and log semgentation metrics
|
||||
metric_dict = self.metric_segm.compute()
|
||||
metric_dict = {f"valid/segm/{key}": val for key, val in metric_dict.items()}
|
||||
self.log_dict(metric_dict)
|
||||
|
||||
def configure_optimizers(self):
|
||||
def configure_optimizers(self) -> Dict[str, Any]:
|
||||
"""PyTorch optimizers and Schedulers.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: dictionnary for PyTorch Lightning optimizer/scheduler configuration
|
||||
"""
|
||||
optimizer = torch.optim.Adam(
|
||||
self.parameters(),
|
||||
lr=wandb.config.LEARNING_RATE,
|
|
@ -1 +0,0 @@
|
|||
from .module import MRCNNModule
|
Loading…
Reference in a new issue