feat: split two types of datasets

Former-commit-id: 2609316692d315f4b0df614c533bf28d20ffaf21 [formerly c1a425cb33fefa2809e591f0fe527236f6386863]
Former-commit-id: c3f96d3f272652a6162b17112be0e722c99eef57
This commit is contained in:
Laurent Fainsin 2022-07-08 16:23:22 +02:00
parent 5f46efa5a1
commit 81cbfd6212
2 changed files with 56 additions and 29 deletions

View file

@ -1,12 +1,11 @@
import albumentations as A import albumentations as A
import pytorch_lightning as pl import pytorch_lightning as pl
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import wandb import wandb
from utils import RandomPaste from utils import RandomPaste
from .dataset import SphereDataset from .dataset import LabeledDataset, SyntheticDataset
class SyntheticSphere(pl.LightningDataModule): class SyntheticSphere(pl.LightningDataModule):
@ -14,7 +13,7 @@ class SyntheticSphere(pl.LightningDataModule):
super().__init__() super().__init__()
def train_dataloader(self): def train_dataloader(self):
tf_train = A.Compose( transform = A.Compose(
[ [
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE), A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
A.Flip(), A.Flip(),
@ -22,16 +21,14 @@ class SyntheticSphere(pl.LightningDataModule):
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE), RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE),
A.GaussianBlur(), A.GaussianBlur(),
A.ISONoise(), A.ISONoise(),
A.ToFloat(max_value=255),
ToTensorV2(),
], ],
) )
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train) dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform)
# ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000))) # ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
return DataLoader( return DataLoader(
ds_train, dataset,
shuffle=True, shuffle=True,
batch_size=wandb.config.BATCH_SIZE, batch_size=wandb.config.BATCH_SIZE,
num_workers=wandb.config.WORKERS, num_workers=wandb.config.WORKERS,
@ -39,10 +36,10 @@ class SyntheticSphere(pl.LightningDataModule):
) )
def val_dataloader(self): def val_dataloader(self):
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG) dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
return DataLoader( return DataLoader(
ds_valid, dataset,
shuffle=False, shuffle=False,
batch_size=1, batch_size=1,
num_workers=wandb.config.WORKERS, num_workers=wandb.config.WORKERS,

View file

@ -7,8 +7,8 @@ from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset
class SphereDataset(Dataset): class SyntheticDataset(Dataset):
def __init__(self, image_dir, transform=None): def __init__(self, image_dir, transform):
self.images = list(Path(image_dir).glob("**/*.jpg")) self.images = list(Path(image_dir).glob("**/*.jpg"))
self.transform = transform self.transform = transform
@ -16,30 +16,60 @@ class SphereDataset(Dataset):
return len(self.images) return len(self.images)
def __getitem__(self, index): def __getitem__(self, index):
# open and convert image
image = np.array(Image.open(self.images[index]).convert("RGB"), dtype=np.uint8) image = np.array(Image.open(self.images[index]).convert("RGB"), dtype=np.uint8)
if self.transform is not None: # create empty mask of same size
mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
augmentations = self.transform(image=image, mask=mask)
image = augmentations["image"]
mask = augmentations["mask"]
else:
mask_path = self.images[index].parent.joinpath("MASK.PNG")
mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8) / 255
preprocess = A.Compose( # augment image and mask
[ augmentations = self.transform(image=image, mask=mask)
A.SmallestMaxSize(1024), image = augmentations["image"]
A.ToFloat(max_value=255), mask = augmentations["mask"]
ToTensorV2(),
], # convert image & mask to Tensor float in [0, 1]
) post_process = A.Compose(
augmentations = preprocess(image=image, mask=mask) [
image = augmentations["image"] A.ToFloat(max_value=255),
mask = augmentations["mask"] ToTensorV2(),
],
)
augmentations = post_process(image=image, mask=mask)
image = augmentations["image"]
mask = augmentations["mask"]
# make sure image and mask are floats # make sure image and mask are floats
image = image.float() image = image.float()
mask = mask.float() mask = mask.float()
return image, mask return image, mask
class LabeledDataset(Dataset):
def __init__(self, image_dir):
self.images = list(Path(image_dir).glob("**/*.jpg"))
def __len__(self):
return len(self.images)
def __getitem__(self, index):
# open and convert image
image = np.array(Image.open(self.images[index]).convert("RGB"), dtype=np.uint8)
# open and convert mask
mask_path = self.images[index].parent.joinpath("MASK.PNG")
mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8) / 255
# convert image & mask to Tensor float in [0, 1]
post_process = A.Compose(
[
# A.SmallestMaxSize(1024),
A.ToFloat(max_value=255),
ToTensorV2(),
],
)
augmentations = post_process(image=image, mask=mask)
image = augmentations["image"]
mask = augmentations["mask"]
return image, mask