diff --git a/src/data/dataset.py b/src/data/dataset.py index 4b8fff3..595e00c 100644 --- a/src/data/dataset.py +++ b/src/data/dataset.py @@ -1,6 +1,6 @@ import os -from pathlib import Path +import albumentations as A import numpy as np import torch from PIL import Image @@ -16,6 +16,8 @@ class RealDataset(Dataset): self.imgs = list(sorted(os.listdir(os.path.join(root, "images")))) self.masks = list(sorted(os.listdir(os.path.join(root, "masks")))) + self.res = A.SmallestMaxSize(max_size=1024) + def __getitem__(self, idx): # create paths from ids image_path = os.path.join(self.root, "images", self.imgs[idx]) @@ -23,11 +25,16 @@ class RealDataset(Dataset): # load image and mask image = Image.open(image_path).convert("RGB") - mask = Image.open(mask_path) + mask = Image.open(mask_path).convert("L") # convert to numpy arrays - image = np.array(image) - mask = np.array(mask) + image = np.asarray(image) + mask = np.asarray(mask) + + # resize images + aug = self.res(image=image, mask=mask) + image = aug["image"] + mask = aug["mask"] # get ids from mask obj_ids = np.unique(mask)