feat: create new datasets for generation and loading of pre rendered images

Former-commit-id: 4ccd97c3583c7def1e5e988a5beb8fcda7545fc9 [formerly 6a7e1eb28ade30f3e4b76337c209ac7e5b3b1cbf]
Former-commit-id: 8a552d696c6e47378a36a8680e3a09d151de25f1
This commit is contained in:
Laurent Fainsin 2022-09-12 09:27:49 +02:00
parent c6c08ac98a
commit 0693f02d83
2 changed files with 125 additions and 11 deletions

View file

@ -2,9 +2,9 @@ import albumentations as A
import pytorch_lightning as pl import pytorch_lightning as pl
import wandb import wandb
from albumentations.pytorch import ToTensorV2 from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader from torch.utils.data import DataLoader, Subset
from .dataset import RealDataset from .dataset import LabeledDataset, RealDataset
def collate_fn(batch): def collate_fn(batch):
@ -18,13 +18,13 @@ class Spheres(pl.LightningDataModule):
def train_dataloader(self): def train_dataloader(self):
transforms = A.Compose( transforms = A.Compose(
[ [
A.Flip(), # A.Flip(),
A.ColorJitter(), # A.ColorJitter(),
A.ToGray(p=0.01), # A.ToGray(p=0.01),
A.GaussianBlur(), # A.GaussianBlur(),
A.MotionBlur(), # A.MotionBlur(),
A.ISONoise(), # A.ISONoise(),
A.ImageCompression(), # A.ImageCompression(),
A.Normalize( A.Normalize(
mean=[0.485, 0.456, 0.406], mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225], std=[0.229, 0.224, 0.225],
@ -40,7 +40,9 @@ class Spheres(pl.LightningDataModule):
), ),
) )
dataset = RealDataset(root="/dev/shm/TRAIN/", transforms=transforms) dataset = LabeledDataset("/dev/shm/TRAIN/", transforms)
# dataset = Subset(dataset, range(6 * 200)) # subset for debugging purpose
# dataset = Subset(dataset, [0] * 320) # overfit test
return DataLoader( return DataLoader(
dataset, dataset,

View file

@ -1,4 +1,5 @@
import os import os
from pathlib import Path
import albumentations as A import albumentations as A
import numpy as np import numpy as np
@ -7,6 +8,37 @@ from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset
class SyntheticDataset(Dataset):
def __init__(self, image_dir, transform):
self.images = list(Path(image_dir).glob("**/*.jpg"))
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self, index):
# open and convert image
image = np.ascontiguousarray(
Image.open(
self.images[index],
).convert("RGB"),
dtype=np.uint8,
)
# create empty mask of same size
mask = np.zeros(
(*image.shape[:2], 4),
dtype=np.uint8,
)
# augment image and mask
augmentations = self.transform(image=image, mask=mask)
image = augmentations["image"]
mask = augmentations["mask"]
return image, mask
class RealDataset(Dataset): class RealDataset(Dataset):
def __init__(self, root, transforms=None): def __init__(self, root, transforms=None):
self.root = root self.root = root
@ -16,7 +48,7 @@ class RealDataset(Dataset):
self.imgs = list(sorted(os.listdir(os.path.join(root, "images")))) self.imgs = list(sorted(os.listdir(os.path.join(root, "images"))))
self.masks = list(sorted(os.listdir(os.path.join(root, "masks")))) self.masks = list(sorted(os.listdir(os.path.join(root, "masks"))))
self.res = A.SmallestMaxSize(max_size=1024) self.res = A.LongestMaxSize(max_size=1024)
def __getitem__(self, idx): def __getitem__(self, idx):
# create paths from ids # create paths from ids
@ -97,3 +129,83 @@ class RealDataset(Dataset):
def __len__(self): def __len__(self):
return len(self.imgs) return len(self.imgs)
class LabeledDataset(Dataset):
def __init__(self, image_dir, transforms):
self.images = list(Path(image_dir).glob("**/*.jpg"))
self.transforms = transforms
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
# open and convert image
image = np.ascontiguousarray(
Image.open(self.images[idx]).convert("RGB"),
)
# open and convert mask
mask_path = self.images[idx].parent.joinpath("MASK.PNG")
mask = np.ascontiguousarray(
Image.open(mask_path).convert("L"),
)
# get ids from mask
obj_ids = np.unique(mask)
obj_ids = obj_ids[1:] # first id is the background, so remove it
# split the color-encoded mask into a set of binary masks
masks = mask == obj_ids[:, None, None]
masks = masks.astype(np.uint8) # cast to uint8 for albumentations
# create bboxes from masks (pascal format)
num_objs = len(obj_ids)
bboxes = []
for i in range(num_objs):
pos = np.where(masks[i])
xmin = np.min(pos[1])
xmax = np.max(pos[1])
ymin = np.min(pos[0])
ymax = np.max(pos[0])
bboxes.append([xmin, ymin, xmax, ymax])
# convert arrays for albumentations
bboxes = torch.as_tensor(bboxes, dtype=torch.int64)
labels = torch.ones((num_objs,), dtype=torch.int64) # assume there is only one class (id=1)
masks = list(np.asarray(masks))
if self.transforms is not None:
# arrange transform data
data = {
"image": image,
"labels": labels,
"bboxes": bboxes,
"masks": masks,
}
# apply transform
augmented = self.transforms(**data)
# get augmented data
image = augmented["image"]
bboxes = augmented["bboxes"]
labels = augmented["labels"]
masks = augmented["masks"]
bboxes = torch.as_tensor(bboxes, dtype=torch.int64)
labels = torch.as_tensor(labels, dtype=torch.int64) # int64 required by torchvision maskrcnn
masks = torch.stack(masks) # stack masks, required by torchvision maskrcnn
area = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
image_id = torch.tensor([idx])
iscrowd = torch.zeros((num_objs,), dtype=torch.int64) # assume all instances are not crowd
target = {
"boxes": bboxes,
"labels": labels,
"masks": masks,
"area": area,
"image_id": image_id,
"iscrowd": iscrowd,
}
return image, target