Use classmethod for preprocess
Former-commit-id: ebe3bbb3ca67502db407eaa8273071a871d4b744
This commit is contained in:
parent
0d1cc25ae2
commit
67c4dd34a0
|
@ -22,8 +22,7 @@ def predict_img(net,
|
|||
use_dense_crf=False):
|
||||
net.eval()
|
||||
|
||||
ds = BasicDataset('', '', scale=scale_factor)
|
||||
img = torch.from_numpy(ds.preprocess(full_img))
|
||||
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))
|
||||
|
||||
img = img.unsqueeze(0)
|
||||
img = img.to(device=device, dtype=torch.float32)
|
||||
|
|
|
@ -22,9 +22,10 @@ class BasicDataset(Dataset):
|
|||
def __len__(self):
|
||||
return len(self.ids)
|
||||
|
||||
def preprocess(self, pil_img):
|
||||
@classmethod
|
||||
def preprocess(cls, pil_img, scale):
|
||||
w, h = pil_img.size
|
||||
newW, newH = int(self.scale * w), int(self.scale * h)
|
||||
newW, newH = int(scale * w), int(scale * h)
|
||||
assert newW > 0 and newH > 0, 'Scale is too small'
|
||||
pil_img = pil_img.resize((newW, newH))
|
||||
|
||||
|
@ -55,7 +56,7 @@ class BasicDataset(Dataset):
|
|||
assert img.size == mask.size, \
|
||||
f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}'
|
||||
|
||||
img = self.preprocess(img)
|
||||
mask = self.preprocess(mask)
|
||||
img = self.preprocess(img, self.scale)
|
||||
mask = self.preprocess(mask, self.scale)
|
||||
|
||||
return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask)}
|
||||
|
|
Loading…
Reference in a new issue