diff --git a/src/data/dataloader.py b/src/data/dataloader.py index bc3a8e7..120ac10 100644 --- a/src/data/dataloader.py +++ b/src/data/dataloader.py @@ -2,9 +2,9 @@ import albumentations as A import pytorch_lightning as pl import wandb from albumentations.pytorch import ToTensorV2 -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Subset -from .dataset import RealDataset +from .dataset import LabeledDataset, RealDataset def collate_fn(batch): @@ -18,13 +18,13 @@ class Spheres(pl.LightningDataModule): def train_dataloader(self): transforms = A.Compose( [ - A.Flip(), - A.ColorJitter(), - A.ToGray(p=0.01), - A.GaussianBlur(), - A.MotionBlur(), - A.ISONoise(), - A.ImageCompression(), + # A.Flip(), + # A.ColorJitter(), + # A.ToGray(p=0.01), + # A.GaussianBlur(), + # A.MotionBlur(), + # A.ISONoise(), + # A.ImageCompression(), A.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], @@ -40,7 +40,9 @@ class Spheres(pl.LightningDataModule): ), ) - dataset = RealDataset(root="/dev/shm/TRAIN/", transforms=transforms) + dataset = LabeledDataset("/dev/shm/TRAIN/", transforms) + # dataset = Subset(dataset, range(6 * 200)) # subset for debugging purpose + # dataset = Subset(dataset, [0] * 320) # overfit test return DataLoader( dataset, diff --git a/src/data/dataset.py b/src/data/dataset.py index e75e9b9..3addc8b 100644 --- a/src/data/dataset.py +++ b/src/data/dataset.py @@ -1,4 +1,5 @@ import os +from pathlib import Path import albumentations as A import numpy as np @@ -7,6 +8,37 @@ from PIL import Image from torch.utils.data import Dataset +class SyntheticDataset(Dataset): + def __init__(self, image_dir, transform): + self.images = list(Path(image_dir).glob("**/*.jpg")) + self.transform = transform + + def __len__(self): + return len(self.images) + + def __getitem__(self, index): + # open and convert image + image = np.ascontiguousarray( + Image.open( + self.images[index], + ).convert("RGB"), + dtype=np.uint8, + ) + + # create empty mask of same size + mask = np.zeros( + (*image.shape[:2], 4), + dtype=np.uint8, + ) + + # augment image and mask + augmentations = self.transform(image=image, mask=mask) + image = augmentations["image"] + mask = augmentations["mask"] + + return image, mask + + class RealDataset(Dataset): def __init__(self, root, transforms=None): self.root = root @@ -16,7 +48,7 @@ class RealDataset(Dataset): self.imgs = list(sorted(os.listdir(os.path.join(root, "images")))) self.masks = list(sorted(os.listdir(os.path.join(root, "masks")))) - self.res = A.SmallestMaxSize(max_size=1024) + self.res = A.LongestMaxSize(max_size=1024) def __getitem__(self, idx): # create paths from ids @@ -97,3 +129,83 @@ class RealDataset(Dataset): def __len__(self): return len(self.imgs) + + +class LabeledDataset(Dataset): + def __init__(self, image_dir, transforms): + self.images = list(Path(image_dir).glob("**/*.jpg")) + self.transforms = transforms + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + # open and convert image + image = np.ascontiguousarray( + Image.open(self.images[idx]).convert("RGB"), + ) + + # open and convert mask + mask_path = self.images[idx].parent.joinpath("MASK.PNG") + mask = np.ascontiguousarray( + Image.open(mask_path).convert("L"), + ) + + # get ids from mask + obj_ids = np.unique(mask) + obj_ids = obj_ids[1:] # first id is the background, so remove it + + # split the color-encoded mask into a set of binary masks + masks = mask == obj_ids[:, None, None] + masks = masks.astype(np.uint8) # cast to uint8 for albumentations + + # create bboxes from masks (pascal format) + num_objs = len(obj_ids) + bboxes = [] + for i in range(num_objs): + pos = np.where(masks[i]) + xmin = np.min(pos[1]) + xmax = np.max(pos[1]) + ymin = np.min(pos[0]) + ymax = np.max(pos[0]) + bboxes.append([xmin, ymin, xmax, ymax]) + + # convert arrays for albumentations + bboxes = torch.as_tensor(bboxes, dtype=torch.int64) + labels = torch.ones((num_objs,), dtype=torch.int64) # assume there is only one class (id=1) + masks = list(np.asarray(masks)) + + if self.transforms is not None: + # arrange transform data + data = { + "image": image, + "labels": labels, + "bboxes": bboxes, + "masks": masks, + } + # apply transform + augmented = self.transforms(**data) + # get augmented data + image = augmented["image"] + bboxes = augmented["bboxes"] + labels = augmented["labels"] + masks = augmented["masks"] + + bboxes = torch.as_tensor(bboxes, dtype=torch.int64) + labels = torch.as_tensor(labels, dtype=torch.int64) # int64 required by torchvision maskrcnn + masks = torch.stack(masks) # stack masks, required by torchvision maskrcnn + + area = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0]) + image_id = torch.tensor([idx]) + iscrowd = torch.zeros((num_objs,), dtype=torch.int64) # assume all instances are not crowd + + target = { + "boxes": bboxes, + "labels": labels, + "masks": masks, + "area": area, + "image_id": image_id, + "iscrowd": iscrowd, + } + + return image, target