feat: WIP, replacing U-Net by Mask R-CNN

Former-commit-id: f51a572adac901ff588e3a467f39ecd26376e617 [formerly 376595d7e5f906928379e25c1246e304b96b156d]
Former-commit-id: 3f4772ba3483702be6e5f7a29f06be93eb1f3bb2
This commit is contained in:
Laurent Fainsin 2022-08-24 14:56:41 +02:00
parent c50235bb1e
commit 4dab157dda
6 changed files with 236 additions and 19 deletions

View file

@ -1,8 +1,8 @@
import albumentations as A import albumentations as A
import pytorch_lightning as pl import pytorch_lightning as pl
import wandb
from torch.utils.data import DataLoader, Subset from torch.utils.data import DataLoader, Subset
import wandb
from utils import RandomPaste from utils import RandomPaste
from .dataset import LabeledDataset, LabeledDataset2, SyntheticDataset from .dataset import LabeledDataset, LabeledDataset2, SyntheticDataset
@ -26,7 +26,7 @@ class Spheres(pl.LightningDataModule):
# dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform) # dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform)
dataset = LabeledDataset2(image_dir="/media/disk1/lfainsin/TRAIN_prerender/") dataset = LabeledDataset2(image_dir="/media/disk1/lfainsin/TEST_tmp_mrcnn/")
dataset = Subset(dataset, list(range(len(dataset)))) # somhow this allows to better utilize the gpu dataset = Subset(dataset, list(range(len(dataset)))) # somhow this allows to better utilize the gpu
return DataLoader( return DataLoader(
@ -38,15 +38,15 @@ class Spheres(pl.LightningDataModule):
pin_memory=wandb.config.PIN_MEMORY, pin_memory=wandb.config.PIN_MEMORY,
) )
def val_dataloader(self): # def val_dataloader(self):
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG) # dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
dataset = Subset(dataset, list(range(len(dataset)))) # somhow this allows to better utilize the gpu # dataset = Subset(dataset, list(range(len(dataset)))) # somhow this allows to better utilize the gpu
return DataLoader( # return DataLoader(
dataset, # dataset,
shuffle=False, # shuffle=False,
prefetch_factor=wandb.config.PREFETCH_FACTOR, # prefetch_factor=wandb.config.PREFETCH_FACTOR,
batch_size=wandb.config.VAL_BATCH_SIZE, # batch_size=wandb.config.VAL_BATCH_SIZE,
num_workers=wandb.config.WORKERS, # num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY, # pin_memory=wandb.config.PIN_MEMORY,
) # )

View file

@ -1,7 +1,9 @@
import os
from pathlib import Path from pathlib import Path
import albumentations as A import albumentations as A
import numpy as np import numpy as np
import torch
from albumentations.pytorch import ToTensorV2 from albumentations.pytorch import ToTensorV2
from PIL import Image from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset
@ -111,3 +113,66 @@ class LabeledDataset2(Dataset):
mask = mask.float() mask = mask.float()
return image, mask return image, mask
class LabeledDataset3(object):
def __init__(self, root, transforms):
self.root = root
self.transforms = transforms
# load all image files, sorting them to ensure that they are aligned
self.imgs = list(sorted(os.listdir(os.path.join(root, "images"))))
self.masks = list(sorted(os.listdir(os.path.join(root, "masks"))))
def __getitem__(self, idx):
# create paths from ids
img_path = os.path.join(self.root, "images", self.imgs[idx])
mask_path = os.path.join(self.root, "masks", self.masks[idx])
# load image and mask
img = Image.open(img_path).convert("RGB")
mask = Image.open(mask_path)
# convert mask to numpy array to apply operations
mask = np.array(mask)
obj_ids = np.unique(mask)
obj_ids = obj_ids[1:] # first id is the background, so remove it
# split the color-encoded mask into a set of binary masks
masks = mask == obj_ids[:, None, None]
# get bounding box coordinates for each mask
num_objs = len(obj_ids)
bboxes = []
for i in range(num_objs):
pos = np.where(masks[i])
xmin = np.min(pos[1])
xmax = np.max(pos[1])
ymin = np.min(pos[0])
ymax = np.max(pos[0])
bboxes.append([xmin, ymin, xmax, ymax])
# convert arrays to tensors
bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
labels = torch.ones((num_objs,), dtype=torch.int64) # there is only one class
masks = torch.as_tensor(masks, dtype=torch.uint8)
# image_id = torch.tensor([idx])
# area = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
# iscrowd = torch.zeros((num_objs,), dtype=torch.int64) # suppose all instances are not crowd
target = {}
target["boxes"] = bboxes
target["labels"] = labels
target["masks"] = masks
# target["image_id"] = image_id
# target["area"] = area
# target["iscrowd"] = iscrowd
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
def __len__(self):
return len(self.imgs)

1
src/mrcnn/__init__.py Normal file
View file

@ -0,0 +1 @@
from .module import MRCNNModule

144
src/mrcnn/module.py Normal file
View file

@ -0,0 +1,144 @@
"""Pytorch lightning wrapper for model."""
import pytorch_lightning as pl
import torch
import torchvision
import wandb
from torchvision.models.detection._utils import Matcher
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.ops.boxes import box_iou
def get_model_instance_segmentation(num_classes):
# load an instance segmentation model pre-trained on COCO
model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# now get the number of input features for the mask classifier
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
# and replace the mask predictor with a new one
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
return model
class MRCNNModule(pl.LightningModule):
def __init__(self, hidden_layer_size, n_classes):
super().__init__()
# Hyperparameters
self.hidden_layers_size = hidden_layer_size
self.n_classes = n_classes
# log hyperparameters
self.save_hyperparameters()
# Network
self.model = get_model_instance_segmentation(n_classes)
def forward(self, imgs):
# Torchvision FasterRCNN returns the loss during training
# and the boxes during eval
self.model.eval()
return self.model(imgs)
def training_step(self, batch, batch_idx):
# unpack batch
images, targets = batch
# enable train mode
self.model.train()
# fasterrcnn takes both images and targets for training
loss_dict = self.model(images, targets)
loss = sum(loss_dict.values())
return {"loss": loss, "log": loss_dict}
def validation_step(self, batch, batch_idx):
# unpack batch
images, targets = batch
# enable eval mode
self.detector.eval()
# make a prediction
preds = self.detector(images)
# compute validation loss
self.val_loss = torch.mean(
torch.stack(
[
self.accuracy(
target,
pred["boxes"],
iou_threshold=0.5,
)
for target, pred in zip(targets, preds)
],
)
)
return self.val_loss
def accuracy(self, src_boxes, pred_boxes, iou_threshold=1.0):
"""
The accuracy method is not the one used in the evaluator but very similar
"""
total_gt = len(src_boxes)
total_pred = len(pred_boxes)
if total_gt > 0 and total_pred > 0:
# Define the matcher and distance matrix based on iou
matcher = Matcher(iou_threshold, iou_threshold, allow_low_quality_matches=False)
match_quality_matrix = box_iou(src_boxes, pred_boxes)
results = matcher(match_quality_matrix)
true_positive = torch.count_nonzero(results.unique() != -1)
matched_elements = results[results > -1]
# in Matcher, a pred element can be matched only twice
false_positive = torch.count_nonzero(results == -1) + (
len(matched_elements) - len(matched_elements.unique())
)
false_negative = total_gt - true_positive
return true_positive / (true_positive + false_positive + false_negative)
elif total_gt == 0:
if total_pred > 0:
return torch.tensor(0.0).cuda()
else:
return torch.tensor(1.0).cuda()
elif total_gt > 0 and total_pred == 0:
return torch.tensor(0.0).cuda()
def configure_optimizers(self):
optimizer = torch.optim.SGD(
self.parameters(),
lr=wandb.config.LEARNING_RATE,
momentum=wandb.config.MOMENTUM,
weight_decay=wandb.config.WEIGHT_DECAY,
nesterov=wandb.config.NESTEROV,
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=3,
T_mult=1,
lr=wandb.config.LEARNING_RATE_MIN,
verbose=True,
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_accuracy",
},
}

View file

@ -1,11 +1,12 @@
import logging import logging
import pytorch_lightning as pl import pytorch_lightning as pl
import wandb
from pytorch_lightning.callbacks import RichProgressBar from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
import wandb
from data import Spheres from data import Spheres
from mrcnn import MRCNNModule
from unet import UNetModule from unet import UNetModule
from utils import ArtifactLog, TableLog from utils import ArtifactLog, TableLog
@ -26,10 +27,15 @@ if __name__ == "__main__":
pl.seed_everything(69420, workers=True) pl.seed_everything(69420, workers=True)
# Create network # Create network
model = UNetModule( # model = UNetModule(
n_channels=wandb.config.N_CHANNELS, # n_channels=wandb.config.N_CHANNELS,
n_classes=wandb.config.N_CLASSES, # n_classes=wandb.config.N_CLASSES,
features=wandb.config.FEATURES, # features=wandb.config.FEATURES,
# )
model = MRCNNModule(
hidden_layer_size=-1,
n_classes=2,
) )
# load checkpoint # load checkpoint
@ -48,6 +54,7 @@ if __name__ == "__main__":
max_epochs=wandb.config.EPOCHS, max_epochs=wandb.config.EPOCHS,
accelerator=wandb.config.DEVICE, accelerator=wandb.config.DEVICE,
benchmark=wandb.config.BENCHMARK, benchmark=wandb.config.BENCHMARK,
deterministic=True,
precision=16, precision=16,
logger=logger, logger=logger,
log_every_n_steps=1, log_every_n_steps=1,

View file

@ -29,7 +29,7 @@ SPHERES:
value: 3 value: 3
EPOCHS: EPOCHS:
value: 1 value: 3
TRAIN_BATCH_SIZE: TRAIN_BATCH_SIZE:
value: 128 # 100 value: 128 # 100
VAL_BATCH_SIZE: VAL_BATCH_SIZE: