mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
chore: bettered predict.ipynb
Former-commit-id: 375b2fa79b8b94f815f261598592d11d2ac8a92d [formerly 547c360afa04fc6986831422ce861a8470563c00] Former-commit-id: 2915fc814fcf760ee4af9fb2447fe4f1cf163d2e
This commit is contained in:
parent
c471342681
commit
04ddf75dd8
152
src/notebooks/module.py
Normal file
152
src/notebooks/module.py
Normal file
|
@ -0,0 +1,152 @@
|
||||||
|
"""Pytorch lightning wrapper for model."""
|
||||||
|
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
||||||
|
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
|
||||||
|
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_instance_segmentation(num_classes):
|
||||||
|
# load an instance segmentation model pre-trained on COCO
|
||||||
|
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
|
||||||
|
|
||||||
|
# get number of input features for the classifier
|
||||||
|
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
||||||
|
# replace the pre-trained head with a new one
|
||||||
|
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
||||||
|
|
||||||
|
# now get the number of input features for the mask classifier
|
||||||
|
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
|
||||||
|
hidden_layer = 256
|
||||||
|
# and replace the mask predictor with a new one
|
||||||
|
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class MRCNNModule(pl.LightningModule):
|
||||||
|
def __init__(self, hidden_layer_size, n_classes):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
self.hidden_layers_size = hidden_layer_size
|
||||||
|
self.n_classes = n_classes
|
||||||
|
|
||||||
|
# log hyperparameters
|
||||||
|
self.save_hyperparameters()
|
||||||
|
|
||||||
|
# Network
|
||||||
|
self.model = get_model_instance_segmentation(n_classes)
|
||||||
|
|
||||||
|
# onnx
|
||||||
|
self.example_input_array = torch.randn(1, 3, 512, 512, requires_grad=True)
|
||||||
|
|
||||||
|
def forward(self, imgs):
|
||||||
|
self.model.eval()
|
||||||
|
return self.model(imgs)
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
# unpack batch
|
||||||
|
images, targets = batch
|
||||||
|
|
||||||
|
# enable train mode
|
||||||
|
# self.model.train()
|
||||||
|
|
||||||
|
# fasterrcnn takes both images and targets for training
|
||||||
|
loss_dict = self.model(images, targets)
|
||||||
|
loss_dict = {f"train/{key}": val for key, val in loss_dict.items()}
|
||||||
|
loss = sum(loss_dict.values())
|
||||||
|
|
||||||
|
# log everything
|
||||||
|
self.log_dict(loss_dict)
|
||||||
|
self.log("train/loss", loss)
|
||||||
|
|
||||||
|
return {"loss": loss, "log": loss_dict}
|
||||||
|
|
||||||
|
# def validation_step(self, batch, batch_idx):
|
||||||
|
# # unpack batch
|
||||||
|
# images, targets = batch
|
||||||
|
|
||||||
|
# # enable eval mode
|
||||||
|
# # self.detector.eval()
|
||||||
|
|
||||||
|
# # make a prediction
|
||||||
|
# preds = self.model(images)
|
||||||
|
|
||||||
|
# # compute validation loss
|
||||||
|
# self.val_loss = torch.mean(
|
||||||
|
# torch.stack(
|
||||||
|
# [
|
||||||
|
# self.accuracy(
|
||||||
|
# target,
|
||||||
|
# pred["boxes"],
|
||||||
|
# iou_threshold=0.5,
|
||||||
|
# )
|
||||||
|
# for target, pred in zip(targets, preds)
|
||||||
|
# ],
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
|
||||||
|
# return self.val_loss
|
||||||
|
|
||||||
|
# def accuracy(self, src_boxes, pred_boxes, iou_threshold=1.0):
|
||||||
|
# """
|
||||||
|
# The accuracy method is not the one used in the evaluator but very similar
|
||||||
|
# """
|
||||||
|
# total_gt = len(src_boxes)
|
||||||
|
# total_pred = len(pred_boxes)
|
||||||
|
# if total_gt > 0 and total_pred > 0:
|
||||||
|
|
||||||
|
# # Define the matcher and distance matrix based on iou
|
||||||
|
# matcher = Matcher(iou_threshold, iou_threshold, allow_low_quality_matches=False)
|
||||||
|
# match_quality_matrix = box_iou(src_boxes, pred_boxes)
|
||||||
|
|
||||||
|
# results = matcher(match_quality_matrix)
|
||||||
|
|
||||||
|
# true_positive = torch.count_nonzero(results.unique() != -1)
|
||||||
|
# matched_elements = results[results > -1]
|
||||||
|
|
||||||
|
# # in Matcher, a pred element can be matched only twice
|
||||||
|
# false_positive = torch.count_nonzero(results == -1) + (
|
||||||
|
# len(matched_elements) - len(matched_elements.unique())
|
||||||
|
# )
|
||||||
|
# false_negative = total_gt - true_positive
|
||||||
|
|
||||||
|
# return true_positive / (true_positive + false_positive + false_negative)
|
||||||
|
|
||||||
|
# elif total_gt == 0:
|
||||||
|
# if total_pred > 0:
|
||||||
|
# return torch.tensor(0.0).cuda()
|
||||||
|
# else:
|
||||||
|
# return torch.tensor(1.0).cuda()
|
||||||
|
# elif total_gt > 0 and total_pred == 0:
|
||||||
|
# return torch.tensor(0.0).cuda()
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = torch.optim.SGD(
|
||||||
|
self.parameters(),
|
||||||
|
lr=wandb.config.LEARNING_RATE,
|
||||||
|
momentum=wandb.config.MOMENTUM,
|
||||||
|
weight_decay=wandb.config.WEIGHT_DECAY,
|
||||||
|
)
|
||||||
|
|
||||||
|
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||||
|
# optimizer,
|
||||||
|
# T_0=3,
|
||||||
|
# T_mult=1,
|
||||||
|
# lr=wandb.config.LEARNING_RATE_MIN,
|
||||||
|
# verbose=True,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# return {
|
||||||
|
# "optimizer": optimizer,
|
||||||
|
# "lr_scheduler": {
|
||||||
|
# "scheduler": scheduler,
|
||||||
|
# "monitor": "val_accuracy",
|
||||||
|
# },
|
||||||
|
# }
|
||||||
|
|
||||||
|
return optimizer
|
|
@ -1 +1 @@
|
||||||
351003dea551ce412daaac074d767fabd060cd72
|
3abcdffb1cbc384de89a39b26f5f57b7472df478
|
Loading…
Reference in a new issue