feat: docstrings + typing

Former-commit-id: 38dfddce97808be4077aa7d943f34096429bc49c [formerly e5d46bbdd39dbd70b65f8deb5fe9ad66d0ecd9b0]
Former-commit-id: 35f5b2cd55cc38756e03c013ee705b137b425239
This commit is contained in:
Laurent Fainsin 2022-09-12 10:59:37 +02:00
parent ea9430a1ff
commit 185060d469
3 changed files with 85 additions and 20 deletions

1
src/modules/__init__.py Normal file
View file

@ -0,0 +1 @@
from .mrcnn import MRCNNModule

View file

@ -1,19 +1,31 @@
"""Pytorch lightning wrapper for model.""" """Mask R-CNN Pytorch Lightning Module for Object Detection and Segmentation."""
from typing import Any, Dict, List
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torchvision import torchvision
import wandb import wandb
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from torchmetrics.detection.mean_ap import MeanAveragePrecision from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import ( from torchvision.models.detection.mask_rcnn import (
MaskRCNN,
MaskRCNN_ResNet50_FPN_Weights, MaskRCNN_ResNet50_FPN_Weights,
MaskRCNNPredictor, MaskRCNNPredictor,
) )
Prediction = List[Dict[str, torch.Tensor]]
def get_model_instance_segmentation(num_classes):
def get_model_instance_segmentation(n_classes: int) -> MaskRCNN:
"""Returns a Torchvision MaskRCNN model for finetunning.
Args:
n_classes (int): number of classes the model should predict, background included
Returns:
MaskRCNN: the model ready to be used
"""
# load an instance segmentation model pre-trained on COCO # load an instance segmentation model pre-trained on COCO
model = torchvision.models.detection.maskrcnn_resnet50_fpn( model = torchvision.models.detection.maskrcnn_resnet50_fpn(
weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT, weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT,
@ -23,19 +35,26 @@ def get_model_instance_segmentation(num_classes):
# get number of input features for the classifier # get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one # replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) model.roi_heads.box_predictor = FastRCNNPredictor(in_features, n_classes)
# now get the number of input features for the mask classifier # now get the number of input features for the mask classifier
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256 hidden_layer = 256
# and replace the mask predictor with a new one # and replace the mask predictor with a new one
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes) model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, n_classes)
return model return model
class MRCNNModule(pl.LightningModule): class MRCNNModule(pl.LightningModule):
def __init__(self, n_classes): """Mask R-CNN Pytorch Lightning Module encapsulating commong PyTorch functions."""
def __init__(self, n_classes: int) -> None:
"""Constructor, build model, save hyperparameters.
Args:
n_classes (int): number of classes the model should predict, background included
"""
super().__init__() super().__init__()
# Hyperparameters # Hyperparameters
@ -50,16 +69,38 @@ class MRCNNModule(pl.LightningModule):
# onnx export # onnx export
self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True).half() self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True).half()
def forward(self, imgs): # torchmetrics
self.model.eval() self.metric_bbox = MeanAveragePrecision(iou_type="bbox")
return self.model(imgs) self.metric_segm = MeanAveragePrecision(iou_type="segm")
def training_step(self, batch, batch_idx): def forward(self, imgs: torch.Tensor) -> Prediction: # type: ignore
"""Make a forward pass (prediction), usefull for onnx export.
Args:
imgs (torch.Tensor): the images whose prediction we wish to make
Returns:
torch.Tensor: the predictions
"""
self.model.eval()
pred: Prediction = self.model(imgs)
return pred
def training_step(self, batch: torch.Tensor, batch_idx: int) -> float: # type: ignore
"""PyTorch training step.
Args:
batch (torch.Tensor): the batch to train the model on
batch_idx (int): the batch index number
Returns:
float: the training loss of this step
"""
# unpack batch # unpack batch
images, targets = batch images, targets = batch
# compute loss # compute loss
loss_dict = self.model(images, targets) loss_dict: dict[str, float] = self.model(images, targets)
loss_dict = {f"train/{key}": val for key, val in loss_dict.items()} loss_dict = {f"train/{key}": val for key, val in loss_dict.items()}
loss = sum(loss_dict.values()) loss = sum(loss_dict.values())
loss_dict["train/loss"] = loss loss_dict["train/loss"] = loss
@ -69,15 +110,28 @@ class MRCNNModule(pl.LightningModule):
return loss return loss
def on_validation_epoch_start(self): def on_validation_epoch_start(self) -> None:
self.metric_bbox = MeanAveragePrecision(iou_type="bbox") """Reset TorchMetrics."""
self.metric_segm = MeanAveragePrecision(iou_type="segm") self.metric_bbox.reset()
self.metric_segm.reset()
def validation_step(self, batch, batch_idx): def validation_step(self, batch: torch.Tensor, batch_idx: int) -> Prediction: # type: ignore
"""PyTorch validation step.
Args:
batch (torch.Tensor): the batch to evaluate the model on
batch_idx (int): the batch index number
Returns:
torch.Tensor: the predictions
"""
# unpack batch # unpack batch
images, targets = batch images, targets = batch
preds = self.model(images) # make prediction
preds: Prediction = self.model(images)
# update TorchMetrics from predictions
for pred, target in zip(preds, targets): for pred, target in zip(preds, targets):
pred["masks"] = pred["masks"].squeeze(1).int().bool() pred["masks"] = pred["masks"].squeeze(1).int().bool()
target["masks"] = target["masks"].squeeze(1).int().bool() target["masks"] = target["masks"].squeeze(1).int().bool()
@ -86,17 +140,28 @@ class MRCNNModule(pl.LightningModule):
return preds return preds
def validation_epoch_end(self, outputs): def validation_epoch_end(self, outputs: List[Prediction]) -> None: # type: ignore
# log metrics """Compute TorchMetrics.
Args:
outputs (List[Prediction]): list of predictions from validation steps
"""
# compute and log bounding boxes metrics
metric_dict = self.metric_bbox.compute() metric_dict = self.metric_bbox.compute()
metric_dict = {f"valid/bbox/{key}": val for key, val in metric_dict.items()} metric_dict = {f"valid/bbox/{key}": val for key, val in metric_dict.items()}
self.log_dict(metric_dict) self.log_dict(metric_dict)
# compute and log semgentation metrics
metric_dict = self.metric_segm.compute() metric_dict = self.metric_segm.compute()
metric_dict = {f"valid/segm/{key}": val for key, val in metric_dict.items()} metric_dict = {f"valid/segm/{key}": val for key, val in metric_dict.items()}
self.log_dict(metric_dict) self.log_dict(metric_dict)
def configure_optimizers(self): def configure_optimizers(self) -> Dict[str, Any]:
"""PyTorch optimizers and Schedulers.
Returns:
Dict[str, Any]: dictionnary for PyTorch Lightning optimizer/scheduler configuration
"""
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
self.parameters(), self.parameters(),
lr=wandb.config.LEARNING_RATE, lr=wandb.config.LEARNING_RATE,

View file

@ -1 +0,0 @@
from .module import MRCNNModule