From 67c4dd34a04911c8c6073514d8f1055782746607 Mon Sep 17 00:00:00 2001 From: milesial Date: Wed, 11 Dec 2019 21:57:45 +0100 Subject: [PATCH] Use classmethod for preprocess Former-commit-id: ebe3bbb3ca67502db407eaa8273071a871d4b744 --- predict.py | 3 +-- utils/dataset.py | 9 +++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/predict.py b/predict.py index cb3c1c2..95112b0 100755 --- a/predict.py +++ b/predict.py @@ -22,8 +22,7 @@ def predict_img(net, use_dense_crf=False): net.eval() - ds = BasicDataset('', '', scale=scale_factor) - img = torch.from_numpy(ds.preprocess(full_img)) + img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor)) img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) diff --git a/utils/dataset.py b/utils/dataset.py index 57c069a..c290eeb 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -22,9 +22,10 @@ class BasicDataset(Dataset): def __len__(self): return len(self.ids) - def preprocess(self, pil_img): + @classmethod + def preprocess(cls, pil_img, scale): w, h = pil_img.size - newW, newH = int(self.scale * w), int(self.scale * h) + newW, newH = int(scale * w), int(scale * h) assert newW > 0 and newH > 0, 'Scale is too small' pil_img = pil_img.resize((newW, newH)) @@ -55,7 +56,7 @@ class BasicDataset(Dataset): assert img.size == mask.size, \ f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}' - img = self.preprocess(img) - mask = self.preprocess(mask) + img = self.preprocess(img, self.scale) + mask = self.preprocess(mask, self.scale) return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask)}