Simplifying

- avoid using classmethod if the class argument is not used, use staticmethod instead
- avoid testing for is_mask three times if one is enough

Former-commit-id: 1116e3096cd5210a0596d1f3938989016e0f3b6c
This commit is contained in:
Gouvernathor 2022-03-31 16:23:37 +02:00 committed by GitHub
parent 2ca43802cc
commit 7f800aff75

View file

@ -25,26 +25,26 @@ class BasicDataset(Dataset):
def __len__(self): def __len__(self):
return len(self.ids) return len(self.ids)
@classmethod @staticmethod
def preprocess(cls, pil_img, scale, is_mask): def preprocess(pil_img, scale, is_mask):
w, h = pil_img.size w, h = pil_img.size
newW, newH = int(scale * w), int(scale * h) newW, newH = int(scale * w), int(scale * h)
assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel' assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel'
pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC) pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC)
img_ndarray = np.asarray(pil_img) img_ndarray = np.asarray(pil_img)
if img_ndarray.ndim == 2 and not is_mask:
img_ndarray = img_ndarray[np.newaxis, ...]
elif not is_mask:
img_ndarray = img_ndarray.transpose((2, 0, 1))
if not is_mask: if not is_mask:
if img_ndarray.ndim == 2:
img_ndarray = img_ndarray[np.newaxis, ...]
else:
img_ndarray = img_ndarray.transpose((2, 0, 1))
img_ndarray = img_ndarray / 255 img_ndarray = img_ndarray / 255
return img_ndarray return img_ndarray
@classmethod @staticmethod
def load(cls, filename): def load(filename):
ext = splitext(filename)[1] ext = splitext(filename)[1]
if ext in ['.npz', '.npy']: if ext in ['.npz', '.npy']:
return Image.fromarray(np.load(filename)) return Image.fromarray(np.load(filename))