Use classmethod for preprocess

Former-commit-id: ebe3bbb3ca67502db407eaa8273071a871d4b744
This commit is contained in:
milesial 2019-12-11 21:57:45 +01:00
parent 0d1cc25ae2
commit 67c4dd34a0
2 changed files with 6 additions and 6 deletions

View file

@ -22,8 +22,7 @@ def predict_img(net,
use_dense_crf=False):
net.eval()
ds = BasicDataset('', '', scale=scale_factor)
img = torch.from_numpy(ds.preprocess(full_img))
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))
img = img.unsqueeze(0)
img = img.to(device=device, dtype=torch.float32)

View file

@ -22,9 +22,10 @@ class BasicDataset(Dataset):
def __len__(self):
return len(self.ids)
def preprocess(self, pil_img):
@classmethod
def preprocess(cls, pil_img, scale):
w, h = pil_img.size
newW, newH = int(self.scale * w), int(self.scale * h)
newW, newH = int(scale * w), int(scale * h)
assert newW > 0 and newH > 0, 'Scale is too small'
pil_img = pil_img.resize((newW, newH))
@ -55,7 +56,7 @@ class BasicDataset(Dataset):
assert img.size == mask.size, \
f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}'
img = self.preprocess(img)
mask = self.preprocess(mask)
img = self.preprocess(img, self.scale)
mask = self.preprocess(mask, self.scale)
return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask)}