diff --git a/src/notebooks/module.py b/src/notebooks/module.py new file mode 100644 index 0000000..71c2e5f --- /dev/null +++ b/src/notebooks/module.py @@ -0,0 +1,152 @@ +"""Pytorch lightning wrapper for model.""" + +import pytorch_lightning as pl +import torch +import torchvision +from torchvision.models.detection.faster_rcnn import FastRCNNPredictor +from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor + +import wandb + + +def get_model_instance_segmentation(num_classes): + # load an instance segmentation model pre-trained on COCO + model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) + + # get number of input features for the classifier + in_features = model.roi_heads.box_predictor.cls_score.in_features + # replace the pre-trained head with a new one + model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) + + # now get the number of input features for the mask classifier + in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels + hidden_layer = 256 + # and replace the mask predictor with a new one + model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes) + + return model + + +class MRCNNModule(pl.LightningModule): + def __init__(self, hidden_layer_size, n_classes): + super().__init__() + + # Hyperparameters + self.hidden_layers_size = hidden_layer_size + self.n_classes = n_classes + + # log hyperparameters + self.save_hyperparameters() + + # Network + self.model = get_model_instance_segmentation(n_classes) + + # onnx + self.example_input_array = torch.randn(1, 3, 512, 512, requires_grad=True) + + def forward(self, imgs): + self.model.eval() + return self.model(imgs) + + def training_step(self, batch, batch_idx): + # unpack batch + images, targets = batch + + # enable train mode + # self.model.train() + + # fasterrcnn takes both images and targets for training + loss_dict = self.model(images, targets) + loss_dict = {f"train/{key}": val for key, val in loss_dict.items()} + loss = sum(loss_dict.values()) + + # log everything + self.log_dict(loss_dict) + self.log("train/loss", loss) + + return {"loss": loss, "log": loss_dict} + + # def validation_step(self, batch, batch_idx): + # # unpack batch + # images, targets = batch + + # # enable eval mode + # # self.detector.eval() + + # # make a prediction + # preds = self.model(images) + + # # compute validation loss + # self.val_loss = torch.mean( + # torch.stack( + # [ + # self.accuracy( + # target, + # pred["boxes"], + # iou_threshold=0.5, + # ) + # for target, pred in zip(targets, preds) + # ], + # ) + # ) + + # return self.val_loss + + # def accuracy(self, src_boxes, pred_boxes, iou_threshold=1.0): + # """ + # The accuracy method is not the one used in the evaluator but very similar + # """ + # total_gt = len(src_boxes) + # total_pred = len(pred_boxes) + # if total_gt > 0 and total_pred > 0: + + # # Define the matcher and distance matrix based on iou + # matcher = Matcher(iou_threshold, iou_threshold, allow_low_quality_matches=False) + # match_quality_matrix = box_iou(src_boxes, pred_boxes) + + # results = matcher(match_quality_matrix) + + # true_positive = torch.count_nonzero(results.unique() != -1) + # matched_elements = results[results > -1] + + # # in Matcher, a pred element can be matched only twice + # false_positive = torch.count_nonzero(results == -1) + ( + # len(matched_elements) - len(matched_elements.unique()) + # ) + # false_negative = total_gt - true_positive + + # return true_positive / (true_positive + false_positive + false_negative) + + # elif total_gt == 0: + # if total_pred > 0: + # return torch.tensor(0.0).cuda() + # else: + # return torch.tensor(1.0).cuda() + # elif total_gt > 0 and total_pred == 0: + # return torch.tensor(0.0).cuda() + + def configure_optimizers(self): + optimizer = torch.optim.SGD( + self.parameters(), + lr=wandb.config.LEARNING_RATE, + momentum=wandb.config.MOMENTUM, + weight_decay=wandb.config.WEIGHT_DECAY, + ) + + # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + # optimizer, + # T_0=3, + # T_mult=1, + # lr=wandb.config.LEARNING_RATE_MIN, + # verbose=True, + # ) + + # return { + # "optimizer": optimizer, + # "lr_scheduler": { + # "scheduler": scheduler, + # "monitor": "val_accuracy", + # }, + # } + + return optimizer diff --git a/src/notebooks/predict.ipynb.REMOVED.git-id b/src/notebooks/predict.ipynb.REMOVED.git-id index 44b9145..9012cc5 100644 --- a/src/notebooks/predict.ipynb.REMOVED.git-id +++ b/src/notebooks/predict.ipynb.REMOVED.git-id @@ -1 +1 @@ -351003dea551ce412daaac074d767fabd060cd72 \ No newline at end of file +3abcdffb1cbc384de89a39b26f5f57b7472df478 \ No newline at end of file