diff --git a/utils/data_loading.py b/utils/data_loading.py index 2d6e4e6..63c4313 100644 --- a/utils/data_loading.py +++ b/utils/data_loading.py @@ -25,26 +25,26 @@ class BasicDataset(Dataset): def __len__(self): return len(self.ids) - @classmethod - def preprocess(cls, pil_img, scale, is_mask): + @staticmethod + def preprocess(pil_img, scale, is_mask): w, h = pil_img.size newW, newH = int(scale * w), int(scale * h) assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel' pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC) img_ndarray = np.asarray(pil_img) - if img_ndarray.ndim == 2 and not is_mask: - img_ndarray = img_ndarray[np.newaxis, ...] - elif not is_mask: - img_ndarray = img_ndarray.transpose((2, 0, 1)) - if not is_mask: + if img_ndarray.ndim == 2: + img_ndarray = img_ndarray[np.newaxis, ...] + else: + img_ndarray = img_ndarray.transpose((2, 0, 1)) + img_ndarray = img_ndarray / 255 return img_ndarray - @classmethod - def load(cls, filename): + @staticmethod + def load(filename): ext = splitext(filename)[1] if ext in ['.npz', '.npy']: return Image.fromarray(np.load(filename))