feat: split two types of datasets
Former-commit-id: 2609316692d315f4b0df614c533bf28d20ffaf21 [formerly c1a425cb33fefa2809e591f0fe527236f6386863] Former-commit-id: c3f96d3f272652a6162b17112be0e722c99eef57
This commit is contained in:
parent
5f46efa5a1
commit
81cbfd6212
|
@ -1,12 +1,11 @@
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from albumentations.pytorch import ToTensorV2
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
from utils import RandomPaste
|
from utils import RandomPaste
|
||||||
|
|
||||||
from .dataset import SphereDataset
|
from .dataset import LabeledDataset, SyntheticDataset
|
||||||
|
|
||||||
|
|
||||||
class SyntheticSphere(pl.LightningDataModule):
|
class SyntheticSphere(pl.LightningDataModule):
|
||||||
|
@ -14,7 +13,7 @@ class SyntheticSphere(pl.LightningDataModule):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
tf_train = A.Compose(
|
transform = A.Compose(
|
||||||
[
|
[
|
||||||
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
|
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
|
||||||
A.Flip(),
|
A.Flip(),
|
||||||
|
@ -22,16 +21,14 @@ class SyntheticSphere(pl.LightningDataModule):
|
||||||
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE),
|
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE),
|
||||||
A.GaussianBlur(),
|
A.GaussianBlur(),
|
||||||
A.ISONoise(),
|
A.ISONoise(),
|
||||||
A.ToFloat(max_value=255),
|
|
||||||
ToTensorV2(),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
|
dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform)
|
||||||
# ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
|
# ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
|
||||||
|
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
ds_train,
|
dataset,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
batch_size=wandb.config.BATCH_SIZE,
|
batch_size=wandb.config.BATCH_SIZE,
|
||||||
num_workers=wandb.config.WORKERS,
|
num_workers=wandb.config.WORKERS,
|
||||||
|
@ -39,10 +36,10 @@ class SyntheticSphere(pl.LightningDataModule):
|
||||||
)
|
)
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
||||||
|
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
ds_valid,
|
dataset,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
num_workers=wandb.config.WORKERS,
|
num_workers=wandb.config.WORKERS,
|
||||||
|
|
|
@ -7,8 +7,8 @@ from PIL import Image
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
class SphereDataset(Dataset):
|
class SyntheticDataset(Dataset):
|
||||||
def __init__(self, image_dir, transform=None):
|
def __init__(self, image_dir, transform):
|
||||||
self.images = list(Path(image_dir).glob("**/*.jpg"))
|
self.images = list(Path(image_dir).glob("**/*.jpg"))
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
|
|
||||||
|
@ -16,30 +16,60 @@ class SphereDataset(Dataset):
|
||||||
return len(self.images)
|
return len(self.images)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
# open and convert image
|
||||||
image = np.array(Image.open(self.images[index]).convert("RGB"), dtype=np.uint8)
|
image = np.array(Image.open(self.images[index]).convert("RGB"), dtype=np.uint8)
|
||||||
|
|
||||||
if self.transform is not None:
|
# create empty mask of same size
|
||||||
mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
||||||
augmentations = self.transform(image=image, mask=mask)
|
|
||||||
image = augmentations["image"]
|
|
||||||
mask = augmentations["mask"]
|
|
||||||
else:
|
|
||||||
mask_path = self.images[index].parent.joinpath("MASK.PNG")
|
|
||||||
mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8) / 255
|
|
||||||
|
|
||||||
preprocess = A.Compose(
|
# augment image and mask
|
||||||
[
|
augmentations = self.transform(image=image, mask=mask)
|
||||||
A.SmallestMaxSize(1024),
|
image = augmentations["image"]
|
||||||
A.ToFloat(max_value=255),
|
mask = augmentations["mask"]
|
||||||
ToTensorV2(),
|
|
||||||
],
|
# convert image & mask to Tensor float in [0, 1]
|
||||||
)
|
post_process = A.Compose(
|
||||||
augmentations = preprocess(image=image, mask=mask)
|
[
|
||||||
image = augmentations["image"]
|
A.ToFloat(max_value=255),
|
||||||
mask = augmentations["mask"]
|
ToTensorV2(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
augmentations = post_process(image=image, mask=mask)
|
||||||
|
image = augmentations["image"]
|
||||||
|
mask = augmentations["mask"]
|
||||||
|
|
||||||
# make sure image and mask are floats
|
# make sure image and mask are floats
|
||||||
image = image.float()
|
image = image.float()
|
||||||
mask = mask.float()
|
mask = mask.float()
|
||||||
|
|
||||||
return image, mask
|
return image, mask
|
||||||
|
|
||||||
|
|
||||||
|
class LabeledDataset(Dataset):
|
||||||
|
def __init__(self, image_dir):
|
||||||
|
self.images = list(Path(image_dir).glob("**/*.jpg"))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.images)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
# open and convert image
|
||||||
|
image = np.array(Image.open(self.images[index]).convert("RGB"), dtype=np.uint8)
|
||||||
|
|
||||||
|
# open and convert mask
|
||||||
|
mask_path = self.images[index].parent.joinpath("MASK.PNG")
|
||||||
|
mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8) / 255
|
||||||
|
|
||||||
|
# convert image & mask to Tensor float in [0, 1]
|
||||||
|
post_process = A.Compose(
|
||||||
|
[
|
||||||
|
# A.SmallestMaxSize(1024),
|
||||||
|
A.ToFloat(max_value=255),
|
||||||
|
ToTensorV2(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
augmentations = post_process(image=image, mask=mask)
|
||||||
|
image = augmentations["image"]
|
||||||
|
mask = augmentations["mask"]
|
||||||
|
|
||||||
|
return image, mask
|
||||||
|
|
Loading…
Reference in a new issue