mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 22:42:02 +00:00
61 lines
1.9 KiB
Python
61 lines
1.9 KiB
Python
|
from os.path import splitext
|
||
|
from os import listdir
|
||
|
import numpy as np
|
||
|
from glob import glob
|
||
|
import torch
|
||
|
from torch.utils.data import Dataset
|
||
|
import logging
|
||
|
from PIL import Image
|
||
|
|
||
|
|
||
|
class BasicDataset(Dataset):
|
||
|
def __init__(self, imgs_dir, masks_dir, scale=1):
|
||
|
self.imgs_dir = imgs_dir
|
||
|
self.masks_dir = masks_dir
|
||
|
self.scale = scale
|
||
|
assert 0 < scale <= 1, 'Scale must be between 0 and 1'
|
||
|
|
||
|
self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
|
||
|
if not file.startswith('.')]
|
||
|
logging.info(f'Creating dataset with {len(self.ids)} examples')
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.ids)
|
||
|
|
||
|
def preprocess(self, pil_img):
|
||
|
w, h = pil_img.size
|
||
|
newW, newH = int(self.scale * w), int(self.scale * h)
|
||
|
pil_img = pil_img.resize((newW, newH))
|
||
|
|
||
|
img_nd = np.array(pil_img)
|
||
|
|
||
|
if len(img_nd.shape) == 2:
|
||
|
img_nd = np.expand_dims(img_nd, axis=2)
|
||
|
|
||
|
# HWC to CHW
|
||
|
img_trans = img_nd.transpose((2, 0, 1))
|
||
|
if img_trans.max() > 1:
|
||
|
img_trans = img_trans / 255
|
||
|
|
||
|
return img_trans
|
||
|
|
||
|
def __getitem__(self, i):
|
||
|
idx = self.ids[i]
|
||
|
mask_file = glob(self.masks_dir + idx + '*')
|
||
|
img_file = glob(self.imgs_dir + idx + '*')
|
||
|
|
||
|
assert len(mask_file) == 1, \
|
||
|
f'Either no mask or multiple masks found for the ID {idx}: {mask_file}'
|
||
|
assert len(img_file) == 1, \
|
||
|
f'Either no image or multiple images found for the ID {idx}: {img_file}'
|
||
|
mask = Image.open(mask_file[0])
|
||
|
img = Image.open(img_file[0])
|
||
|
|
||
|
assert img.size == mask.size, \
|
||
|
f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}'
|
||
|
|
||
|
img = self.preprocess(img)
|
||
|
mask = self.preprocess(mask)
|
||
|
|
||
|
return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask)}
|