Fixed some details on predict + requirements

Former-commit-id: d04263be694bba9dc1cfeb96389a3ad2a4015841
This commit is contained in:
milesial 2018-06-23 19:57:53 +02:00
parent ff85f85dcc
commit bc465b2d91
2 changed files with 9 additions and 15 deletions

21
predict.py Normal file → Executable file
View file

@ -33,7 +33,7 @@ def predict_img(net,
X_left = torch.from_numpy(left_square).unsqueeze(0) X_left = torch.from_numpy(left_square).unsqueeze(0)
X_right = torch.from_numpy(right_square).unsqueeze(0) X_right = torch.from_numpy(right_square).unsqueeze(0)
if use_gpu: if use_gpu:
X_left = X_left.cuda() X_left = X_left.cuda()
X_right = X_right.cuda() X_right = X_right.cuda()
@ -52,15 +52,9 @@ def predict_img(net,
transforms.ToTensor() transforms.ToTensor()
] ]
) )
left_probs = tf(left_probs.cpu()).unsqueeze(0) left_probs = tf(left_probs.cpu())
right_probs = tf(right_probs.cpu()).unsqueeze(0) right_probs = tf(right_probs.cpu())
# left_probs = F.upsample(left_probs, size=(img_height, img_height))
# right_probs = F.upsample(right_probs, size=(img_height, img_height))
left_mask_np = left_probs.squeeze().cpu().numpy() left_mask_np = left_probs.squeeze().cpu().numpy()
right_mask_np = right_probs.squeeze().cpu().numpy() right_mask_np = right_probs.squeeze().cpu().numpy()
@ -80,11 +74,11 @@ def get_args():
metavar='FILE', metavar='FILE',
help="Specify the file in which is stored the model" help="Specify the file in which is stored the model"
" (default : 'MODEL.pth')") " (default : 'MODEL.pth')")
parser.add_argument('--input', '-i', default=['test.jpg'], metavar='INPUT', nargs='+', parser.add_argument('--input', '-i', metavar='INPUT', nargs='+',
help='filenames of input images', required=False) help='filenames of input images', required=True)
parser.add_argument('--output', '-o', metavar='INPUT', nargs='+', parser.add_argument('--output', '-o', metavar='INPUT', nargs='+',
help='filenames of ouput images',default=['opt.jpg']) help='filenames of ouput images')
parser.add_argument('--cpu', '-c', action='store_true', parser.add_argument('--cpu', '-c', action='store_true',
help="Do not use the cuda version of the net", help="Do not use the cuda version of the net",
default=False) default=False)
@ -108,7 +102,6 @@ def get_args():
def get_output_filenames(args): def get_output_filenames(args):
in_files = args.input in_files = args.input
# in_files = 'img.jpg'
out_files = [] out_files = []
if not args.output: if not args.output:

View file

@ -2,4 +2,5 @@ matplotlib
pydensecrf pydensecrf
numpy numpy
Pillow Pillow
torch torch
torchvision