feat: create new datasets for generation and loading of pre rendered images

Former-commit-id: 4ccd97c3583c7def1e5e988a5beb8fcda7545fc9 [formerly 6a7e1eb28ade30f3e4b76337c209ac7e5b3b1cbf]
Former-commit-id: 8a552d696c6e47378a36a8680e3a09d151de25f1
This commit is contained in:
Laurent Fainsin 2022-09-12 09:27:49 +02:00
parent c6c08ac98a
commit 0693f02d83
2 changed files with 125 additions and 11 deletions

View file

@ -2,9 +2,9 @@ import albumentations as A
import pytorch_lightning as pl
import wandb
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Subset
from .dataset import RealDataset
from .dataset import LabeledDataset, RealDataset
def collate_fn(batch):
@ -18,13 +18,13 @@ class Spheres(pl.LightningDataModule):
def train_dataloader(self):
transforms = A.Compose(
[
A.Flip(),
A.ColorJitter(),
A.ToGray(p=0.01),
A.GaussianBlur(),
A.MotionBlur(),
A.ISONoise(),
A.ImageCompression(),
# A.Flip(),
# A.ColorJitter(),
# A.ToGray(p=0.01),
# A.GaussianBlur(),
# A.MotionBlur(),
# A.ISONoise(),
# A.ImageCompression(),
A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
@ -40,7 +40,9 @@ class Spheres(pl.LightningDataModule):
),
)
dataset = RealDataset(root="/dev/shm/TRAIN/", transforms=transforms)
dataset = LabeledDataset("/dev/shm/TRAIN/", transforms)
# dataset = Subset(dataset, range(6 * 200)) # subset for debugging purpose
# dataset = Subset(dataset, [0] * 320) # overfit test
return DataLoader(
dataset,

View file

@ -1,4 +1,5 @@
import os
from pathlib import Path
import albumentations as A
import numpy as np
@ -7,6 +8,37 @@ from PIL import Image
from torch.utils.data import Dataset
class SyntheticDataset(Dataset):
def __init__(self, image_dir, transform):
self.images = list(Path(image_dir).glob("**/*.jpg"))
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self, index):
# open and convert image
image = np.ascontiguousarray(
Image.open(
self.images[index],
).convert("RGB"),
dtype=np.uint8,
)
# create empty mask of same size
mask = np.zeros(
(*image.shape[:2], 4),
dtype=np.uint8,
)
# augment image and mask
augmentations = self.transform(image=image, mask=mask)
image = augmentations["image"]
mask = augmentations["mask"]
return image, mask
class RealDataset(Dataset):
def __init__(self, root, transforms=None):
self.root = root
@ -16,7 +48,7 @@ class RealDataset(Dataset):
self.imgs = list(sorted(os.listdir(os.path.join(root, "images"))))
self.masks = list(sorted(os.listdir(os.path.join(root, "masks"))))
self.res = A.SmallestMaxSize(max_size=1024)
self.res = A.LongestMaxSize(max_size=1024)
def __getitem__(self, idx):
# create paths from ids
@ -97,3 +129,83 @@ class RealDataset(Dataset):
def __len__(self):
return len(self.imgs)
class LabeledDataset(Dataset):
def __init__(self, image_dir, transforms):
self.images = list(Path(image_dir).glob("**/*.jpg"))
self.transforms = transforms
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
# open and convert image
image = np.ascontiguousarray(
Image.open(self.images[idx]).convert("RGB"),
)
# open and convert mask
mask_path = self.images[idx].parent.joinpath("MASK.PNG")
mask = np.ascontiguousarray(
Image.open(mask_path).convert("L"),
)
# get ids from mask
obj_ids = np.unique(mask)
obj_ids = obj_ids[1:] # first id is the background, so remove it
# split the color-encoded mask into a set of binary masks
masks = mask == obj_ids[:, None, None]
masks = masks.astype(np.uint8) # cast to uint8 for albumentations
# create bboxes from masks (pascal format)
num_objs = len(obj_ids)
bboxes = []
for i in range(num_objs):
pos = np.where(masks[i])
xmin = np.min(pos[1])
xmax = np.max(pos[1])
ymin = np.min(pos[0])
ymax = np.max(pos[0])
bboxes.append([xmin, ymin, xmax, ymax])
# convert arrays for albumentations
bboxes = torch.as_tensor(bboxes, dtype=torch.int64)
labels = torch.ones((num_objs,), dtype=torch.int64) # assume there is only one class (id=1)
masks = list(np.asarray(masks))
if self.transforms is not None:
# arrange transform data
data = {
"image": image,
"labels": labels,
"bboxes": bboxes,
"masks": masks,
}
# apply transform
augmented = self.transforms(**data)
# get augmented data
image = augmented["image"]
bboxes = augmented["bboxes"]
labels = augmented["labels"]
masks = augmented["masks"]
bboxes = torch.as_tensor(bboxes, dtype=torch.int64)
labels = torch.as_tensor(labels, dtype=torch.int64) # int64 required by torchvision maskrcnn
masks = torch.stack(masks) # stack masks, required by torchvision maskrcnn
area = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
image_id = torch.tensor([idx])
iscrowd = torch.zeros((num_objs,), dtype=torch.int64) # assume all instances are not crowd
target = {
"boxes": bboxes,
"labels": labels,
"masks": masks,
"area": area,
"image_id": image_id,
"iscrowd": iscrowd,
}
return image, target