feat: WIP, replacing U-Net by Mask R-CNN

Former-commit-id: f51a572adac901ff588e3a467f39ecd26376e617 [formerly 376595d7e5f906928379e25c1246e304b96b156d]
Former-commit-id: 3f4772ba3483702be6e5f7a29f06be93eb1f3bb2
This commit is contained in:
Laurent Fainsin 2022-08-24 14:56:41 +02:00
parent c50235bb1e
commit 4dab157dda
6 changed files with 236 additions and 19 deletions

View file

@ -1,8 +1,8 @@
import albumentations as A
import pytorch_lightning as pl
import wandb
from torch.utils.data import DataLoader, Subset
import wandb
from utils import RandomPaste
from .dataset import LabeledDataset, LabeledDataset2, SyntheticDataset
@ -26,7 +26,7 @@ class Spheres(pl.LightningDataModule):
# dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform)
dataset = LabeledDataset2(image_dir="/media/disk1/lfainsin/TRAIN_prerender/")
dataset = LabeledDataset2(image_dir="/media/disk1/lfainsin/TEST_tmp_mrcnn/")
dataset = Subset(dataset, list(range(len(dataset)))) # somhow this allows to better utilize the gpu
return DataLoader(
@ -38,15 +38,15 @@ class Spheres(pl.LightningDataModule):
pin_memory=wandb.config.PIN_MEMORY,
)
def val_dataloader(self):
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
dataset = Subset(dataset, list(range(len(dataset)))) # somhow this allows to better utilize the gpu
# def val_dataloader(self):
# dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
# dataset = Subset(dataset, list(range(len(dataset)))) # somhow this allows to better utilize the gpu
return DataLoader(
dataset,
shuffle=False,
prefetch_factor=wandb.config.PREFETCH_FACTOR,
batch_size=wandb.config.VAL_BATCH_SIZE,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,
)
# return DataLoader(
# dataset,
# shuffle=False,
# prefetch_factor=wandb.config.PREFETCH_FACTOR,
# batch_size=wandb.config.VAL_BATCH_SIZE,
# num_workers=wandb.config.WORKERS,
# pin_memory=wandb.config.PIN_MEMORY,
# )

View file

@ -1,7 +1,9 @@
import os
from pathlib import Path
import albumentations as A
import numpy as np
import torch
from albumentations.pytorch import ToTensorV2
from PIL import Image
from torch.utils.data import Dataset
@ -111,3 +113,66 @@ class LabeledDataset2(Dataset):
mask = mask.float()
return image, mask
class LabeledDataset3(object):
def __init__(self, root, transforms):
self.root = root
self.transforms = transforms
# load all image files, sorting them to ensure that they are aligned
self.imgs = list(sorted(os.listdir(os.path.join(root, "images"))))
self.masks = list(sorted(os.listdir(os.path.join(root, "masks"))))
def __getitem__(self, idx):
# create paths from ids
img_path = os.path.join(self.root, "images", self.imgs[idx])
mask_path = os.path.join(self.root, "masks", self.masks[idx])
# load image and mask
img = Image.open(img_path).convert("RGB")
mask = Image.open(mask_path)
# convert mask to numpy array to apply operations
mask = np.array(mask)
obj_ids = np.unique(mask)
obj_ids = obj_ids[1:] # first id is the background, so remove it
# split the color-encoded mask into a set of binary masks
masks = mask == obj_ids[:, None, None]
# get bounding box coordinates for each mask
num_objs = len(obj_ids)
bboxes = []
for i in range(num_objs):
pos = np.where(masks[i])
xmin = np.min(pos[1])
xmax = np.max(pos[1])
ymin = np.min(pos[0])
ymax = np.max(pos[0])
bboxes.append([xmin, ymin, xmax, ymax])
# convert arrays to tensors
bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
labels = torch.ones((num_objs,), dtype=torch.int64) # there is only one class
masks = torch.as_tensor(masks, dtype=torch.uint8)
# image_id = torch.tensor([idx])
# area = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
# iscrowd = torch.zeros((num_objs,), dtype=torch.int64) # suppose all instances are not crowd
target = {}
target["boxes"] = bboxes
target["labels"] = labels
target["masks"] = masks
# target["image_id"] = image_id
# target["area"] = area
# target["iscrowd"] = iscrowd
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
def __len__(self):
return len(self.imgs)

1
src/mrcnn/__init__.py Normal file
View file

@ -0,0 +1 @@
from .module import MRCNNModule

144
src/mrcnn/module.py Normal file
View file

@ -0,0 +1,144 @@
"""Pytorch lightning wrapper for model."""
import pytorch_lightning as pl
import torch
import torchvision
import wandb
from torchvision.models.detection._utils import Matcher
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.ops.boxes import box_iou
def get_model_instance_segmentation(num_classes):
# load an instance segmentation model pre-trained on COCO
model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# now get the number of input features for the mask classifier
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
# and replace the mask predictor with a new one
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
return model
class MRCNNModule(pl.LightningModule):
def __init__(self, hidden_layer_size, n_classes):
super().__init__()
# Hyperparameters
self.hidden_layers_size = hidden_layer_size
self.n_classes = n_classes
# log hyperparameters
self.save_hyperparameters()
# Network
self.model = get_model_instance_segmentation(n_classes)
def forward(self, imgs):
# Torchvision FasterRCNN returns the loss during training
# and the boxes during eval
self.model.eval()
return self.model(imgs)
def training_step(self, batch, batch_idx):
# unpack batch
images, targets = batch
# enable train mode
self.model.train()
# fasterrcnn takes both images and targets for training
loss_dict = self.model(images, targets)
loss = sum(loss_dict.values())
return {"loss": loss, "log": loss_dict}
def validation_step(self, batch, batch_idx):
# unpack batch
images, targets = batch
# enable eval mode
self.detector.eval()
# make a prediction
preds = self.detector(images)
# compute validation loss
self.val_loss = torch.mean(
torch.stack(
[
self.accuracy(
target,
pred["boxes"],
iou_threshold=0.5,
)
for target, pred in zip(targets, preds)
],
)
)
return self.val_loss
def accuracy(self, src_boxes, pred_boxes, iou_threshold=1.0):
"""
The accuracy method is not the one used in the evaluator but very similar
"""
total_gt = len(src_boxes)
total_pred = len(pred_boxes)
if total_gt > 0 and total_pred > 0:
# Define the matcher and distance matrix based on iou
matcher = Matcher(iou_threshold, iou_threshold, allow_low_quality_matches=False)
match_quality_matrix = box_iou(src_boxes, pred_boxes)
results = matcher(match_quality_matrix)
true_positive = torch.count_nonzero(results.unique() != -1)
matched_elements = results[results > -1]
# in Matcher, a pred element can be matched only twice
false_positive = torch.count_nonzero(results == -1) + (
len(matched_elements) - len(matched_elements.unique())
)
false_negative = total_gt - true_positive
return true_positive / (true_positive + false_positive + false_negative)
elif total_gt == 0:
if total_pred > 0:
return torch.tensor(0.0).cuda()
else:
return torch.tensor(1.0).cuda()
elif total_gt > 0 and total_pred == 0:
return torch.tensor(0.0).cuda()
def configure_optimizers(self):
optimizer = torch.optim.SGD(
self.parameters(),
lr=wandb.config.LEARNING_RATE,
momentum=wandb.config.MOMENTUM,
weight_decay=wandb.config.WEIGHT_DECAY,
nesterov=wandb.config.NESTEROV,
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=3,
T_mult=1,
lr=wandb.config.LEARNING_RATE_MIN,
verbose=True,
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_accuracy",
},
}

View file

@ -1,11 +1,12 @@
import logging
import pytorch_lightning as pl
import wandb
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.loggers import WandbLogger
import wandb
from data import Spheres
from mrcnn import MRCNNModule
from unet import UNetModule
from utils import ArtifactLog, TableLog
@ -26,10 +27,15 @@ if __name__ == "__main__":
pl.seed_everything(69420, workers=True)
# Create network
model = UNetModule(
n_channels=wandb.config.N_CHANNELS,
n_classes=wandb.config.N_CLASSES,
features=wandb.config.FEATURES,
# model = UNetModule(
# n_channels=wandb.config.N_CHANNELS,
# n_classes=wandb.config.N_CLASSES,
# features=wandb.config.FEATURES,
# )
model = MRCNNModule(
hidden_layer_size=-1,
n_classes=2,
)
# load checkpoint
@ -48,6 +54,7 @@ if __name__ == "__main__":
max_epochs=wandb.config.EPOCHS,
accelerator=wandb.config.DEVICE,
benchmark=wandb.config.BENCHMARK,
deterministic=True,
precision=16,
logger=logger,
log_every_n_steps=1,

View file

@ -29,7 +29,7 @@ SPHERES:
value: 3
EPOCHS:
value: 1
value: 3
TRAIN_BATCH_SIZE:
value: 128 # 100
VAL_BATCH_SIZE: