mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 23:12:05 +00:00
feat: create new datasets for generation and loading of pre rendered images
Former-commit-id: 4ccd97c3583c7def1e5e988a5beb8fcda7545fc9 [formerly 6a7e1eb28ade30f3e4b76337c209ac7e5b3b1cbf] Former-commit-id: 8a552d696c6e47378a36a8680e3a09d151de25f1
This commit is contained in:
parent
c6c08ac98a
commit
0693f02d83
|
@ -2,9 +2,9 @@ import albumentations as A
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import wandb
|
import wandb
|
||||||
from albumentations.pytorch import ToTensorV2
|
from albumentations.pytorch import ToTensorV2
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader, Subset
|
||||||
|
|
||||||
from .dataset import RealDataset
|
from .dataset import LabeledDataset, RealDataset
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(batch):
|
def collate_fn(batch):
|
||||||
|
@ -18,13 +18,13 @@ class Spheres(pl.LightningDataModule):
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
transforms = A.Compose(
|
transforms = A.Compose(
|
||||||
[
|
[
|
||||||
A.Flip(),
|
# A.Flip(),
|
||||||
A.ColorJitter(),
|
# A.ColorJitter(),
|
||||||
A.ToGray(p=0.01),
|
# A.ToGray(p=0.01),
|
||||||
A.GaussianBlur(),
|
# A.GaussianBlur(),
|
||||||
A.MotionBlur(),
|
# A.MotionBlur(),
|
||||||
A.ISONoise(),
|
# A.ISONoise(),
|
||||||
A.ImageCompression(),
|
# A.ImageCompression(),
|
||||||
A.Normalize(
|
A.Normalize(
|
||||||
mean=[0.485, 0.456, 0.406],
|
mean=[0.485, 0.456, 0.406],
|
||||||
std=[0.229, 0.224, 0.225],
|
std=[0.229, 0.224, 0.225],
|
||||||
|
@ -40,7 +40,9 @@ class Spheres(pl.LightningDataModule):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = RealDataset(root="/dev/shm/TRAIN/", transforms=transforms)
|
dataset = LabeledDataset("/dev/shm/TRAIN/", transforms)
|
||||||
|
# dataset = Subset(dataset, range(6 * 200)) # subset for debugging purpose
|
||||||
|
# dataset = Subset(dataset, [0] * 320) # overfit test
|
||||||
|
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -7,6 +8,37 @@ from PIL import Image
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class SyntheticDataset(Dataset):
|
||||||
|
def __init__(self, image_dir, transform):
|
||||||
|
self.images = list(Path(image_dir).glob("**/*.jpg"))
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.images)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
# open and convert image
|
||||||
|
image = np.ascontiguousarray(
|
||||||
|
Image.open(
|
||||||
|
self.images[index],
|
||||||
|
).convert("RGB"),
|
||||||
|
dtype=np.uint8,
|
||||||
|
)
|
||||||
|
|
||||||
|
# create empty mask of same size
|
||||||
|
mask = np.zeros(
|
||||||
|
(*image.shape[:2], 4),
|
||||||
|
dtype=np.uint8,
|
||||||
|
)
|
||||||
|
|
||||||
|
# augment image and mask
|
||||||
|
augmentations = self.transform(image=image, mask=mask)
|
||||||
|
image = augmentations["image"]
|
||||||
|
mask = augmentations["mask"]
|
||||||
|
|
||||||
|
return image, mask
|
||||||
|
|
||||||
|
|
||||||
class RealDataset(Dataset):
|
class RealDataset(Dataset):
|
||||||
def __init__(self, root, transforms=None):
|
def __init__(self, root, transforms=None):
|
||||||
self.root = root
|
self.root = root
|
||||||
|
@ -16,7 +48,7 @@ class RealDataset(Dataset):
|
||||||
self.imgs = list(sorted(os.listdir(os.path.join(root, "images"))))
|
self.imgs = list(sorted(os.listdir(os.path.join(root, "images"))))
|
||||||
self.masks = list(sorted(os.listdir(os.path.join(root, "masks"))))
|
self.masks = list(sorted(os.listdir(os.path.join(root, "masks"))))
|
||||||
|
|
||||||
self.res = A.SmallestMaxSize(max_size=1024)
|
self.res = A.LongestMaxSize(max_size=1024)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
# create paths from ids
|
# create paths from ids
|
||||||
|
@ -97,3 +129,83 @@ class RealDataset(Dataset):
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.imgs)
|
return len(self.imgs)
|
||||||
|
|
||||||
|
|
||||||
|
class LabeledDataset(Dataset):
|
||||||
|
def __init__(self, image_dir, transforms):
|
||||||
|
self.images = list(Path(image_dir).glob("**/*.jpg"))
|
||||||
|
self.transforms = transforms
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.images)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# open and convert image
|
||||||
|
image = np.ascontiguousarray(
|
||||||
|
Image.open(self.images[idx]).convert("RGB"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# open and convert mask
|
||||||
|
mask_path = self.images[idx].parent.joinpath("MASK.PNG")
|
||||||
|
mask = np.ascontiguousarray(
|
||||||
|
Image.open(mask_path).convert("L"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# get ids from mask
|
||||||
|
obj_ids = np.unique(mask)
|
||||||
|
obj_ids = obj_ids[1:] # first id is the background, so remove it
|
||||||
|
|
||||||
|
# split the color-encoded mask into a set of binary masks
|
||||||
|
masks = mask == obj_ids[:, None, None]
|
||||||
|
masks = masks.astype(np.uint8) # cast to uint8 for albumentations
|
||||||
|
|
||||||
|
# create bboxes from masks (pascal format)
|
||||||
|
num_objs = len(obj_ids)
|
||||||
|
bboxes = []
|
||||||
|
for i in range(num_objs):
|
||||||
|
pos = np.where(masks[i])
|
||||||
|
xmin = np.min(pos[1])
|
||||||
|
xmax = np.max(pos[1])
|
||||||
|
ymin = np.min(pos[0])
|
||||||
|
ymax = np.max(pos[0])
|
||||||
|
bboxes.append([xmin, ymin, xmax, ymax])
|
||||||
|
|
||||||
|
# convert arrays for albumentations
|
||||||
|
bboxes = torch.as_tensor(bboxes, dtype=torch.int64)
|
||||||
|
labels = torch.ones((num_objs,), dtype=torch.int64) # assume there is only one class (id=1)
|
||||||
|
masks = list(np.asarray(masks))
|
||||||
|
|
||||||
|
if self.transforms is not None:
|
||||||
|
# arrange transform data
|
||||||
|
data = {
|
||||||
|
"image": image,
|
||||||
|
"labels": labels,
|
||||||
|
"bboxes": bboxes,
|
||||||
|
"masks": masks,
|
||||||
|
}
|
||||||
|
# apply transform
|
||||||
|
augmented = self.transforms(**data)
|
||||||
|
# get augmented data
|
||||||
|
image = augmented["image"]
|
||||||
|
bboxes = augmented["bboxes"]
|
||||||
|
labels = augmented["labels"]
|
||||||
|
masks = augmented["masks"]
|
||||||
|
|
||||||
|
bboxes = torch.as_tensor(bboxes, dtype=torch.int64)
|
||||||
|
labels = torch.as_tensor(labels, dtype=torch.int64) # int64 required by torchvision maskrcnn
|
||||||
|
masks = torch.stack(masks) # stack masks, required by torchvision maskrcnn
|
||||||
|
|
||||||
|
area = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
|
||||||
|
image_id = torch.tensor([idx])
|
||||||
|
iscrowd = torch.zeros((num_objs,), dtype=torch.int64) # assume all instances are not crowd
|
||||||
|
|
||||||
|
target = {
|
||||||
|
"boxes": bboxes,
|
||||||
|
"labels": labels,
|
||||||
|
"masks": masks,
|
||||||
|
"area": area,
|
||||||
|
"image_id": image_id,
|
||||||
|
"iscrowd": iscrowd,
|
||||||
|
}
|
||||||
|
|
||||||
|
return image, target
|
||||||
|
|
Loading…
Reference in a new issue