mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
feat: docstrings + typing
Former-commit-id: 38dfddce97808be4077aa7d943f34096429bc49c [formerly e5d46bbdd39dbd70b65f8deb5fe9ad66d0ecd9b0] Former-commit-id: 35f5b2cd55cc38756e03c013ee705b137b425239
This commit is contained in:
parent
ea9430a1ff
commit
185060d469
1
src/modules/__init__.py
Normal file
1
src/modules/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
from .mrcnn import MRCNNModule
|
|
@ -1,19 +1,31 @@
|
||||||
"""Pytorch lightning wrapper for model."""
|
"""Mask R-CNN Pytorch Lightning Module for Object Detection and Segmentation."""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
import wandb
|
import wandb
|
||||||
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
|
|
||||||
from torchmetrics.detection.mean_ap import MeanAveragePrecision
|
from torchmetrics.detection.mean_ap import MeanAveragePrecision
|
||||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
||||||
from torchvision.models.detection.mask_rcnn import (
|
from torchvision.models.detection.mask_rcnn import (
|
||||||
|
MaskRCNN,
|
||||||
MaskRCNN_ResNet50_FPN_Weights,
|
MaskRCNN_ResNet50_FPN_Weights,
|
||||||
MaskRCNNPredictor,
|
MaskRCNNPredictor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Prediction = List[Dict[str, torch.Tensor]]
|
||||||
|
|
||||||
def get_model_instance_segmentation(num_classes):
|
|
||||||
|
def get_model_instance_segmentation(n_classes: int) -> MaskRCNN:
|
||||||
|
"""Returns a Torchvision MaskRCNN model for finetunning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_classes (int): number of classes the model should predict, background included
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MaskRCNN: the model ready to be used
|
||||||
|
"""
|
||||||
# load an instance segmentation model pre-trained on COCO
|
# load an instance segmentation model pre-trained on COCO
|
||||||
model = torchvision.models.detection.maskrcnn_resnet50_fpn(
|
model = torchvision.models.detection.maskrcnn_resnet50_fpn(
|
||||||
weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT,
|
weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT,
|
||||||
|
@ -23,19 +35,26 @@ def get_model_instance_segmentation(num_classes):
|
||||||
# get number of input features for the classifier
|
# get number of input features for the classifier
|
||||||
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
||||||
# replace the pre-trained head with a new one
|
# replace the pre-trained head with a new one
|
||||||
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, n_classes)
|
||||||
|
|
||||||
# now get the number of input features for the mask classifier
|
# now get the number of input features for the mask classifier
|
||||||
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
|
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
|
||||||
hidden_layer = 256
|
hidden_layer = 256
|
||||||
# and replace the mask predictor with a new one
|
# and replace the mask predictor with a new one
|
||||||
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
|
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, n_classes)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class MRCNNModule(pl.LightningModule):
|
class MRCNNModule(pl.LightningModule):
|
||||||
def __init__(self, n_classes):
|
"""Mask R-CNN Pytorch Lightning Module encapsulating commong PyTorch functions."""
|
||||||
|
|
||||||
|
def __init__(self, n_classes: int) -> None:
|
||||||
|
"""Constructor, build model, save hyperparameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_classes (int): number of classes the model should predict, background included
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
|
@ -50,16 +69,38 @@ class MRCNNModule(pl.LightningModule):
|
||||||
# onnx export
|
# onnx export
|
||||||
self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True).half()
|
self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True).half()
|
||||||
|
|
||||||
def forward(self, imgs):
|
# torchmetrics
|
||||||
self.model.eval()
|
self.metric_bbox = MeanAveragePrecision(iou_type="bbox")
|
||||||
return self.model(imgs)
|
self.metric_segm = MeanAveragePrecision(iou_type="segm")
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def forward(self, imgs: torch.Tensor) -> Prediction: # type: ignore
|
||||||
|
"""Make a forward pass (prediction), usefull for onnx export.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
imgs (torch.Tensor): the images whose prediction we wish to make
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: the predictions
|
||||||
|
"""
|
||||||
|
self.model.eval()
|
||||||
|
pred: Prediction = self.model(imgs)
|
||||||
|
return pred
|
||||||
|
|
||||||
|
def training_step(self, batch: torch.Tensor, batch_idx: int) -> float: # type: ignore
|
||||||
|
"""PyTorch training step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (torch.Tensor): the batch to train the model on
|
||||||
|
batch_idx (int): the batch index number
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: the training loss of this step
|
||||||
|
"""
|
||||||
# unpack batch
|
# unpack batch
|
||||||
images, targets = batch
|
images, targets = batch
|
||||||
|
|
||||||
# compute loss
|
# compute loss
|
||||||
loss_dict = self.model(images, targets)
|
loss_dict: dict[str, float] = self.model(images, targets)
|
||||||
loss_dict = {f"train/{key}": val for key, val in loss_dict.items()}
|
loss_dict = {f"train/{key}": val for key, val in loss_dict.items()}
|
||||||
loss = sum(loss_dict.values())
|
loss = sum(loss_dict.values())
|
||||||
loss_dict["train/loss"] = loss
|
loss_dict["train/loss"] = loss
|
||||||
|
@ -69,15 +110,28 @@ class MRCNNModule(pl.LightningModule):
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def on_validation_epoch_start(self):
|
def on_validation_epoch_start(self) -> None:
|
||||||
self.metric_bbox = MeanAveragePrecision(iou_type="bbox")
|
"""Reset TorchMetrics."""
|
||||||
self.metric_segm = MeanAveragePrecision(iou_type="segm")
|
self.metric_bbox.reset()
|
||||||
|
self.metric_segm.reset()
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
def validation_step(self, batch: torch.Tensor, batch_idx: int) -> Prediction: # type: ignore
|
||||||
|
"""PyTorch validation step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (torch.Tensor): the batch to evaluate the model on
|
||||||
|
batch_idx (int): the batch index number
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: the predictions
|
||||||
|
"""
|
||||||
# unpack batch
|
# unpack batch
|
||||||
images, targets = batch
|
images, targets = batch
|
||||||
|
|
||||||
preds = self.model(images)
|
# make prediction
|
||||||
|
preds: Prediction = self.model(images)
|
||||||
|
|
||||||
|
# update TorchMetrics from predictions
|
||||||
for pred, target in zip(preds, targets):
|
for pred, target in zip(preds, targets):
|
||||||
pred["masks"] = pred["masks"].squeeze(1).int().bool()
|
pred["masks"] = pred["masks"].squeeze(1).int().bool()
|
||||||
target["masks"] = target["masks"].squeeze(1).int().bool()
|
target["masks"] = target["masks"].squeeze(1).int().bool()
|
||||||
|
@ -86,17 +140,28 @@ class MRCNNModule(pl.LightningModule):
|
||||||
|
|
||||||
return preds
|
return preds
|
||||||
|
|
||||||
def validation_epoch_end(self, outputs):
|
def validation_epoch_end(self, outputs: List[Prediction]) -> None: # type: ignore
|
||||||
# log metrics
|
"""Compute TorchMetrics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs (List[Prediction]): list of predictions from validation steps
|
||||||
|
"""
|
||||||
|
# compute and log bounding boxes metrics
|
||||||
metric_dict = self.metric_bbox.compute()
|
metric_dict = self.metric_bbox.compute()
|
||||||
metric_dict = {f"valid/bbox/{key}": val for key, val in metric_dict.items()}
|
metric_dict = {f"valid/bbox/{key}": val for key, val in metric_dict.items()}
|
||||||
self.log_dict(metric_dict)
|
self.log_dict(metric_dict)
|
||||||
|
|
||||||
|
# compute and log semgentation metrics
|
||||||
metric_dict = self.metric_segm.compute()
|
metric_dict = self.metric_segm.compute()
|
||||||
metric_dict = {f"valid/segm/{key}": val for key, val in metric_dict.items()}
|
metric_dict = {f"valid/segm/{key}": val for key, val in metric_dict.items()}
|
||||||
self.log_dict(metric_dict)
|
self.log_dict(metric_dict)
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self) -> Dict[str, Any]:
|
||||||
|
"""PyTorch optimizers and Schedulers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: dictionnary for PyTorch Lightning optimizer/scheduler configuration
|
||||||
|
"""
|
||||||
optimizer = torch.optim.Adam(
|
optimizer = torch.optim.Adam(
|
||||||
self.parameters(),
|
self.parameters(),
|
||||||
lr=wandb.config.LEARNING_RATE,
|
lr=wandb.config.LEARNING_RATE,
|
|
@ -1 +0,0 @@
|
||||||
from .module import MRCNNModule
|
|
Loading…
Reference in a new issue