fix scaling issue by size change from net, i.e. Issues - error in predict.py #12
Former-commit-id: 4af6e54bd4e1c642282fe512d97aa5a09ce16fca
This commit is contained in:
parent
af08dbb98d
commit
ff85f85dcc
32
predict.py
32
predict.py
|
@ -11,6 +11,8 @@ from unet import UNet
|
||||||
from utils import resize_and_crop, normalize, split_img_into_squares, hwc_to_chw, merge_masks, dense_crf
|
from utils import resize_and_crop, normalize, split_img_into_squares, hwc_to_chw, merge_masks, dense_crf
|
||||||
from utils import plot_img_and_mask
|
from utils import plot_img_and_mask
|
||||||
|
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
def predict_img(net,
|
def predict_img(net,
|
||||||
full_img,
|
full_img,
|
||||||
scale_factor=0.5,
|
scale_factor=0.5,
|
||||||
|
@ -40,11 +42,25 @@ def predict_img(net,
|
||||||
output_left = net(X_left)
|
output_left = net(X_left)
|
||||||
output_right = net(X_right)
|
output_right = net(X_right)
|
||||||
|
|
||||||
left_probs = F.sigmoid(output_left)
|
left_probs = F.sigmoid(output_left).squeeze(0)
|
||||||
right_probs = F.sigmoid(output_right)
|
right_probs = F.sigmoid(output_right).squeeze(0)
|
||||||
|
|
||||||
|
tf = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.ToPILImage(),
|
||||||
|
transforms.Resize(img_height),
|
||||||
|
transforms.ToTensor()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
left_probs = tf(left_probs.cpu()).unsqueeze(0)
|
||||||
|
right_probs = tf(right_probs.cpu()).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
# left_probs = F.upsample(left_probs, size=(img_height, img_height))
|
||||||
|
# right_probs = F.upsample(right_probs, size=(img_height, img_height))
|
||||||
|
|
||||||
|
|
||||||
left_probs = F.upsample(left_probs, size=(img_height, img_height))
|
|
||||||
right_probs = F.upsample(right_probs, size=(img_height, img_height))
|
|
||||||
|
|
||||||
left_mask_np = left_probs.squeeze().cpu().numpy()
|
left_mask_np = left_probs.squeeze().cpu().numpy()
|
||||||
right_mask_np = right_probs.squeeze().cpu().numpy()
|
right_mask_np = right_probs.squeeze().cpu().numpy()
|
||||||
|
@ -64,10 +80,11 @@ def get_args():
|
||||||
metavar='FILE',
|
metavar='FILE',
|
||||||
help="Specify the file in which is stored the model"
|
help="Specify the file in which is stored the model"
|
||||||
" (default : 'MODEL.pth')")
|
" (default : 'MODEL.pth')")
|
||||||
parser.add_argument('--input', '-i', metavar='INPUT', nargs='+',
|
parser.add_argument('--input', '-i', default=['test.jpg'], metavar='INPUT', nargs='+',
|
||||||
help='filenames of input images', required=True)
|
help='filenames of input images', required=False)
|
||||||
|
|
||||||
parser.add_argument('--output', '-o', metavar='INPUT', nargs='+',
|
parser.add_argument('--output', '-o', metavar='INPUT', nargs='+',
|
||||||
help='filenames of ouput images')
|
help='filenames of ouput images',default=['opt.jpg'])
|
||||||
parser.add_argument('--cpu', '-c', action='store_true',
|
parser.add_argument('--cpu', '-c', action='store_true',
|
||||||
help="Do not use the cuda version of the net",
|
help="Do not use the cuda version of the net",
|
||||||
default=False)
|
default=False)
|
||||||
|
@ -91,6 +108,7 @@ def get_args():
|
||||||
|
|
||||||
def get_output_filenames(args):
|
def get_output_filenames(args):
|
||||||
in_files = args.input
|
in_files = args.input
|
||||||
|
# in_files = 'img.jpg'
|
||||||
out_files = []
|
out_files = []
|
||||||
|
|
||||||
if not args.output:
|
if not args.output:
|
||||||
|
|
Loading…
Reference in a new issue