From 4dab157ddaa26fa7b585b2b72e6c871a65addec0 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Wed, 24 Aug 2022 14:56:41 +0200 Subject: [PATCH] feat: WIP, replacing U-Net by Mask R-CNN Former-commit-id: f51a572adac901ff588e3a467f39ecd26376e617 [formerly 376595d7e5f906928379e25c1246e304b96b156d] Former-commit-id: 3f4772ba3483702be6e5f7a29f06be93eb1f3bb2 --- src/data/dataloader.py | 26 ++++---- src/data/dataset.py | 65 +++++++++++++++++++ src/mrcnn/__init__.py | 1 + src/mrcnn/module.py | 144 +++++++++++++++++++++++++++++++++++++++++ src/train.py | 17 +++-- wandb.yaml | 2 +- 6 files changed, 236 insertions(+), 19 deletions(-) create mode 100644 src/mrcnn/__init__.py create mode 100644 src/mrcnn/module.py diff --git a/src/data/dataloader.py b/src/data/dataloader.py index 1ff1d36..80d4b2d 100644 --- a/src/data/dataloader.py +++ b/src/data/dataloader.py @@ -1,8 +1,8 @@ import albumentations as A import pytorch_lightning as pl +import wandb from torch.utils.data import DataLoader, Subset -import wandb from utils import RandomPaste from .dataset import LabeledDataset, LabeledDataset2, SyntheticDataset @@ -26,7 +26,7 @@ class Spheres(pl.LightningDataModule): # dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform) - dataset = LabeledDataset2(image_dir="/media/disk1/lfainsin/TRAIN_prerender/") + dataset = LabeledDataset2(image_dir="/media/disk1/lfainsin/TEST_tmp_mrcnn/") dataset = Subset(dataset, list(range(len(dataset)))) # somhow this allows to better utilize the gpu return DataLoader( @@ -38,15 +38,15 @@ class Spheres(pl.LightningDataModule): pin_memory=wandb.config.PIN_MEMORY, ) - def val_dataloader(self): - dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG) - dataset = Subset(dataset, list(range(len(dataset)))) # somhow this allows to better utilize the gpu + # def val_dataloader(self): + # dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG) + # dataset = Subset(dataset, list(range(len(dataset)))) # somhow this allows to better utilize the gpu - return DataLoader( - dataset, - shuffle=False, - prefetch_factor=wandb.config.PREFETCH_FACTOR, - batch_size=wandb.config.VAL_BATCH_SIZE, - num_workers=wandb.config.WORKERS, - pin_memory=wandb.config.PIN_MEMORY, - ) + # return DataLoader( + # dataset, + # shuffle=False, + # prefetch_factor=wandb.config.PREFETCH_FACTOR, + # batch_size=wandb.config.VAL_BATCH_SIZE, + # num_workers=wandb.config.WORKERS, + # pin_memory=wandb.config.PIN_MEMORY, + # ) diff --git a/src/data/dataset.py b/src/data/dataset.py index 455cc5d..f062ec4 100644 --- a/src/data/dataset.py +++ b/src/data/dataset.py @@ -1,7 +1,9 @@ +import os from pathlib import Path import albumentations as A import numpy as np +import torch from albumentations.pytorch import ToTensorV2 from PIL import Image from torch.utils.data import Dataset @@ -111,3 +113,66 @@ class LabeledDataset2(Dataset): mask = mask.float() return image, mask + + +class LabeledDataset3(object): + def __init__(self, root, transforms): + self.root = root + self.transforms = transforms + # load all image files, sorting them to ensure that they are aligned + self.imgs = list(sorted(os.listdir(os.path.join(root, "images")))) + self.masks = list(sorted(os.listdir(os.path.join(root, "masks")))) + + def __getitem__(self, idx): + # create paths from ids + img_path = os.path.join(self.root, "images", self.imgs[idx]) + mask_path = os.path.join(self.root, "masks", self.masks[idx]) + + # load image and mask + img = Image.open(img_path).convert("RGB") + mask = Image.open(mask_path) + + # convert mask to numpy array to apply operations + mask = np.array(mask) + + obj_ids = np.unique(mask) + obj_ids = obj_ids[1:] # first id is the background, so remove it + + # split the color-encoded mask into a set of binary masks + masks = mask == obj_ids[:, None, None] + + # get bounding box coordinates for each mask + num_objs = len(obj_ids) + bboxes = [] + for i in range(num_objs): + pos = np.where(masks[i]) + xmin = np.min(pos[1]) + xmax = np.max(pos[1]) + ymin = np.min(pos[0]) + ymax = np.max(pos[0]) + bboxes.append([xmin, ymin, xmax, ymax]) + + # convert arrays to tensors + bboxes = torch.as_tensor(bboxes, dtype=torch.float32) + labels = torch.ones((num_objs,), dtype=torch.int64) # there is only one class + masks = torch.as_tensor(masks, dtype=torch.uint8) + + # image_id = torch.tensor([idx]) + # area = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0]) + # iscrowd = torch.zeros((num_objs,), dtype=torch.int64) # suppose all instances are not crowd + + target = {} + target["boxes"] = bboxes + target["labels"] = labels + target["masks"] = masks + # target["image_id"] = image_id + # target["area"] = area + # target["iscrowd"] = iscrowd + + if self.transforms is not None: + img, target = self.transforms(img, target) + + return img, target + + def __len__(self): + return len(self.imgs) diff --git a/src/mrcnn/__init__.py b/src/mrcnn/__init__.py new file mode 100644 index 0000000..8629291 --- /dev/null +++ b/src/mrcnn/__init__.py @@ -0,0 +1 @@ +from .module import MRCNNModule diff --git a/src/mrcnn/module.py b/src/mrcnn/module.py new file mode 100644 index 0000000..ddf8826 --- /dev/null +++ b/src/mrcnn/module.py @@ -0,0 +1,144 @@ +"""Pytorch lightning wrapper for model.""" + +import pytorch_lightning as pl +import torch +import torchvision +import wandb +from torchvision.models.detection._utils import Matcher +from torchvision.models.detection.faster_rcnn import FastRCNNPredictor +from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor +from torchvision.ops.boxes import box_iou + + +def get_model_instance_segmentation(num_classes): + # load an instance segmentation model pre-trained on COCO + model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT") + + # get number of input features for the classifier + in_features = model.roi_heads.box_predictor.cls_score.in_features + # replace the pre-trained head with a new one + model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) + + # now get the number of input features for the mask classifier + in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels + hidden_layer = 256 + # and replace the mask predictor with a new one + model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes) + + return model + + +class MRCNNModule(pl.LightningModule): + def __init__(self, hidden_layer_size, n_classes): + super().__init__() + + # Hyperparameters + self.hidden_layers_size = hidden_layer_size + self.n_classes = n_classes + + # log hyperparameters + self.save_hyperparameters() + + # Network + self.model = get_model_instance_segmentation(n_classes) + + def forward(self, imgs): + # Torchvision FasterRCNN returns the loss during training + # and the boxes during eval + self.model.eval() + return self.model(imgs) + + def training_step(self, batch, batch_idx): + # unpack batch + images, targets = batch + + # enable train mode + self.model.train() + + # fasterrcnn takes both images and targets for training + loss_dict = self.model(images, targets) + loss = sum(loss_dict.values()) + return {"loss": loss, "log": loss_dict} + + def validation_step(self, batch, batch_idx): + # unpack batch + images, targets = batch + + # enable eval mode + self.detector.eval() + + # make a prediction + preds = self.detector(images) + + # compute validation loss + self.val_loss = torch.mean( + torch.stack( + [ + self.accuracy( + target, + pred["boxes"], + iou_threshold=0.5, + ) + for target, pred in zip(targets, preds) + ], + ) + ) + + return self.val_loss + + def accuracy(self, src_boxes, pred_boxes, iou_threshold=1.0): + """ + The accuracy method is not the one used in the evaluator but very similar + """ + total_gt = len(src_boxes) + total_pred = len(pred_boxes) + if total_gt > 0 and total_pred > 0: + + # Define the matcher and distance matrix based on iou + matcher = Matcher(iou_threshold, iou_threshold, allow_low_quality_matches=False) + match_quality_matrix = box_iou(src_boxes, pred_boxes) + + results = matcher(match_quality_matrix) + + true_positive = torch.count_nonzero(results.unique() != -1) + matched_elements = results[results > -1] + + # in Matcher, a pred element can be matched only twice + false_positive = torch.count_nonzero(results == -1) + ( + len(matched_elements) - len(matched_elements.unique()) + ) + false_negative = total_gt - true_positive + + return true_positive / (true_positive + false_positive + false_negative) + + elif total_gt == 0: + if total_pred > 0: + return torch.tensor(0.0).cuda() + else: + return torch.tensor(1.0).cuda() + elif total_gt > 0 and total_pred == 0: + return torch.tensor(0.0).cuda() + + def configure_optimizers(self): + optimizer = torch.optim.SGD( + self.parameters(), + lr=wandb.config.LEARNING_RATE, + momentum=wandb.config.MOMENTUM, + weight_decay=wandb.config.WEIGHT_DECAY, + nesterov=wandb.config.NESTEROV, + ) + scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, + T_0=3, + T_mult=1, + lr=wandb.config.LEARNING_RATE_MIN, + verbose=True, + ) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val_accuracy", + }, + } diff --git a/src/train.py b/src/train.py index 63b14e5..86d85e6 100644 --- a/src/train.py +++ b/src/train.py @@ -1,11 +1,12 @@ import logging import pytorch_lightning as pl +import wandb from pytorch_lightning.callbacks import RichProgressBar from pytorch_lightning.loggers import WandbLogger -import wandb from data import Spheres +from mrcnn import MRCNNModule from unet import UNetModule from utils import ArtifactLog, TableLog @@ -26,10 +27,15 @@ if __name__ == "__main__": pl.seed_everything(69420, workers=True) # Create network - model = UNetModule( - n_channels=wandb.config.N_CHANNELS, - n_classes=wandb.config.N_CLASSES, - features=wandb.config.FEATURES, + # model = UNetModule( + # n_channels=wandb.config.N_CHANNELS, + # n_classes=wandb.config.N_CLASSES, + # features=wandb.config.FEATURES, + # ) + + model = MRCNNModule( + hidden_layer_size=-1, + n_classes=2, ) # load checkpoint @@ -48,6 +54,7 @@ if __name__ == "__main__": max_epochs=wandb.config.EPOCHS, accelerator=wandb.config.DEVICE, benchmark=wandb.config.BENCHMARK, + deterministic=True, precision=16, logger=logger, log_every_n_steps=1, diff --git a/wandb.yaml b/wandb.yaml index 4746243..c26bf50 100644 --- a/wandb.yaml +++ b/wandb.yaml @@ -29,7 +29,7 @@ SPHERES: value: 3 EPOCHS: - value: 1 + value: 3 TRAIN_BATCH_SIZE: value: 128 # 100 VAL_BATCH_SIZE: