From ff85f85dcca877a41340903ef6bd884b58af62e0 Mon Sep 17 00:00:00 2001 From: ZiyuanTonyZhang Date: Sun, 17 Jun 2018 02:31:42 -0400 Subject: [PATCH] fix scaling issue by size change from net, i.e. Issues - error in predict.py #12 Former-commit-id: 4af6e54bd4e1c642282fe512d97aa5a09ce16fca --- predict.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/predict.py b/predict.py index d963b04..d3ef938 100644 --- a/predict.py +++ b/predict.py @@ -11,6 +11,8 @@ from unet import UNet from utils import resize_and_crop, normalize, split_img_into_squares, hwc_to_chw, merge_masks, dense_crf from utils import plot_img_and_mask +from torchvision import transforms + def predict_img(net, full_img, scale_factor=0.5, @@ -40,11 +42,25 @@ def predict_img(net, output_left = net(X_left) output_right = net(X_right) - left_probs = F.sigmoid(output_left) - right_probs = F.sigmoid(output_right) + left_probs = F.sigmoid(output_left).squeeze(0) + right_probs = F.sigmoid(output_right).squeeze(0) + + tf = transforms.Compose( + [ + transforms.ToPILImage(), + transforms.Resize(img_height), + transforms.ToTensor() + ] + ) + + left_probs = tf(left_probs.cpu()).unsqueeze(0) + right_probs = tf(right_probs.cpu()).unsqueeze(0) + + + # left_probs = F.upsample(left_probs, size=(img_height, img_height)) + # right_probs = F.upsample(right_probs, size=(img_height, img_height)) + - left_probs = F.upsample(left_probs, size=(img_height, img_height)) - right_probs = F.upsample(right_probs, size=(img_height, img_height)) left_mask_np = left_probs.squeeze().cpu().numpy() right_mask_np = right_probs.squeeze().cpu().numpy() @@ -64,10 +80,11 @@ def get_args(): metavar='FILE', help="Specify the file in which is stored the model" " (default : 'MODEL.pth')") - parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', - help='filenames of input images', required=True) + parser.add_argument('--input', '-i', default=['test.jpg'], metavar='INPUT', nargs='+', + help='filenames of input images', required=False) + parser.add_argument('--output', '-o', metavar='INPUT', nargs='+', - help='filenames of ouput images') + help='filenames of ouput images',default=['opt.jpg']) parser.add_argument('--cpu', '-c', action='store_true', help="Do not use the cuda version of the net", default=False) @@ -91,6 +108,7 @@ def get_args(): def get_output_filenames(args): in_files = args.input + # in_files = 'img.jpg' out_files = [] if not args.output: