From 7f800aff75d275e2ecbf826c327f0f97d62e23c6 Mon Sep 17 00:00:00 2001 From: Gouvernathor Date: Thu, 31 Mar 2022 16:23:37 +0200 Subject: [PATCH] Simplifying - avoid using classmethod if the class argument is not used, use staticmethod instead - avoid testing for is_mask three times if one is enough Former-commit-id: 1116e3096cd5210a0596d1f3938989016e0f3b6c --- utils/data_loading.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/utils/data_loading.py b/utils/data_loading.py index 2d6e4e6..63c4313 100644 --- a/utils/data_loading.py +++ b/utils/data_loading.py @@ -25,26 +25,26 @@ class BasicDataset(Dataset): def __len__(self): return len(self.ids) - @classmethod - def preprocess(cls, pil_img, scale, is_mask): + @staticmethod + def preprocess(pil_img, scale, is_mask): w, h = pil_img.size newW, newH = int(scale * w), int(scale * h) assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel' pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC) img_ndarray = np.asarray(pil_img) - if img_ndarray.ndim == 2 and not is_mask: - img_ndarray = img_ndarray[np.newaxis, ...] - elif not is_mask: - img_ndarray = img_ndarray.transpose((2, 0, 1)) - if not is_mask: + if img_ndarray.ndim == 2: + img_ndarray = img_ndarray[np.newaxis, ...] + else: + img_ndarray = img_ndarray.transpose((2, 0, 1)) + img_ndarray = img_ndarray / 255 return img_ndarray - @classmethod - def load(cls, filename): + @staticmethod + def load(filename): ext = splitext(filename)[1] if ext in ['.npz', '.npy']: return Image.fromarray(np.load(filename))