mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
feat: create new datasets for generation and loading of pre rendered images
Former-commit-id: 4ccd97c3583c7def1e5e988a5beb8fcda7545fc9 [formerly 6a7e1eb28ade30f3e4b76337c209ac7e5b3b1cbf] Former-commit-id: 8a552d696c6e47378a36a8680e3a09d151de25f1
This commit is contained in:
parent
c6c08ac98a
commit
0693f02d83
|
@ -2,9 +2,9 @@ import albumentations as A
|
|||
import pytorch_lightning as pl
|
||||
import wandb
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
|
||||
from .dataset import RealDataset
|
||||
from .dataset import LabeledDataset, RealDataset
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
|
@ -18,13 +18,13 @@ class Spheres(pl.LightningDataModule):
|
|||
def train_dataloader(self):
|
||||
transforms = A.Compose(
|
||||
[
|
||||
A.Flip(),
|
||||
A.ColorJitter(),
|
||||
A.ToGray(p=0.01),
|
||||
A.GaussianBlur(),
|
||||
A.MotionBlur(),
|
||||
A.ISONoise(),
|
||||
A.ImageCompression(),
|
||||
# A.Flip(),
|
||||
# A.ColorJitter(),
|
||||
# A.ToGray(p=0.01),
|
||||
# A.GaussianBlur(),
|
||||
# A.MotionBlur(),
|
||||
# A.ISONoise(),
|
||||
# A.ImageCompression(),
|
||||
A.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225],
|
||||
|
@ -40,7 +40,9 @@ class Spheres(pl.LightningDataModule):
|
|||
),
|
||||
)
|
||||
|
||||
dataset = RealDataset(root="/dev/shm/TRAIN/", transforms=transforms)
|
||||
dataset = LabeledDataset("/dev/shm/TRAIN/", transforms)
|
||||
# dataset = Subset(dataset, range(6 * 200)) # subset for debugging purpose
|
||||
# dataset = Subset(dataset, [0] * 320) # overfit test
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import albumentations as A
|
||||
import numpy as np
|
||||
|
@ -7,6 +8,37 @@ from PIL import Image
|
|||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class SyntheticDataset(Dataset):
|
||||
def __init__(self, image_dir, transform):
|
||||
self.images = list(Path(image_dir).glob("**/*.jpg"))
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def __getitem__(self, index):
|
||||
# open and convert image
|
||||
image = np.ascontiguousarray(
|
||||
Image.open(
|
||||
self.images[index],
|
||||
).convert("RGB"),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
# create empty mask of same size
|
||||
mask = np.zeros(
|
||||
(*image.shape[:2], 4),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
# augment image and mask
|
||||
augmentations = self.transform(image=image, mask=mask)
|
||||
image = augmentations["image"]
|
||||
mask = augmentations["mask"]
|
||||
|
||||
return image, mask
|
||||
|
||||
|
||||
class RealDataset(Dataset):
|
||||
def __init__(self, root, transforms=None):
|
||||
self.root = root
|
||||
|
@ -16,7 +48,7 @@ class RealDataset(Dataset):
|
|||
self.imgs = list(sorted(os.listdir(os.path.join(root, "images"))))
|
||||
self.masks = list(sorted(os.listdir(os.path.join(root, "masks"))))
|
||||
|
||||
self.res = A.SmallestMaxSize(max_size=1024)
|
||||
self.res = A.LongestMaxSize(max_size=1024)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# create paths from ids
|
||||
|
@ -97,3 +129,83 @@ class RealDataset(Dataset):
|
|||
|
||||
def __len__(self):
|
||||
return len(self.imgs)
|
||||
|
||||
|
||||
class LabeledDataset(Dataset):
|
||||
def __init__(self, image_dir, transforms):
|
||||
self.images = list(Path(image_dir).glob("**/*.jpg"))
|
||||
self.transforms = transforms
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# open and convert image
|
||||
image = np.ascontiguousarray(
|
||||
Image.open(self.images[idx]).convert("RGB"),
|
||||
)
|
||||
|
||||
# open and convert mask
|
||||
mask_path = self.images[idx].parent.joinpath("MASK.PNG")
|
||||
mask = np.ascontiguousarray(
|
||||
Image.open(mask_path).convert("L"),
|
||||
)
|
||||
|
||||
# get ids from mask
|
||||
obj_ids = np.unique(mask)
|
||||
obj_ids = obj_ids[1:] # first id is the background, so remove it
|
||||
|
||||
# split the color-encoded mask into a set of binary masks
|
||||
masks = mask == obj_ids[:, None, None]
|
||||
masks = masks.astype(np.uint8) # cast to uint8 for albumentations
|
||||
|
||||
# create bboxes from masks (pascal format)
|
||||
num_objs = len(obj_ids)
|
||||
bboxes = []
|
||||
for i in range(num_objs):
|
||||
pos = np.where(masks[i])
|
||||
xmin = np.min(pos[1])
|
||||
xmax = np.max(pos[1])
|
||||
ymin = np.min(pos[0])
|
||||
ymax = np.max(pos[0])
|
||||
bboxes.append([xmin, ymin, xmax, ymax])
|
||||
|
||||
# convert arrays for albumentations
|
||||
bboxes = torch.as_tensor(bboxes, dtype=torch.int64)
|
||||
labels = torch.ones((num_objs,), dtype=torch.int64) # assume there is only one class (id=1)
|
||||
masks = list(np.asarray(masks))
|
||||
|
||||
if self.transforms is not None:
|
||||
# arrange transform data
|
||||
data = {
|
||||
"image": image,
|
||||
"labels": labels,
|
||||
"bboxes": bboxes,
|
||||
"masks": masks,
|
||||
}
|
||||
# apply transform
|
||||
augmented = self.transforms(**data)
|
||||
# get augmented data
|
||||
image = augmented["image"]
|
||||
bboxes = augmented["bboxes"]
|
||||
labels = augmented["labels"]
|
||||
masks = augmented["masks"]
|
||||
|
||||
bboxes = torch.as_tensor(bboxes, dtype=torch.int64)
|
||||
labels = torch.as_tensor(labels, dtype=torch.int64) # int64 required by torchvision maskrcnn
|
||||
masks = torch.stack(masks) # stack masks, required by torchvision maskrcnn
|
||||
|
||||
area = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
|
||||
image_id = torch.tensor([idx])
|
||||
iscrowd = torch.zeros((num_objs,), dtype=torch.int64) # assume all instances are not crowd
|
||||
|
||||
target = {
|
||||
"boxes": bboxes,
|
||||
"labels": labels,
|
||||
"masks": masks,
|
||||
"area": area,
|
||||
"image_id": image_id,
|
||||
"iscrowd": iscrowd,
|
||||
}
|
||||
|
||||
return image, target
|
||||
|
|
Loading…
Reference in a new issue