feat: WIP, replacing U-Net by Mask R-CNN
Former-commit-id: f51a572adac901ff588e3a467f39ecd26376e617 [formerly 376595d7e5f906928379e25c1246e304b96b156d] Former-commit-id: 3f4772ba3483702be6e5f7a29f06be93eb1f3bb2
This commit is contained in:
parent
c50235bb1e
commit
4dab157dda
|
@ -1,8 +1,8 @@
|
|||
import albumentations as A
|
||||
import pytorch_lightning as pl
|
||||
import wandb
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
|
||||
import wandb
|
||||
from utils import RandomPaste
|
||||
|
||||
from .dataset import LabeledDataset, LabeledDataset2, SyntheticDataset
|
||||
|
@ -26,7 +26,7 @@ class Spheres(pl.LightningDataModule):
|
|||
|
||||
# dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform)
|
||||
|
||||
dataset = LabeledDataset2(image_dir="/media/disk1/lfainsin/TRAIN_prerender/")
|
||||
dataset = LabeledDataset2(image_dir="/media/disk1/lfainsin/TEST_tmp_mrcnn/")
|
||||
dataset = Subset(dataset, list(range(len(dataset)))) # somhow this allows to better utilize the gpu
|
||||
|
||||
return DataLoader(
|
||||
|
@ -38,15 +38,15 @@ class Spheres(pl.LightningDataModule):
|
|||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
||||
dataset = Subset(dataset, list(range(len(dataset)))) # somhow this allows to better utilize the gpu
|
||||
# def val_dataloader(self):
|
||||
# dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
||||
# dataset = Subset(dataset, list(range(len(dataset)))) # somhow this allows to better utilize the gpu
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
shuffle=False,
|
||||
prefetch_factor=wandb.config.PREFETCH_FACTOR,
|
||||
batch_size=wandb.config.VAL_BATCH_SIZE,
|
||||
num_workers=wandb.config.WORKERS,
|
||||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
)
|
||||
# return DataLoader(
|
||||
# dataset,
|
||||
# shuffle=False,
|
||||
# prefetch_factor=wandb.config.PREFETCH_FACTOR,
|
||||
# batch_size=wandb.config.VAL_BATCH_SIZE,
|
||||
# num_workers=wandb.config.WORKERS,
|
||||
# pin_memory=wandb.config.PIN_MEMORY,
|
||||
# )
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import albumentations as A
|
||||
import numpy as np
|
||||
import torch
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
|
@ -111,3 +113,66 @@ class LabeledDataset2(Dataset):
|
|||
mask = mask.float()
|
||||
|
||||
return image, mask
|
||||
|
||||
|
||||
class LabeledDataset3(object):
|
||||
def __init__(self, root, transforms):
|
||||
self.root = root
|
||||
self.transforms = transforms
|
||||
# load all image files, sorting them to ensure that they are aligned
|
||||
self.imgs = list(sorted(os.listdir(os.path.join(root, "images"))))
|
||||
self.masks = list(sorted(os.listdir(os.path.join(root, "masks"))))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# create paths from ids
|
||||
img_path = os.path.join(self.root, "images", self.imgs[idx])
|
||||
mask_path = os.path.join(self.root, "masks", self.masks[idx])
|
||||
|
||||
# load image and mask
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
mask = Image.open(mask_path)
|
||||
|
||||
# convert mask to numpy array to apply operations
|
||||
mask = np.array(mask)
|
||||
|
||||
obj_ids = np.unique(mask)
|
||||
obj_ids = obj_ids[1:] # first id is the background, so remove it
|
||||
|
||||
# split the color-encoded mask into a set of binary masks
|
||||
masks = mask == obj_ids[:, None, None]
|
||||
|
||||
# get bounding box coordinates for each mask
|
||||
num_objs = len(obj_ids)
|
||||
bboxes = []
|
||||
for i in range(num_objs):
|
||||
pos = np.where(masks[i])
|
||||
xmin = np.min(pos[1])
|
||||
xmax = np.max(pos[1])
|
||||
ymin = np.min(pos[0])
|
||||
ymax = np.max(pos[0])
|
||||
bboxes.append([xmin, ymin, xmax, ymax])
|
||||
|
||||
# convert arrays to tensors
|
||||
bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
|
||||
labels = torch.ones((num_objs,), dtype=torch.int64) # there is only one class
|
||||
masks = torch.as_tensor(masks, dtype=torch.uint8)
|
||||
|
||||
# image_id = torch.tensor([idx])
|
||||
# area = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
|
||||
# iscrowd = torch.zeros((num_objs,), dtype=torch.int64) # suppose all instances are not crowd
|
||||
|
||||
target = {}
|
||||
target["boxes"] = bboxes
|
||||
target["labels"] = labels
|
||||
target["masks"] = masks
|
||||
# target["image_id"] = image_id
|
||||
# target["area"] = area
|
||||
# target["iscrowd"] = iscrowd
|
||||
|
||||
if self.transforms is not None:
|
||||
img, target = self.transforms(img, target)
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgs)
|
||||
|
|
1
src/mrcnn/__init__.py
Normal file
1
src/mrcnn/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from .module import MRCNNModule
|
144
src/mrcnn/module.py
Normal file
144
src/mrcnn/module.py
Normal file
|
@ -0,0 +1,144 @@
|
|||
"""Pytorch lightning wrapper for model."""
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchvision
|
||||
import wandb
|
||||
from torchvision.models.detection._utils import Matcher
|
||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
||||
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
|
||||
from torchvision.ops.boxes import box_iou
|
||||
|
||||
|
||||
def get_model_instance_segmentation(num_classes):
|
||||
# load an instance segmentation model pre-trained on COCO
|
||||
model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")
|
||||
|
||||
# get number of input features for the classifier
|
||||
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
||||
# replace the pre-trained head with a new one
|
||||
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
||||
|
||||
# now get the number of input features for the mask classifier
|
||||
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
|
||||
hidden_layer = 256
|
||||
# and replace the mask predictor with a new one
|
||||
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class MRCNNModule(pl.LightningModule):
|
||||
def __init__(self, hidden_layer_size, n_classes):
|
||||
super().__init__()
|
||||
|
||||
# Hyperparameters
|
||||
self.hidden_layers_size = hidden_layer_size
|
||||
self.n_classes = n_classes
|
||||
|
||||
# log hyperparameters
|
||||
self.save_hyperparameters()
|
||||
|
||||
# Network
|
||||
self.model = get_model_instance_segmentation(n_classes)
|
||||
|
||||
def forward(self, imgs):
|
||||
# Torchvision FasterRCNN returns the loss during training
|
||||
# and the boxes during eval
|
||||
self.model.eval()
|
||||
return self.model(imgs)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# unpack batch
|
||||
images, targets = batch
|
||||
|
||||
# enable train mode
|
||||
self.model.train()
|
||||
|
||||
# fasterrcnn takes both images and targets for training
|
||||
loss_dict = self.model(images, targets)
|
||||
loss = sum(loss_dict.values())
|
||||
return {"loss": loss, "log": loss_dict}
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
# unpack batch
|
||||
images, targets = batch
|
||||
|
||||
# enable eval mode
|
||||
self.detector.eval()
|
||||
|
||||
# make a prediction
|
||||
preds = self.detector(images)
|
||||
|
||||
# compute validation loss
|
||||
self.val_loss = torch.mean(
|
||||
torch.stack(
|
||||
[
|
||||
self.accuracy(
|
||||
target,
|
||||
pred["boxes"],
|
||||
iou_threshold=0.5,
|
||||
)
|
||||
for target, pred in zip(targets, preds)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
return self.val_loss
|
||||
|
||||
def accuracy(self, src_boxes, pred_boxes, iou_threshold=1.0):
|
||||
"""
|
||||
The accuracy method is not the one used in the evaluator but very similar
|
||||
"""
|
||||
total_gt = len(src_boxes)
|
||||
total_pred = len(pred_boxes)
|
||||
if total_gt > 0 and total_pred > 0:
|
||||
|
||||
# Define the matcher and distance matrix based on iou
|
||||
matcher = Matcher(iou_threshold, iou_threshold, allow_low_quality_matches=False)
|
||||
match_quality_matrix = box_iou(src_boxes, pred_boxes)
|
||||
|
||||
results = matcher(match_quality_matrix)
|
||||
|
||||
true_positive = torch.count_nonzero(results.unique() != -1)
|
||||
matched_elements = results[results > -1]
|
||||
|
||||
# in Matcher, a pred element can be matched only twice
|
||||
false_positive = torch.count_nonzero(results == -1) + (
|
||||
len(matched_elements) - len(matched_elements.unique())
|
||||
)
|
||||
false_negative = total_gt - true_positive
|
||||
|
||||
return true_positive / (true_positive + false_positive + false_negative)
|
||||
|
||||
elif total_gt == 0:
|
||||
if total_pred > 0:
|
||||
return torch.tensor(0.0).cuda()
|
||||
else:
|
||||
return torch.tensor(1.0).cuda()
|
||||
elif total_gt > 0 and total_pred == 0:
|
||||
return torch.tensor(0.0).cuda()
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(
|
||||
self.parameters(),
|
||||
lr=wandb.config.LEARNING_RATE,
|
||||
momentum=wandb.config.MOMENTUM,
|
||||
weight_decay=wandb.config.WEIGHT_DECAY,
|
||||
nesterov=wandb.config.NESTEROV,
|
||||
)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||
optimizer,
|
||||
T_0=3,
|
||||
T_mult=1,
|
||||
lr=wandb.config.LEARNING_RATE_MIN,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": scheduler,
|
||||
"monitor": "val_accuracy",
|
||||
},
|
||||
}
|
17
src/train.py
17
src/train.py
|
@ -1,11 +1,12 @@
|
|||
import logging
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import wandb
|
||||
from pytorch_lightning.callbacks import RichProgressBar
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
import wandb
|
||||
from data import Spheres
|
||||
from mrcnn import MRCNNModule
|
||||
from unet import UNetModule
|
||||
from utils import ArtifactLog, TableLog
|
||||
|
||||
|
@ -26,10 +27,15 @@ if __name__ == "__main__":
|
|||
pl.seed_everything(69420, workers=True)
|
||||
|
||||
# Create network
|
||||
model = UNetModule(
|
||||
n_channels=wandb.config.N_CHANNELS,
|
||||
n_classes=wandb.config.N_CLASSES,
|
||||
features=wandb.config.FEATURES,
|
||||
# model = UNetModule(
|
||||
# n_channels=wandb.config.N_CHANNELS,
|
||||
# n_classes=wandb.config.N_CLASSES,
|
||||
# features=wandb.config.FEATURES,
|
||||
# )
|
||||
|
||||
model = MRCNNModule(
|
||||
hidden_layer_size=-1,
|
||||
n_classes=2,
|
||||
)
|
||||
|
||||
# load checkpoint
|
||||
|
@ -48,6 +54,7 @@ if __name__ == "__main__":
|
|||
max_epochs=wandb.config.EPOCHS,
|
||||
accelerator=wandb.config.DEVICE,
|
||||
benchmark=wandb.config.BENCHMARK,
|
||||
deterministic=True,
|
||||
precision=16,
|
||||
logger=logger,
|
||||
log_every_n_steps=1,
|
||||
|
|
|
@ -29,7 +29,7 @@ SPHERES:
|
|||
value: 3
|
||||
|
||||
EPOCHS:
|
||||
value: 1
|
||||
value: 3
|
||||
TRAIN_BATCH_SIZE:
|
||||
value: 128 # 100
|
||||
VAL_BATCH_SIZE:
|
||||
|
|
Loading…
Reference in a new issue