2021-08-16 00:53:00 +00:00
|
|
|
import logging
|
|
|
|
from os import listdir
|
|
|
|
from os.path import splitext
|
|
|
|
from pathlib import Path
|
|
|
|
|
2022-06-28 07:36:21 +00:00
|
|
|
import albumentations as A
|
2021-08-16 00:53:00 +00:00
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
from PIL import Image
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
|
|
|
2022-06-27 14:40:04 +00:00
|
|
|
class SphereDataset(Dataset):
|
2022-06-28 07:36:21 +00:00
|
|
|
def __init__(self, images_dir: str, transform: A.Compose, masks_dir: str = None):
|
2021-08-16 00:53:00 +00:00
|
|
|
self.images_dir = Path(images_dir)
|
2022-06-28 07:36:21 +00:00
|
|
|
self.masks_dir = Path(masks_dir) if masks_dir else None
|
2021-08-16 00:53:00 +00:00
|
|
|
|
2022-06-27 13:39:44 +00:00
|
|
|
self.ids = [splitext(file)[0] for file in listdir(images_dir) if not file.startswith(".")]
|
2022-06-28 07:36:21 +00:00
|
|
|
|
2021-08-16 00:53:00 +00:00
|
|
|
if not self.ids:
|
2022-06-27 13:39:44 +00:00
|
|
|
raise RuntimeError(f"No input file found in {images_dir}, make sure you put your images there")
|
2022-06-28 07:36:21 +00:00
|
|
|
|
2022-06-27 13:39:44 +00:00
|
|
|
logging.info(f"Creating dataset with {len(self.ids)} examples")
|
2021-08-16 00:53:00 +00:00
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.ids)
|
|
|
|
|
2022-03-31 14:23:37 +00:00
|
|
|
@staticmethod
|
|
|
|
def preprocess(pil_img, scale, is_mask):
|
2021-08-16 00:53:00 +00:00
|
|
|
w, h = pil_img.size
|
|
|
|
newW, newH = int(scale * w), int(scale * h)
|
2022-06-27 14:40:04 +00:00
|
|
|
|
2022-06-28 07:36:21 +00:00
|
|
|
assert newW > 0 and newH > 0, "Scale is too small, resized images would have no pixel"
|
2022-06-27 14:40:04 +00:00
|
|
|
|
2021-11-03 08:39:21 +00:00
|
|
|
pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC)
|
2021-08-16 00:53:00 +00:00
|
|
|
img_ndarray = np.asarray(pil_img)
|
|
|
|
|
|
|
|
if not is_mask:
|
2022-03-31 14:23:37 +00:00
|
|
|
if img_ndarray.ndim == 2:
|
|
|
|
img_ndarray = img_ndarray[np.newaxis, ...]
|
|
|
|
else:
|
|
|
|
img_ndarray = img_ndarray.transpose((2, 0, 1))
|
|
|
|
|
2021-08-16 00:53:00 +00:00
|
|
|
img_ndarray = img_ndarray / 255
|
|
|
|
|
|
|
|
return img_ndarray
|
|
|
|
|
2022-03-31 14:23:37 +00:00
|
|
|
@staticmethod
|
|
|
|
def load(filename):
|
2021-08-16 00:53:00 +00:00
|
|
|
ext = splitext(filename)[1]
|
2022-06-27 14:40:04 +00:00
|
|
|
|
2022-06-27 13:39:44 +00:00
|
|
|
if ext in [".npz", ".npy"]:
|
2021-08-16 00:53:00 +00:00
|
|
|
return Image.fromarray(np.load(filename))
|
2022-06-27 13:39:44 +00:00
|
|
|
elif ext in [".pt", ".pth"]:
|
2021-08-16 00:53:00 +00:00
|
|
|
return Image.fromarray(torch.load(filename).numpy())
|
|
|
|
else:
|
|
|
|
return Image.open(filename)
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
name = self.ids[idx]
|
2022-06-28 07:36:21 +00:00
|
|
|
|
2022-06-27 13:39:44 +00:00
|
|
|
mask_file = list(self.masks_dir.glob(name + self.mask_suffix + ".*"))
|
|
|
|
img_file = list(self.images_dir.glob(name + ".*"))
|
2021-08-16 00:53:00 +00:00
|
|
|
|
2022-06-28 07:36:21 +00:00
|
|
|
assert len(img_file) == 1, f"Either no image or multiple images found for the ID {name}: {img_file}"
|
|
|
|
assert len(mask_file) == 1, f"Either no mask or multiple masks found for the ID {name}: {mask_file}"
|
2022-06-27 14:40:04 +00:00
|
|
|
|
2021-08-16 00:53:00 +00:00
|
|
|
mask = self.load(mask_file[0])
|
|
|
|
img = self.load(img_file[0])
|
|
|
|
|
2022-06-27 13:39:44 +00:00
|
|
|
assert (
|
2022-06-28 07:36:21 +00:00
|
|
|
img.size == mask.size
|
|
|
|
), f"Image and mask {name} should be the same size, but are {img.size} and {mask.size}"
|
2021-08-16 00:53:00 +00:00
|
|
|
|
|
|
|
img = self.preprocess(img, self.scale, is_mask=False)
|
|
|
|
mask = self.preprocess(mask, self.scale, is_mask=True)
|
|
|
|
|
|
|
|
return {
|
2022-06-27 13:39:44 +00:00
|
|
|
"image": torch.as_tensor(img.copy()).float().contiguous(),
|
|
|
|
"mask": torch.as_tensor(mask.copy()).long().contiguous(),
|
2021-08-16 00:53:00 +00:00
|
|
|
}
|