import logging from os import listdir from os.path import splitext from pathlib import Path import albumentations as A import numpy as np import torch from PIL import Image from torch.utils.data import Dataset class SphereDataset(Dataset): def __init__(self, images_dir: str, transform: A.Compose, masks_dir: str = None): self.images_dir = Path(images_dir) self.masks_dir = Path(masks_dir) if masks_dir else None self.ids = [splitext(file)[0] for file in listdir(images_dir) if not file.startswith(".")] if not self.ids: raise RuntimeError(f"No input file found in {images_dir}, make sure you put your images there") logging.info(f"Creating dataset with {len(self.ids)} examples") def __len__(self): return len(self.ids) @staticmethod def preprocess(pil_img, scale, is_mask): w, h = pil_img.size newW, newH = int(scale * w), int(scale * h) assert newW > 0 and newH > 0, "Scale is too small, resized images would have no pixel" pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC) img_ndarray = np.asarray(pil_img) if not is_mask: if img_ndarray.ndim == 2: img_ndarray = img_ndarray[np.newaxis, ...] else: img_ndarray = img_ndarray.transpose((2, 0, 1)) img_ndarray = img_ndarray / 255 return img_ndarray @staticmethod def load(filename): ext = splitext(filename)[1] if ext in [".npz", ".npy"]: return Image.fromarray(np.load(filename)) elif ext in [".pt", ".pth"]: return Image.fromarray(torch.load(filename).numpy()) else: return Image.open(filename) def __getitem__(self, idx): name = self.ids[idx] mask_file = list(self.masks_dir.glob(name + self.mask_suffix + ".*")) img_file = list(self.images_dir.glob(name + ".*")) assert len(img_file) == 1, f"Either no image or multiple images found for the ID {name}: {img_file}" assert len(mask_file) == 1, f"Either no mask or multiple masks found for the ID {name}: {mask_file}" mask = self.load(mask_file[0]) img = self.load(img_file[0]) assert ( img.size == mask.size ), f"Image and mask {name} should be the same size, but are {img.size} and {mask.size}" img = self.preprocess(img, self.scale, is_mask=False) mask = self.preprocess(mask, self.scale, is_mask=True) return { "image": torch.as_tensor(img.copy()).float().contiguous(), "mask": torch.as_tensor(mask.copy()).long().contiguous(), }