diff --git a/src/data/dataloader.py b/src/data/dataloader.py index 80d4b2d..c08f9c4 100644 --- a/src/data/dataloader.py +++ b/src/data/dataloader.py @@ -1,11 +1,15 @@ import albumentations as A import pytorch_lightning as pl -import wandb +from albumentations.pytorch import ToTensorV2 from torch.utils.data import DataLoader, Subset -from utils import RandomPaste +import wandb -from .dataset import LabeledDataset, LabeledDataset2, SyntheticDataset +from .dataset import RealDataset + + +def collate_fn(batch): + return tuple(zip(*batch)) class Spheres(pl.LightningDataModule): @@ -13,21 +17,22 @@ class Spheres(pl.LightningDataModule): super().__init__() def train_dataloader(self): - # transform = A.Compose( - # [ - # A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE), - # A.Flip(), - # A.ColorJitter(), - # RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE), - # A.GaussianBlur(), - # A.ISONoise(), - # ], - # ) + transforms = A.Compose( + [ + A.ToFloat(max_value=255), + ToTensorV2(), + ], + bbox_params=A.BboxParams( + format="pascal_voc", + min_area=0.0, + min_visibility=0.0, + label_fields=["labels"], + ), + ) - # dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform) - - dataset = LabeledDataset2(image_dir="/media/disk1/lfainsin/TEST_tmp_mrcnn/") - dataset = Subset(dataset, list(range(len(dataset)))) # somhow this allows to better utilize the gpu + dataset = RealDataset(root="/media/disk1/lfainsin/TEST_tmp_mrcnn/", transforms=transforms) + print(f"len(dataset)={len(dataset)}") + dataset = Subset(dataset, list(range(len(dataset)))) # somehow this allows to better utilize the gpu return DataLoader( dataset, @@ -36,11 +41,12 @@ class Spheres(pl.LightningDataModule): batch_size=wandb.config.TRAIN_BATCH_SIZE, num_workers=wandb.config.WORKERS, pin_memory=wandb.config.PIN_MEMORY, + collate_fn=collate_fn, ) # def val_dataloader(self): # dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG) - # dataset = Subset(dataset, list(range(len(dataset)))) # somhow this allows to better utilize the gpu + # dataset = Subset(dataset, list(range(len(dataset)))) # somehow this allows to better utilize the gpu # return DataLoader( # dataset, diff --git a/src/data/dataset.py b/src/data/dataset.py index f062ec4..4b8fff3 100644 --- a/src/data/dataset.py +++ b/src/data/dataset.py @@ -1,147 +1,42 @@ import os from pathlib import Path -import albumentations as A import numpy as np import torch -from albumentations.pytorch import ToTensorV2 from PIL import Image from torch.utils.data import Dataset -class SyntheticDataset(Dataset): - def __init__(self, image_dir, transform): - self.images = list(Path(image_dir).glob("**/*.jpg")) - self.transform = transform - - def __len__(self): - return len(self.images) - - def __getitem__(self, index): - # open and convert image - image = np.array(Image.open(self.images[index]).convert("RGB"), dtype=np.uint8) - - # create empty mask of same size - mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) - - # augment image and mask - augmentations = self.transform(image=image, mask=mask) - image = augmentations["image"] - mask = augmentations["mask"] - - # convert image & mask to Tensor float in [0, 1] - post_process = A.Compose( - [ - A.ToFloat(max_value=255), - ToTensorV2(), - ], - ) - augmentations = post_process(image=image, mask=mask) - image = augmentations["image"] - mask = augmentations["mask"] - - # make sure image and mask are floats - image = image.float() - mask = mask.float() - - return image, mask - - -class LabeledDataset(Dataset): - def __init__(self, image_dir): - self.images = list(Path(image_dir).glob("**/*.jpg")) - - def __len__(self): - return len(self.images) - - def __getitem__(self, index): - # open and convert image - image = np.array(Image.open(self.images[index]).convert("RGB"), dtype=np.uint8) - - # open and convert mask - mask_path = self.images[index].parent.joinpath("MASK.PNG") - mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8) // 255 - - # convert image & mask to Tensor float in [0, 1] - post_process = A.Compose( - [ - A.SmallestMaxSize(1024), - A.ToFloat(max_value=255), - ToTensorV2(), - ], - ) - augmentations = post_process(image=image, mask=mask) - image = augmentations["image"] - mask = augmentations["mask"] - - # make sure image and mask are floats, TODO: mettre dans le post_process, ToFloat Image only - image = image.float() - mask = mask.float() - - return image, mask - - -class LabeledDataset2(Dataset): - def __init__(self, image_dir): - self.image_dir = Path(image_dir) - - def __len__(self): - return len(list(self.image_dir.iterdir())) - - def __getitem__(self, index): - path = self.image_dir / str(index) - - # open and convert image - image = np.array(Image.open(path / "image.jpg").convert("RGB"), dtype=np.uint8) - - # open and convert mask - mask = np.array(Image.open(path / "MASK.PNG").convert("L"), dtype=np.uint8) // 255 - - # convert image & mask to Tensor float in [0, 1] - post_process = A.Compose( - [ - A.ToFloat(max_value=255), - ToTensorV2(), - ], - ) - augmentations = post_process(image=image, mask=mask) - image = augmentations["image"] - mask = augmentations["mask"] - - # make sure image and mask are floats, TODO: mettre dans le post_process, ToFloat Image only - image = image.float() - mask = mask.float() - - return image, mask - - -class LabeledDataset3(object): - def __init__(self, root, transforms): +class RealDataset(Dataset): + def __init__(self, root, transforms=None): self.root = root self.transforms = transforms + # load all image files, sorting them to ensure that they are aligned self.imgs = list(sorted(os.listdir(os.path.join(root, "images")))) self.masks = list(sorted(os.listdir(os.path.join(root, "masks")))) def __getitem__(self, idx): # create paths from ids - img_path = os.path.join(self.root, "images", self.imgs[idx]) + image_path = os.path.join(self.root, "images", self.imgs[idx]) mask_path = os.path.join(self.root, "masks", self.masks[idx]) # load image and mask - img = Image.open(img_path).convert("RGB") + image = Image.open(image_path).convert("RGB") mask = Image.open(mask_path) - # convert mask to numpy array to apply operations + # convert to numpy arrays + image = np.array(image) mask = np.array(mask) + # get ids from mask obj_ids = np.unique(mask) obj_ids = obj_ids[1:] # first id is the background, so remove it # split the color-encoded mask into a set of binary masks masks = mask == obj_ids[:, None, None] - # get bounding box coordinates for each mask + # create bboxes from masks (pascal format) num_objs = len(obj_ids) bboxes = [] for i in range(num_objs): @@ -152,27 +47,46 @@ class LabeledDataset3(object): ymax = np.max(pos[0]) bboxes.append([xmin, ymin, xmax, ymax]) - # convert arrays to tensors + # convert arrays to tensors, TODO: check what albumentations wants, to reduce follow lines bboxes = torch.as_tensor(bboxes, dtype=torch.float32) - labels = torch.ones((num_objs,), dtype=torch.int64) # there is only one class - masks = torch.as_tensor(masks, dtype=torch.uint8) - - # image_id = torch.tensor([idx]) - # area = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0]) - # iscrowd = torch.zeros((num_objs,), dtype=torch.int64) # suppose all instances are not crowd - - target = {} - target["boxes"] = bboxes - target["labels"] = labels - target["masks"] = masks - # target["image_id"] = image_id - # target["area"] = area - # target["iscrowd"] = iscrowd + labels = torch.ones((num_objs,), dtype=torch.int64) # suppose there is only one class (id=1) + masks = [mask for mask in masks] # albumentations wants list of masks if self.transforms is not None: - img, target = self.transforms(img, target) + # arrange transform data + data = { + "image": image, + "labels": labels, + "bboxes": bboxes, + "masks": masks, + } + # apply transform + augmented = self.transforms(**data) + # get augmented image and bboxes + image = augmented["image"] + bboxes = augmented["bboxes"] + labels = augmented["labels"] + # get masks + masks = augmented["masks"] - return img, target + bboxes = torch.as_tensor(bboxes, dtype=torch.float32) + labels = torch.as_tensor(labels, dtype=torch.int64) # int64 requiered by torchvision maskrcnn + masks = torch.stack(masks) # stack masks, wanted by maskrcnn from torchvision + + area = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0]) + image_id = torch.tensor([idx]) + iscrowd = torch.zeros((num_objs,), dtype=torch.int64) # suppose all instances are not crowd + + target = { + "boxes": bboxes, + "labels": labels, + "masks": masks, + "area": area, + "image_id": image_id, + "iscrowd": iscrowd, + } + + return image, target def __len__(self): return len(self.imgs) diff --git a/src/mrcnn/module.py b/src/mrcnn/module.py index ddf8826..b1e8ae5 100644 --- a/src/mrcnn/module.py +++ b/src/mrcnn/module.py @@ -3,16 +3,17 @@ import pytorch_lightning as pl import torch import torchvision -import wandb from torchvision.models.detection._utils import Matcher from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor from torchvision.ops.boxes import box_iou +import wandb + def get_model_instance_segmentation(num_classes): # load an instance segmentation model pre-trained on COCO - model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT") + model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) # get number of input features for the classifier in_features = model.roi_heads.box_predictor.cls_score.in_features @@ -42,82 +43,84 @@ class MRCNNModule(pl.LightningModule): # Network self.model = get_model_instance_segmentation(n_classes) - def forward(self, imgs): - # Torchvision FasterRCNN returns the loss during training - # and the boxes during eval - self.model.eval() - return self.model(imgs) + # def forward(self, imgs): + # # Torchvision FasterRCNN returns the loss during training + # # and the boxes during eval + # self.model.eval() + # return self.model(imgs) def training_step(self, batch, batch_idx): # unpack batch images, targets = batch # enable train mode - self.model.train() + # self.model.train() # fasterrcnn takes both images and targets for training loss_dict = self.model(images, targets) loss = sum(loss_dict.values()) + # self.log_dict(loss_dict) + # self.log(loss) return {"loss": loss, "log": loss_dict} - def validation_step(self, batch, batch_idx): - # unpack batch - images, targets = batch + # def validation_step(self, batch, batch_idx): + # # unpack batch + # images, targets = batch - # enable eval mode - self.detector.eval() + # # enable eval mode + # # self.detector.eval() - # make a prediction - preds = self.detector(images) + # # make a prediction + # preds = self.model(images) - # compute validation loss - self.val_loss = torch.mean( - torch.stack( - [ - self.accuracy( - target, - pred["boxes"], - iou_threshold=0.5, - ) - for target, pred in zip(targets, preds) - ], - ) - ) + # # compute validation loss + # self.val_loss = torch.mean( + # torch.stack( + # [ + # self.accuracy( + # target, + # pred["boxes"], + # iou_threshold=0.5, + # ) + # for target, pred in zip(targets, preds) + # ], + # ) + # ) - return self.val_loss + # return self.val_loss - def accuracy(self, src_boxes, pred_boxes, iou_threshold=1.0): - """ - The accuracy method is not the one used in the evaluator but very similar - """ - total_gt = len(src_boxes) - total_pred = len(pred_boxes) - if total_gt > 0 and total_pred > 0: + # def accuracy(self, src_boxes, pred_boxes, iou_threshold=1.0): + # """ + # The accuracy method is not the one used in the evaluator but very similar + # """ + # total_gt = len(src_boxes) + # total_pred = len(pred_boxes) + # if total_gt > 0 and total_pred > 0: - # Define the matcher and distance matrix based on iou - matcher = Matcher(iou_threshold, iou_threshold, allow_low_quality_matches=False) - match_quality_matrix = box_iou(src_boxes, pred_boxes) + # # Define the matcher and distance matrix based on iou + # matcher = Matcher(iou_threshold, iou_threshold, allow_low_quality_matches=False) + # match_quality_matrix = box_iou(src_boxes, pred_boxes) - results = matcher(match_quality_matrix) + # results = matcher(match_quality_matrix) - true_positive = torch.count_nonzero(results.unique() != -1) - matched_elements = results[results > -1] + # true_positive = torch.count_nonzero(results.unique() != -1) + # matched_elements = results[results > -1] - # in Matcher, a pred element can be matched only twice - false_positive = torch.count_nonzero(results == -1) + ( - len(matched_elements) - len(matched_elements.unique()) - ) - false_negative = total_gt - true_positive + # # in Matcher, a pred element can be matched only twice + # false_positive = torch.count_nonzero(results == -1) + ( + # len(matched_elements) - len(matched_elements.unique()) + # ) + # false_negative = total_gt - true_positive - return true_positive / (true_positive + false_positive + false_negative) + # return true_positive / (true_positive + false_positive + false_negative) - elif total_gt == 0: - if total_pred > 0: - return torch.tensor(0.0).cuda() - else: - return torch.tensor(1.0).cuda() - elif total_gt > 0 and total_pred == 0: - return torch.tensor(0.0).cuda() + # elif total_gt == 0: + # if total_pred > 0: + # return torch.tensor(0.0).cuda() + # else: + # return torch.tensor(1.0).cuda() + # elif total_gt > 0 and total_pred == 0: + # return torch.tensor(0.0).cuda() def configure_optimizers(self): optimizer = torch.optim.SGD( @@ -125,20 +128,22 @@ class MRCNNModule(pl.LightningModule): lr=wandb.config.LEARNING_RATE, momentum=wandb.config.MOMENTUM, weight_decay=wandb.config.WEIGHT_DECAY, - nesterov=wandb.config.NESTEROV, - ) - scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( - optimizer, - T_0=3, - T_mult=1, - lr=wandb.config.LEARNING_RATE_MIN, - verbose=True, ) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": scheduler, - "monitor": "val_accuracy", - }, - } + # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + # optimizer, + # T_0=3, + # T_mult=1, + # lr=wandb.config.LEARNING_RATE_MIN, + # verbose=True, + # ) + + # return { + # "optimizer": optimizer, + # "lr_scheduler": { + # "scheduler": scheduler, + # "monitor": "val_accuracy", + # }, + # } + + return optimizer diff --git a/src/train.py b/src/train.py index 86d85e6..f7027bb 100644 --- a/src/train.py +++ b/src/train.py @@ -1,10 +1,10 @@ import logging import pytorch_lightning as pl -import wandb from pytorch_lightning.callbacks import RichProgressBar from pytorch_lightning.loggers import WandbLogger +import wandb from data import Spheres from mrcnn import MRCNNModule from unet import UNetModule @@ -58,10 +58,10 @@ if __name__ == "__main__": precision=16, logger=logger, log_every_n_steps=1, - val_check_interval=100, + # val_check_interval=100, callbacks=[RichProgressBar(), ArtifactLog(), TableLog()], # profiler="simple", - # num_sanity_val_steps=0, + num_sanity_val_steps=0, ) # actually train the model diff --git a/wandb.yaml b/wandb.yaml index c26bf50..738b412 100644 --- a/wandb.yaml +++ b/wandb.yaml @@ -17,11 +17,11 @@ AMP: PIN_MEMORY: value: True BENCHMARK: - value: True + value: False DEVICE: value: gpu WORKERS: - value: 8 + value: 1 IMG_SIZE: value: 512 @@ -31,11 +31,11 @@ SPHERES: EPOCHS: value: 3 TRAIN_BATCH_SIZE: - value: 128 # 100 + value: 2 # 100 VAL_BATCH_SIZE: - value: 8 # 10 + value: 0 # 10 PREFETCH_FACTOR: - value: 2 + value: 1 LEARNING_RATE: value: 1.0e-4