Merge pull request #13 from ZiyuanTonyZhang/master
fix scaling issue Former-commit-id: b282e761f04904697418c9e759d761004a045428
This commit is contained in:
commit
14e0ff1c62
19
predict.py
Normal file → Executable file
19
predict.py
Normal file → Executable file
|
@ -11,6 +11,8 @@ from unet import UNet
|
||||||
from utils import resize_and_crop, normalize, split_img_into_squares, hwc_to_chw, merge_masks, dense_crf
|
from utils import resize_and_crop, normalize, split_img_into_squares, hwc_to_chw, merge_masks, dense_crf
|
||||||
from utils import plot_img_and_mask
|
from utils import plot_img_and_mask
|
||||||
|
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
def predict_img(net,
|
def predict_img(net,
|
||||||
full_img,
|
full_img,
|
||||||
scale_factor=0.5,
|
scale_factor=0.5,
|
||||||
|
@ -40,11 +42,19 @@ def predict_img(net,
|
||||||
output_left = net(X_left)
|
output_left = net(X_left)
|
||||||
output_right = net(X_right)
|
output_right = net(X_right)
|
||||||
|
|
||||||
left_probs = F.sigmoid(output_left)
|
left_probs = F.sigmoid(output_left).squeeze(0)
|
||||||
right_probs = F.sigmoid(output_right)
|
right_probs = F.sigmoid(output_right).squeeze(0)
|
||||||
|
|
||||||
left_probs = F.upsample(left_probs, size=(img_height, img_height))
|
tf = transforms.Compose(
|
||||||
right_probs = F.upsample(right_probs, size=(img_height, img_height))
|
[
|
||||||
|
transforms.ToPILImage(),
|
||||||
|
transforms.Resize(img_height),
|
||||||
|
transforms.ToTensor()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
left_probs = tf(left_probs.cpu())
|
||||||
|
right_probs = tf(right_probs.cpu())
|
||||||
|
|
||||||
left_mask_np = left_probs.squeeze().cpu().numpy()
|
left_mask_np = left_probs.squeeze().cpu().numpy()
|
||||||
right_mask_np = right_probs.squeeze().cpu().numpy()
|
right_mask_np = right_probs.squeeze().cpu().numpy()
|
||||||
|
@ -66,6 +76,7 @@ def get_args():
|
||||||
" (default : 'MODEL.pth')")
|
" (default : 'MODEL.pth')")
|
||||||
parser.add_argument('--input', '-i', metavar='INPUT', nargs='+',
|
parser.add_argument('--input', '-i', metavar='INPUT', nargs='+',
|
||||||
help='filenames of input images', required=True)
|
help='filenames of input images', required=True)
|
||||||
|
|
||||||
parser.add_argument('--output', '-o', metavar='INPUT', nargs='+',
|
parser.add_argument('--output', '-o', metavar='INPUT', nargs='+',
|
||||||
help='filenames of ouput images')
|
help='filenames of ouput images')
|
||||||
parser.add_argument('--cpu', '-c', action='store_true',
|
parser.add_argument('--cpu', '-c', action='store_true',
|
||||||
|
|
|
@ -3,3 +3,4 @@ pydensecrf
|
||||||
numpy
|
numpy
|
||||||
Pillow
|
Pillow
|
||||||
torch
|
torch
|
||||||
|
torchvision
|
||||||
|
|
Loading…
Reference in a new issue