mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
Fixed some details on predict + requirements
Former-commit-id: d04263be694bba9dc1cfeb96389a3ad2a4015841
This commit is contained in:
parent
ff85f85dcc
commit
bc465b2d91
21
predict.py
Normal file → Executable file
21
predict.py
Normal file → Executable file
|
@ -33,7 +33,7 @@ def predict_img(net,
|
|||
|
||||
X_left = torch.from_numpy(left_square).unsqueeze(0)
|
||||
X_right = torch.from_numpy(right_square).unsqueeze(0)
|
||||
|
||||
|
||||
if use_gpu:
|
||||
X_left = X_left.cuda()
|
||||
X_right = X_right.cuda()
|
||||
|
@ -52,15 +52,9 @@ def predict_img(net,
|
|||
transforms.ToTensor()
|
||||
]
|
||||
)
|
||||
|
||||
left_probs = tf(left_probs.cpu()).unsqueeze(0)
|
||||
right_probs = tf(right_probs.cpu()).unsqueeze(0)
|
||||
|
||||
|
||||
# left_probs = F.upsample(left_probs, size=(img_height, img_height))
|
||||
# right_probs = F.upsample(right_probs, size=(img_height, img_height))
|
||||
|
||||
|
||||
|
||||
left_probs = tf(left_probs.cpu())
|
||||
right_probs = tf(right_probs.cpu())
|
||||
|
||||
left_mask_np = left_probs.squeeze().cpu().numpy()
|
||||
right_mask_np = right_probs.squeeze().cpu().numpy()
|
||||
|
@ -80,11 +74,11 @@ def get_args():
|
|||
metavar='FILE',
|
||||
help="Specify the file in which is stored the model"
|
||||
" (default : 'MODEL.pth')")
|
||||
parser.add_argument('--input', '-i', default=['test.jpg'], metavar='INPUT', nargs='+',
|
||||
help='filenames of input images', required=False)
|
||||
parser.add_argument('--input', '-i', metavar='INPUT', nargs='+',
|
||||
help='filenames of input images', required=True)
|
||||
|
||||
parser.add_argument('--output', '-o', metavar='INPUT', nargs='+',
|
||||
help='filenames of ouput images',default=['opt.jpg'])
|
||||
help='filenames of ouput images')
|
||||
parser.add_argument('--cpu', '-c', action='store_true',
|
||||
help="Do not use the cuda version of the net",
|
||||
default=False)
|
||||
|
@ -108,7 +102,6 @@ def get_args():
|
|||
|
||||
def get_output_filenames(args):
|
||||
in_files = args.input
|
||||
# in_files = 'img.jpg'
|
||||
out_files = []
|
||||
|
||||
if not args.output:
|
||||
|
|
|
@ -2,4 +2,5 @@ matplotlib
|
|||
pydensecrf
|
||||
numpy
|
||||
Pillow
|
||||
torch
|
||||
torch
|
||||
torchvision
|
||||
|
|
Loading…
Reference in a new issue