mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
Simplifying
- avoid using classmethod if the class argument is not used, use staticmethod instead - avoid testing for is_mask three times if one is enough Former-commit-id: 1116e3096cd5210a0596d1f3938989016e0f3b6c
This commit is contained in:
parent
2ca43802cc
commit
7f800aff75
|
@ -25,26 +25,26 @@ class BasicDataset(Dataset):
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.ids)
|
return len(self.ids)
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def preprocess(cls, pil_img, scale, is_mask):
|
def preprocess(pil_img, scale, is_mask):
|
||||||
w, h = pil_img.size
|
w, h = pil_img.size
|
||||||
newW, newH = int(scale * w), int(scale * h)
|
newW, newH = int(scale * w), int(scale * h)
|
||||||
assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel'
|
assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel'
|
||||||
pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC)
|
pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC)
|
||||||
img_ndarray = np.asarray(pil_img)
|
img_ndarray = np.asarray(pil_img)
|
||||||
|
|
||||||
if img_ndarray.ndim == 2 and not is_mask:
|
|
||||||
img_ndarray = img_ndarray[np.newaxis, ...]
|
|
||||||
elif not is_mask:
|
|
||||||
img_ndarray = img_ndarray.transpose((2, 0, 1))
|
|
||||||
|
|
||||||
if not is_mask:
|
if not is_mask:
|
||||||
|
if img_ndarray.ndim == 2:
|
||||||
|
img_ndarray = img_ndarray[np.newaxis, ...]
|
||||||
|
else:
|
||||||
|
img_ndarray = img_ndarray.transpose((2, 0, 1))
|
||||||
|
|
||||||
img_ndarray = img_ndarray / 255
|
img_ndarray = img_ndarray / 255
|
||||||
|
|
||||||
return img_ndarray
|
return img_ndarray
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def load(cls, filename):
|
def load(filename):
|
||||||
ext = splitext(filename)[1]
|
ext = splitext(filename)[1]
|
||||||
if ext in ['.npz', '.npy']:
|
if ext in ['.npz', '.npy']:
|
||||||
return Image.fromarray(np.load(filename))
|
return Image.fromarray(np.load(filename))
|
||||||
|
|
Loading…
Reference in a new issue