feat: WIP, replacing U-Net by Mask R-CNN
Former-commit-id: f51a572adac901ff588e3a467f39ecd26376e617 [formerly 376595d7e5f906928379e25c1246e304b96b156d] Former-commit-id: 3f4772ba3483702be6e5f7a29f06be93eb1f3bb2
This commit is contained in:
parent
c50235bb1e
commit
4dab157dda
|
@ -1,8 +1,8 @@
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
|
import wandb
|
||||||
from torch.utils.data import DataLoader, Subset
|
from torch.utils.data import DataLoader, Subset
|
||||||
|
|
||||||
import wandb
|
|
||||||
from utils import RandomPaste
|
from utils import RandomPaste
|
||||||
|
|
||||||
from .dataset import LabeledDataset, LabeledDataset2, SyntheticDataset
|
from .dataset import LabeledDataset, LabeledDataset2, SyntheticDataset
|
||||||
|
@ -26,7 +26,7 @@ class Spheres(pl.LightningDataModule):
|
||||||
|
|
||||||
# dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform)
|
# dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform)
|
||||||
|
|
||||||
dataset = LabeledDataset2(image_dir="/media/disk1/lfainsin/TRAIN_prerender/")
|
dataset = LabeledDataset2(image_dir="/media/disk1/lfainsin/TEST_tmp_mrcnn/")
|
||||||
dataset = Subset(dataset, list(range(len(dataset)))) # somhow this allows to better utilize the gpu
|
dataset = Subset(dataset, list(range(len(dataset)))) # somhow this allows to better utilize the gpu
|
||||||
|
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
|
@ -38,15 +38,15 @@ class Spheres(pl.LightningDataModule):
|
||||||
pin_memory=wandb.config.PIN_MEMORY,
|
pin_memory=wandb.config.PIN_MEMORY,
|
||||||
)
|
)
|
||||||
|
|
||||||
def val_dataloader(self):
|
# def val_dataloader(self):
|
||||||
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
# dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
||||||
dataset = Subset(dataset, list(range(len(dataset)))) # somhow this allows to better utilize the gpu
|
# dataset = Subset(dataset, list(range(len(dataset)))) # somhow this allows to better utilize the gpu
|
||||||
|
|
||||||
return DataLoader(
|
# return DataLoader(
|
||||||
dataset,
|
# dataset,
|
||||||
shuffle=False,
|
# shuffle=False,
|
||||||
prefetch_factor=wandb.config.PREFETCH_FACTOR,
|
# prefetch_factor=wandb.config.PREFETCH_FACTOR,
|
||||||
batch_size=wandb.config.VAL_BATCH_SIZE,
|
# batch_size=wandb.config.VAL_BATCH_SIZE,
|
||||||
num_workers=wandb.config.WORKERS,
|
# num_workers=wandb.config.WORKERS,
|
||||||
pin_memory=wandb.config.PIN_MEMORY,
|
# pin_memory=wandb.config.PIN_MEMORY,
|
||||||
)
|
# )
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from albumentations.pytorch import ToTensorV2
|
from albumentations.pytorch import ToTensorV2
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
@ -111,3 +113,66 @@ class LabeledDataset2(Dataset):
|
||||||
mask = mask.float()
|
mask = mask.float()
|
||||||
|
|
||||||
return image, mask
|
return image, mask
|
||||||
|
|
||||||
|
|
||||||
|
class LabeledDataset3(object):
|
||||||
|
def __init__(self, root, transforms):
|
||||||
|
self.root = root
|
||||||
|
self.transforms = transforms
|
||||||
|
# load all image files, sorting them to ensure that they are aligned
|
||||||
|
self.imgs = list(sorted(os.listdir(os.path.join(root, "images"))))
|
||||||
|
self.masks = list(sorted(os.listdir(os.path.join(root, "masks"))))
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# create paths from ids
|
||||||
|
img_path = os.path.join(self.root, "images", self.imgs[idx])
|
||||||
|
mask_path = os.path.join(self.root, "masks", self.masks[idx])
|
||||||
|
|
||||||
|
# load image and mask
|
||||||
|
img = Image.open(img_path).convert("RGB")
|
||||||
|
mask = Image.open(mask_path)
|
||||||
|
|
||||||
|
# convert mask to numpy array to apply operations
|
||||||
|
mask = np.array(mask)
|
||||||
|
|
||||||
|
obj_ids = np.unique(mask)
|
||||||
|
obj_ids = obj_ids[1:] # first id is the background, so remove it
|
||||||
|
|
||||||
|
# split the color-encoded mask into a set of binary masks
|
||||||
|
masks = mask == obj_ids[:, None, None]
|
||||||
|
|
||||||
|
# get bounding box coordinates for each mask
|
||||||
|
num_objs = len(obj_ids)
|
||||||
|
bboxes = []
|
||||||
|
for i in range(num_objs):
|
||||||
|
pos = np.where(masks[i])
|
||||||
|
xmin = np.min(pos[1])
|
||||||
|
xmax = np.max(pos[1])
|
||||||
|
ymin = np.min(pos[0])
|
||||||
|
ymax = np.max(pos[0])
|
||||||
|
bboxes.append([xmin, ymin, xmax, ymax])
|
||||||
|
|
||||||
|
# convert arrays to tensors
|
||||||
|
bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
|
||||||
|
labels = torch.ones((num_objs,), dtype=torch.int64) # there is only one class
|
||||||
|
masks = torch.as_tensor(masks, dtype=torch.uint8)
|
||||||
|
|
||||||
|
# image_id = torch.tensor([idx])
|
||||||
|
# area = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
|
||||||
|
# iscrowd = torch.zeros((num_objs,), dtype=torch.int64) # suppose all instances are not crowd
|
||||||
|
|
||||||
|
target = {}
|
||||||
|
target["boxes"] = bboxes
|
||||||
|
target["labels"] = labels
|
||||||
|
target["masks"] = masks
|
||||||
|
# target["image_id"] = image_id
|
||||||
|
# target["area"] = area
|
||||||
|
# target["iscrowd"] = iscrowd
|
||||||
|
|
||||||
|
if self.transforms is not None:
|
||||||
|
img, target = self.transforms(img, target)
|
||||||
|
|
||||||
|
return img, target
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.imgs)
|
||||||
|
|
1
src/mrcnn/__init__.py
Normal file
1
src/mrcnn/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
from .module import MRCNNModule
|
144
src/mrcnn/module.py
Normal file
144
src/mrcnn/module.py
Normal file
|
@ -0,0 +1,144 @@
|
||||||
|
"""Pytorch lightning wrapper for model."""
|
||||||
|
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
import wandb
|
||||||
|
from torchvision.models.detection._utils import Matcher
|
||||||
|
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
||||||
|
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
|
||||||
|
from torchvision.ops.boxes import box_iou
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_instance_segmentation(num_classes):
|
||||||
|
# load an instance segmentation model pre-trained on COCO
|
||||||
|
model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")
|
||||||
|
|
||||||
|
# get number of input features for the classifier
|
||||||
|
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
||||||
|
# replace the pre-trained head with a new one
|
||||||
|
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
||||||
|
|
||||||
|
# now get the number of input features for the mask classifier
|
||||||
|
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
|
||||||
|
hidden_layer = 256
|
||||||
|
# and replace the mask predictor with a new one
|
||||||
|
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class MRCNNModule(pl.LightningModule):
|
||||||
|
def __init__(self, hidden_layer_size, n_classes):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
self.hidden_layers_size = hidden_layer_size
|
||||||
|
self.n_classes = n_classes
|
||||||
|
|
||||||
|
# log hyperparameters
|
||||||
|
self.save_hyperparameters()
|
||||||
|
|
||||||
|
# Network
|
||||||
|
self.model = get_model_instance_segmentation(n_classes)
|
||||||
|
|
||||||
|
def forward(self, imgs):
|
||||||
|
# Torchvision FasterRCNN returns the loss during training
|
||||||
|
# and the boxes during eval
|
||||||
|
self.model.eval()
|
||||||
|
return self.model(imgs)
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
# unpack batch
|
||||||
|
images, targets = batch
|
||||||
|
|
||||||
|
# enable train mode
|
||||||
|
self.model.train()
|
||||||
|
|
||||||
|
# fasterrcnn takes both images and targets for training
|
||||||
|
loss_dict = self.model(images, targets)
|
||||||
|
loss = sum(loss_dict.values())
|
||||||
|
return {"loss": loss, "log": loss_dict}
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
# unpack batch
|
||||||
|
images, targets = batch
|
||||||
|
|
||||||
|
# enable eval mode
|
||||||
|
self.detector.eval()
|
||||||
|
|
||||||
|
# make a prediction
|
||||||
|
preds = self.detector(images)
|
||||||
|
|
||||||
|
# compute validation loss
|
||||||
|
self.val_loss = torch.mean(
|
||||||
|
torch.stack(
|
||||||
|
[
|
||||||
|
self.accuracy(
|
||||||
|
target,
|
||||||
|
pred["boxes"],
|
||||||
|
iou_threshold=0.5,
|
||||||
|
)
|
||||||
|
for target, pred in zip(targets, preds)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.val_loss
|
||||||
|
|
||||||
|
def accuracy(self, src_boxes, pred_boxes, iou_threshold=1.0):
|
||||||
|
"""
|
||||||
|
The accuracy method is not the one used in the evaluator but very similar
|
||||||
|
"""
|
||||||
|
total_gt = len(src_boxes)
|
||||||
|
total_pred = len(pred_boxes)
|
||||||
|
if total_gt > 0 and total_pred > 0:
|
||||||
|
|
||||||
|
# Define the matcher and distance matrix based on iou
|
||||||
|
matcher = Matcher(iou_threshold, iou_threshold, allow_low_quality_matches=False)
|
||||||
|
match_quality_matrix = box_iou(src_boxes, pred_boxes)
|
||||||
|
|
||||||
|
results = matcher(match_quality_matrix)
|
||||||
|
|
||||||
|
true_positive = torch.count_nonzero(results.unique() != -1)
|
||||||
|
matched_elements = results[results > -1]
|
||||||
|
|
||||||
|
# in Matcher, a pred element can be matched only twice
|
||||||
|
false_positive = torch.count_nonzero(results == -1) + (
|
||||||
|
len(matched_elements) - len(matched_elements.unique())
|
||||||
|
)
|
||||||
|
false_negative = total_gt - true_positive
|
||||||
|
|
||||||
|
return true_positive / (true_positive + false_positive + false_negative)
|
||||||
|
|
||||||
|
elif total_gt == 0:
|
||||||
|
if total_pred > 0:
|
||||||
|
return torch.tensor(0.0).cuda()
|
||||||
|
else:
|
||||||
|
return torch.tensor(1.0).cuda()
|
||||||
|
elif total_gt > 0 and total_pred == 0:
|
||||||
|
return torch.tensor(0.0).cuda()
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = torch.optim.SGD(
|
||||||
|
self.parameters(),
|
||||||
|
lr=wandb.config.LEARNING_RATE,
|
||||||
|
momentum=wandb.config.MOMENTUM,
|
||||||
|
weight_decay=wandb.config.WEIGHT_DECAY,
|
||||||
|
nesterov=wandb.config.NESTEROV,
|
||||||
|
)
|
||||||
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||||
|
optimizer,
|
||||||
|
T_0=3,
|
||||||
|
T_mult=1,
|
||||||
|
lr=wandb.config.LEARNING_RATE_MIN,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"optimizer": optimizer,
|
||||||
|
"lr_scheduler": {
|
||||||
|
"scheduler": scheduler,
|
||||||
|
"monitor": "val_accuracy",
|
||||||
|
},
|
||||||
|
}
|
17
src/train.py
17
src/train.py
|
@ -1,11 +1,12 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
|
import wandb
|
||||||
from pytorch_lightning.callbacks import RichProgressBar
|
from pytorch_lightning.callbacks import RichProgressBar
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
|
||||||
import wandb
|
|
||||||
from data import Spheres
|
from data import Spheres
|
||||||
|
from mrcnn import MRCNNModule
|
||||||
from unet import UNetModule
|
from unet import UNetModule
|
||||||
from utils import ArtifactLog, TableLog
|
from utils import ArtifactLog, TableLog
|
||||||
|
|
||||||
|
@ -26,10 +27,15 @@ if __name__ == "__main__":
|
||||||
pl.seed_everything(69420, workers=True)
|
pl.seed_everything(69420, workers=True)
|
||||||
|
|
||||||
# Create network
|
# Create network
|
||||||
model = UNetModule(
|
# model = UNetModule(
|
||||||
n_channels=wandb.config.N_CHANNELS,
|
# n_channels=wandb.config.N_CHANNELS,
|
||||||
n_classes=wandb.config.N_CLASSES,
|
# n_classes=wandb.config.N_CLASSES,
|
||||||
features=wandb.config.FEATURES,
|
# features=wandb.config.FEATURES,
|
||||||
|
# )
|
||||||
|
|
||||||
|
model = MRCNNModule(
|
||||||
|
hidden_layer_size=-1,
|
||||||
|
n_classes=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
# load checkpoint
|
# load checkpoint
|
||||||
|
@ -48,6 +54,7 @@ if __name__ == "__main__":
|
||||||
max_epochs=wandb.config.EPOCHS,
|
max_epochs=wandb.config.EPOCHS,
|
||||||
accelerator=wandb.config.DEVICE,
|
accelerator=wandb.config.DEVICE,
|
||||||
benchmark=wandb.config.BENCHMARK,
|
benchmark=wandb.config.BENCHMARK,
|
||||||
|
deterministic=True,
|
||||||
precision=16,
|
precision=16,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
log_every_n_steps=1,
|
log_every_n_steps=1,
|
||||||
|
|
|
@ -29,7 +29,7 @@ SPHERES:
|
||||||
value: 3
|
value: 3
|
||||||
|
|
||||||
EPOCHS:
|
EPOCHS:
|
||||||
value: 1
|
value: 3
|
||||||
TRAIN_BATCH_SIZE:
|
TRAIN_BATCH_SIZE:
|
||||||
value: 128 # 100
|
value: 128 # 100
|
||||||
VAL_BATCH_SIZE:
|
VAL_BATCH_SIZE:
|
||||||
|
|
Loading…
Reference in a new issue