
82 lines
2.7 KiB
Raw Normal View History

import logging
from os import listdir
from os.path import splitext
from pathlib import Path
import albumentations as A
import numpy as np
import torch
from PIL import Image
from import Dataset
class SphereDataset(Dataset):
def __init__(self, images_dir: str, transform: A.Compose, masks_dir: str = None):
self.images_dir = Path(images_dir)
self.masks_dir = Path(masks_dir) if masks_dir else None
self.ids = [splitext(file)[0] for file in listdir(images_dir) if not file.startswith(".")]
if not self.ids:
raise RuntimeError(f"No input file found in {images_dir}, make sure you put your images there")"Creating dataset with {len(self.ids)} examples")
def __len__(self):
return len(self.ids)
def preprocess(pil_img, scale, is_mask):
w, h = pil_img.size
newW, newH = int(scale * w), int(scale * h)
assert newW > 0 and newH > 0, "Scale is too small, resized images would have no pixel"
pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC)
img_ndarray = np.asarray(pil_img)
if not is_mask:
if img_ndarray.ndim == 2:
img_ndarray = img_ndarray[np.newaxis, ...]
img_ndarray = img_ndarray.transpose((2, 0, 1))
img_ndarray = img_ndarray / 255
return img_ndarray
def load(filename):
ext = splitext(filename)[1]
if ext in [".npz", ".npy"]:
return Image.fromarray(np.load(filename))
elif ext in [".pt", ".pth"]:
return Image.fromarray(torch.load(filename).numpy())
def __getitem__(self, idx):
name = self.ids[idx]
mask_file = list(self.masks_dir.glob(name + self.mask_suffix + ".*"))
img_file = list(self.images_dir.glob(name + ".*"))
assert len(img_file) == 1, f"Either no image or multiple images found for the ID {name}: {img_file}"
assert len(mask_file) == 1, f"Either no mask or multiple masks found for the ID {name}: {mask_file}"
mask = self.load(mask_file[0])
img = self.load(img_file[0])
assert (
img.size == mask.size
), f"Image and mask {name} should be the same size, but are {img.size} and {mask.size}"
img = self.preprocess(img, self.scale, is_mask=False)
mask = self.preprocess(mask, self.scale, is_mask=True)
return {
"image": torch.as_tensor(img.copy()).float().contiguous(),
"mask": torch.as_tensor(mask.copy()).long().contiguous(),