mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-14 09:08:18 +00:00
012fca4715
Former-commit-id: de7507ff08510b48e6a0e11da849e0d1c94d3ac8
63 lines
2 KiB
Python
63 lines
2 KiB
Python
from os.path import splitext
|
|
from os import listdir
|
|
import numpy as np
|
|
from glob import glob
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
import logging
|
|
from PIL import Image
|
|
|
|
|
|
class BasicDataset(Dataset):
|
|
def __init__(self, imgs_dir, masks_dir, scale=1):
|
|
self.imgs_dir = imgs_dir
|
|
self.masks_dir = masks_dir
|
|
self.scale = scale
|
|
assert 0 < scale <= 1, 'Scale must be between 0 and 1'
|
|
|
|
self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
|
|
if not file.startswith('.')]
|
|
logging.info(f'Creating dataset with {len(self.ids)} examples')
|
|
|
|
def __len__(self):
|
|
return len(self.ids)
|
|
|
|
@classmethod
|
|
def preprocess(cls, pil_img, scale):
|
|
w, h = pil_img.size
|
|
newW, newH = int(scale * w), int(scale * h)
|
|
assert newW > 0 and newH > 0, 'Scale is too small'
|
|
pil_img = pil_img.resize((newW, newH))
|
|
|
|
img_nd = np.array(pil_img)
|
|
|
|
if len(img_nd.shape) == 2:
|
|
img_nd = np.expand_dims(img_nd, axis=2)
|
|
|
|
# HWC to CHW
|
|
img_trans = img_nd.transpose((2, 0, 1))
|
|
if img_trans.max() > 1:
|
|
img_trans = img_trans / 255
|
|
|
|
return img_trans
|
|
|
|
def __getitem__(self, i):
|
|
idx = self.ids[i]
|
|
mask_file = glob(self.masks_dir + idx + '*')
|
|
img_file = glob(self.imgs_dir + idx + '*')
|
|
|
|
assert len(mask_file) == 1, \
|
|
f'Either no mask or multiple masks found for the ID {idx}: {mask_file}'
|
|
assert len(img_file) == 1, \
|
|
f'Either no image or multiple images found for the ID {idx}: {img_file}'
|
|
mask = Image.open(mask_file[0])
|
|
img = Image.open(img_file[0])
|
|
|
|
assert img.size == mask.size, \
|
|
f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}'
|
|
|
|
img = self.preprocess(img, self.scale)
|
|
mask = self.preprocess(mask, self.scale)
|
|
|
|
return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask)}
|