diff --git a/src/modules/__init__.py b/src/modules/__init__.py new file mode 100644 index 0000000..8076f54 --- /dev/null +++ b/src/modules/__init__.py @@ -0,0 +1 @@ +from .mrcnn import MRCNNModule diff --git a/src/mrcnn/module.py b/src/modules/mrcnn.py similarity index 54% rename from src/mrcnn/module.py rename to src/modules/mrcnn.py index fcee13b..da19a1e 100644 --- a/src/mrcnn/module.py +++ b/src/modules/mrcnn.py @@ -1,19 +1,31 @@ -"""Pytorch lightning wrapper for model.""" +"""Mask R-CNN Pytorch Lightning Module for Object Detection and Segmentation.""" + +from typing import Any, Dict, List import pytorch_lightning as pl import torch import torchvision import wandb -from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR from torchmetrics.detection.mean_ap import MeanAveragePrecision from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torchvision.models.detection.mask_rcnn import ( + MaskRCNN, MaskRCNN_ResNet50_FPN_Weights, MaskRCNNPredictor, ) +Prediction = List[Dict[str, torch.Tensor]] -def get_model_instance_segmentation(num_classes): + +def get_model_instance_segmentation(n_classes: int) -> MaskRCNN: + """Returns a Torchvision MaskRCNN model for finetunning. + + Args: + n_classes (int): number of classes the model should predict, background included + + Returns: + MaskRCNN: the model ready to be used + """ # load an instance segmentation model pre-trained on COCO model = torchvision.models.detection.maskrcnn_resnet50_fpn( weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT, @@ -23,19 +35,26 @@ def get_model_instance_segmentation(num_classes): # get number of input features for the classifier in_features = model.roi_heads.box_predictor.cls_score.in_features # replace the pre-trained head with a new one - model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) + model.roi_heads.box_predictor = FastRCNNPredictor(in_features, n_classes) # now get the number of input features for the mask classifier in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels hidden_layer = 256 # and replace the mask predictor with a new one - model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes) + model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, n_classes) return model class MRCNNModule(pl.LightningModule): - def __init__(self, n_classes): + """Mask R-CNN Pytorch Lightning Module encapsulating commong PyTorch functions.""" + + def __init__(self, n_classes: int) -> None: + """Constructor, build model, save hyperparameters. + + Args: + n_classes (int): number of classes the model should predict, background included + """ super().__init__() # Hyperparameters @@ -50,16 +69,38 @@ class MRCNNModule(pl.LightningModule): # onnx export self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True).half() - def forward(self, imgs): - self.model.eval() - return self.model(imgs) + # torchmetrics + self.metric_bbox = MeanAveragePrecision(iou_type="bbox") + self.metric_segm = MeanAveragePrecision(iou_type="segm") - def training_step(self, batch, batch_idx): + def forward(self, imgs: torch.Tensor) -> Prediction: # type: ignore + """Make a forward pass (prediction), usefull for onnx export. + + Args: + imgs (torch.Tensor): the images whose prediction we wish to make + + Returns: + torch.Tensor: the predictions + """ + self.model.eval() + pred: Prediction = self.model(imgs) + return pred + + def training_step(self, batch: torch.Tensor, batch_idx: int) -> float: # type: ignore + """PyTorch training step. + + Args: + batch (torch.Tensor): the batch to train the model on + batch_idx (int): the batch index number + + Returns: + float: the training loss of this step + """ # unpack batch images, targets = batch # compute loss - loss_dict = self.model(images, targets) + loss_dict: dict[str, float] = self.model(images, targets) loss_dict = {f"train/{key}": val for key, val in loss_dict.items()} loss = sum(loss_dict.values()) loss_dict["train/loss"] = loss @@ -69,15 +110,28 @@ class MRCNNModule(pl.LightningModule): return loss - def on_validation_epoch_start(self): - self.metric_bbox = MeanAveragePrecision(iou_type="bbox") - self.metric_segm = MeanAveragePrecision(iou_type="segm") + def on_validation_epoch_start(self) -> None: + """Reset TorchMetrics.""" + self.metric_bbox.reset() + self.metric_segm.reset() - def validation_step(self, batch, batch_idx): + def validation_step(self, batch: torch.Tensor, batch_idx: int) -> Prediction: # type: ignore + """PyTorch validation step. + + Args: + batch (torch.Tensor): the batch to evaluate the model on + batch_idx (int): the batch index number + + Returns: + torch.Tensor: the predictions + """ # unpack batch images, targets = batch - preds = self.model(images) + # make prediction + preds: Prediction = self.model(images) + + # update TorchMetrics from predictions for pred, target in zip(preds, targets): pred["masks"] = pred["masks"].squeeze(1).int().bool() target["masks"] = target["masks"].squeeze(1).int().bool() @@ -86,17 +140,28 @@ class MRCNNModule(pl.LightningModule): return preds - def validation_epoch_end(self, outputs): - # log metrics + def validation_epoch_end(self, outputs: List[Prediction]) -> None: # type: ignore + """Compute TorchMetrics. + + Args: + outputs (List[Prediction]): list of predictions from validation steps + """ + # compute and log bounding boxes metrics metric_dict = self.metric_bbox.compute() metric_dict = {f"valid/bbox/{key}": val for key, val in metric_dict.items()} self.log_dict(metric_dict) + # compute and log semgentation metrics metric_dict = self.metric_segm.compute() metric_dict = {f"valid/segm/{key}": val for key, val in metric_dict.items()} self.log_dict(metric_dict) - def configure_optimizers(self): + def configure_optimizers(self) -> Dict[str, Any]: + """PyTorch optimizers and Schedulers. + + Returns: + Dict[str, Any]: dictionnary for PyTorch Lightning optimizer/scheduler configuration + """ optimizer = torch.optim.Adam( self.parameters(), lr=wandb.config.LEARNING_RATE, diff --git a/src/mrcnn/__init__.py b/src/mrcnn/__init__.py deleted file mode 100644 index 8629291..0000000 --- a/src/mrcnn/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .module import MRCNNModule