mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
Fixed some details on predict + requirements
Former-commit-id: d04263be694bba9dc1cfeb96389a3ad2a4015841
This commit is contained in:
parent
ff85f85dcc
commit
bc465b2d91
17
predict.py
Normal file → Executable file
17
predict.py
Normal file → Executable file
|
@ -53,14 +53,8 @@ def predict_img(net,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
left_probs = tf(left_probs.cpu()).unsqueeze(0)
|
left_probs = tf(left_probs.cpu())
|
||||||
right_probs = tf(right_probs.cpu()).unsqueeze(0)
|
right_probs = tf(right_probs.cpu())
|
||||||
|
|
||||||
|
|
||||||
# left_probs = F.upsample(left_probs, size=(img_height, img_height))
|
|
||||||
# right_probs = F.upsample(right_probs, size=(img_height, img_height))
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
left_mask_np = left_probs.squeeze().cpu().numpy()
|
left_mask_np = left_probs.squeeze().cpu().numpy()
|
||||||
right_mask_np = right_probs.squeeze().cpu().numpy()
|
right_mask_np = right_probs.squeeze().cpu().numpy()
|
||||||
|
@ -80,11 +74,11 @@ def get_args():
|
||||||
metavar='FILE',
|
metavar='FILE',
|
||||||
help="Specify the file in which is stored the model"
|
help="Specify the file in which is stored the model"
|
||||||
" (default : 'MODEL.pth')")
|
" (default : 'MODEL.pth')")
|
||||||
parser.add_argument('--input', '-i', default=['test.jpg'], metavar='INPUT', nargs='+',
|
parser.add_argument('--input', '-i', metavar='INPUT', nargs='+',
|
||||||
help='filenames of input images', required=False)
|
help='filenames of input images', required=True)
|
||||||
|
|
||||||
parser.add_argument('--output', '-o', metavar='INPUT', nargs='+',
|
parser.add_argument('--output', '-o', metavar='INPUT', nargs='+',
|
||||||
help='filenames of ouput images',default=['opt.jpg'])
|
help='filenames of ouput images')
|
||||||
parser.add_argument('--cpu', '-c', action='store_true',
|
parser.add_argument('--cpu', '-c', action='store_true',
|
||||||
help="Do not use the cuda version of the net",
|
help="Do not use the cuda version of the net",
|
||||||
default=False)
|
default=False)
|
||||||
|
@ -108,7 +102,6 @@ def get_args():
|
||||||
|
|
||||||
def get_output_filenames(args):
|
def get_output_filenames(args):
|
||||||
in_files = args.input
|
in_files = args.input
|
||||||
# in_files = 'img.jpg'
|
|
||||||
out_files = []
|
out_files = []
|
||||||
|
|
||||||
if not args.output:
|
if not args.output:
|
||||||
|
|
|
@ -3,3 +3,4 @@ pydensecrf
|
||||||
numpy
|
numpy
|
||||||
Pillow
|
Pillow
|
||||||
torch
|
torch
|
||||||
|
torchvision
|
||||||
|
|
Loading…
Reference in a new issue