From bc465b2d91a970e4d325e470d1a700722dccf5cd Mon Sep 17 00:00:00 2001 From: milesial Date: Sat, 23 Jun 2018 19:57:53 +0200 Subject: [PATCH] Fixed some details on predict + requirements Former-commit-id: d04263be694bba9dc1cfeb96389a3ad2a4015841 --- predict.py | 21 +++++++-------------- requirements.txt | 3 ++- 2 files changed, 9 insertions(+), 15 deletions(-) mode change 100644 => 100755 predict.py diff --git a/predict.py b/predict.py old mode 100644 new mode 100755 index d3ef938..176e7a1 --- a/predict.py +++ b/predict.py @@ -33,7 +33,7 @@ def predict_img(net, X_left = torch.from_numpy(left_square).unsqueeze(0) X_right = torch.from_numpy(right_square).unsqueeze(0) - + if use_gpu: X_left = X_left.cuda() X_right = X_right.cuda() @@ -52,15 +52,9 @@ def predict_img(net, transforms.ToTensor() ] ) - - left_probs = tf(left_probs.cpu()).unsqueeze(0) - right_probs = tf(right_probs.cpu()).unsqueeze(0) - - - # left_probs = F.upsample(left_probs, size=(img_height, img_height)) - # right_probs = F.upsample(right_probs, size=(img_height, img_height)) - - + + left_probs = tf(left_probs.cpu()) + right_probs = tf(right_probs.cpu()) left_mask_np = left_probs.squeeze().cpu().numpy() right_mask_np = right_probs.squeeze().cpu().numpy() @@ -80,11 +74,11 @@ def get_args(): metavar='FILE', help="Specify the file in which is stored the model" " (default : 'MODEL.pth')") - parser.add_argument('--input', '-i', default=['test.jpg'], metavar='INPUT', nargs='+', - help='filenames of input images', required=False) + parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', + help='filenames of input images', required=True) parser.add_argument('--output', '-o', metavar='INPUT', nargs='+', - help='filenames of ouput images',default=['opt.jpg']) + help='filenames of ouput images') parser.add_argument('--cpu', '-c', action='store_true', help="Do not use the cuda version of the net", default=False) @@ -108,7 +102,6 @@ def get_args(): def get_output_filenames(args): in_files = args.input - # in_files = 'img.jpg' out_files = [] if not args.output: diff --git a/requirements.txt b/requirements.txt index 2d8ee43..d22f05c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ matplotlib pydensecrf numpy Pillow -torch \ No newline at end of file +torch +torchvision