From 81cbfd6212e0663c13f2085085a74dcdf54e5be4 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Fri, 8 Jul 2022 16:23:22 +0200 Subject: [PATCH] feat: split two types of datasets Former-commit-id: 2609316692d315f4b0df614c533bf28d20ffaf21 [formerly c1a425cb33fefa2809e591f0fe527236f6386863] Former-commit-id: c3f96d3f272652a6162b17112be0e722c99eef57 --- src/data/dataloader.py | 15 ++++----- src/data/dataset.py | 70 ++++++++++++++++++++++++++++++------------ 2 files changed, 56 insertions(+), 29 deletions(-) diff --git a/src/data/dataloader.py b/src/data/dataloader.py index f575d3a..3d585c2 100644 --- a/src/data/dataloader.py +++ b/src/data/dataloader.py @@ -1,12 +1,11 @@ import albumentations as A import pytorch_lightning as pl -from albumentations.pytorch import ToTensorV2 from torch.utils.data import DataLoader import wandb from utils import RandomPaste -from .dataset import SphereDataset +from .dataset import LabeledDataset, SyntheticDataset class SyntheticSphere(pl.LightningDataModule): @@ -14,7 +13,7 @@ class SyntheticSphere(pl.LightningDataModule): super().__init__() def train_dataloader(self): - tf_train = A.Compose( + transform = A.Compose( [ A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE), A.Flip(), @@ -22,16 +21,14 @@ class SyntheticSphere(pl.LightningDataModule): RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE), A.GaussianBlur(), A.ISONoise(), - A.ToFloat(max_value=255), - ToTensorV2(), ], ) - ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train) + dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform) # ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000))) return DataLoader( - ds_train, + dataset, shuffle=True, batch_size=wandb.config.BATCH_SIZE, num_workers=wandb.config.WORKERS, @@ -39,10 +36,10 @@ class SyntheticSphere(pl.LightningDataModule): ) def val_dataloader(self): - ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG) + dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG) return DataLoader( - ds_valid, + dataset, shuffle=False, batch_size=1, num_workers=wandb.config.WORKERS, diff --git a/src/data/dataset.py b/src/data/dataset.py index bf821a4..01c50ea 100644 --- a/src/data/dataset.py +++ b/src/data/dataset.py @@ -7,8 +7,8 @@ from PIL import Image from torch.utils.data import Dataset -class SphereDataset(Dataset): - def __init__(self, image_dir, transform=None): +class SyntheticDataset(Dataset): + def __init__(self, image_dir, transform): self.images = list(Path(image_dir).glob("**/*.jpg")) self.transform = transform @@ -16,30 +16,60 @@ class SphereDataset(Dataset): return len(self.images) def __getitem__(self, index): + # open and convert image image = np.array(Image.open(self.images[index]).convert("RGB"), dtype=np.uint8) - if self.transform is not None: - mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) - augmentations = self.transform(image=image, mask=mask) - image = augmentations["image"] - mask = augmentations["mask"] - else: - mask_path = self.images[index].parent.joinpath("MASK.PNG") - mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8) / 255 + # create empty mask of same size + mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) - preprocess = A.Compose( - [ - A.SmallestMaxSize(1024), - A.ToFloat(max_value=255), - ToTensorV2(), - ], - ) - augmentations = preprocess(image=image, mask=mask) - image = augmentations["image"] - mask = augmentations["mask"] + # augment image and mask + augmentations = self.transform(image=image, mask=mask) + image = augmentations["image"] + mask = augmentations["mask"] + + # convert image & mask to Tensor float in [0, 1] + post_process = A.Compose( + [ + A.ToFloat(max_value=255), + ToTensorV2(), + ], + ) + augmentations = post_process(image=image, mask=mask) + image = augmentations["image"] + mask = augmentations["mask"] # make sure image and mask are floats image = image.float() mask = mask.float() return image, mask + + +class LabeledDataset(Dataset): + def __init__(self, image_dir): + self.images = list(Path(image_dir).glob("**/*.jpg")) + + def __len__(self): + return len(self.images) + + def __getitem__(self, index): + # open and convert image + image = np.array(Image.open(self.images[index]).convert("RGB"), dtype=np.uint8) + + # open and convert mask + mask_path = self.images[index].parent.joinpath("MASK.PNG") + mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8) / 255 + + # convert image & mask to Tensor float in [0, 1] + post_process = A.Compose( + [ + # A.SmallestMaxSize(1024), + A.ToFloat(max_value=255), + ToTensorV2(), + ], + ) + augmentations = post_process(image=image, mask=mask) + image = augmentations["image"] + mask = augmentations["mask"] + + return image, mask