mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
Merge pull request #13 from ZiyuanTonyZhang/master
fix scaling issue Former-commit-id: b282e761f04904697418c9e759d761004a045428
This commit is contained in:
commit
14e0ff1c62
21
predict.py
Normal file → Executable file
21
predict.py
Normal file → Executable file
|
@ -11,6 +11,8 @@ from unet import UNet
|
|||
from utils import resize_and_crop, normalize, split_img_into_squares, hwc_to_chw, merge_masks, dense_crf
|
||||
from utils import plot_img_and_mask
|
||||
|
||||
from torchvision import transforms
|
||||
|
||||
def predict_img(net,
|
||||
full_img,
|
||||
scale_factor=0.5,
|
||||
|
@ -31,7 +33,7 @@ def predict_img(net,
|
|||
|
||||
X_left = torch.from_numpy(left_square).unsqueeze(0)
|
||||
X_right = torch.from_numpy(right_square).unsqueeze(0)
|
||||
|
||||
|
||||
if use_gpu:
|
||||
X_left = X_left.cuda()
|
||||
X_right = X_right.cuda()
|
||||
|
@ -40,11 +42,19 @@ def predict_img(net,
|
|||
output_left = net(X_left)
|
||||
output_right = net(X_right)
|
||||
|
||||
left_probs = F.sigmoid(output_left)
|
||||
right_probs = F.sigmoid(output_right)
|
||||
left_probs = F.sigmoid(output_left).squeeze(0)
|
||||
right_probs = F.sigmoid(output_right).squeeze(0)
|
||||
|
||||
left_probs = F.upsample(left_probs, size=(img_height, img_height))
|
||||
right_probs = F.upsample(right_probs, size=(img_height, img_height))
|
||||
tf = transforms.Compose(
|
||||
[
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize(img_height),
|
||||
transforms.ToTensor()
|
||||
]
|
||||
)
|
||||
|
||||
left_probs = tf(left_probs.cpu())
|
||||
right_probs = tf(right_probs.cpu())
|
||||
|
||||
left_mask_np = left_probs.squeeze().cpu().numpy()
|
||||
right_mask_np = right_probs.squeeze().cpu().numpy()
|
||||
|
@ -66,6 +76,7 @@ def get_args():
|
|||
" (default : 'MODEL.pth')")
|
||||
parser.add_argument('--input', '-i', metavar='INPUT', nargs='+',
|
||||
help='filenames of input images', required=True)
|
||||
|
||||
parser.add_argument('--output', '-o', metavar='INPUT', nargs='+',
|
||||
help='filenames of ouput images')
|
||||
parser.add_argument('--cpu', '-c', action='store_true',
|
||||
|
|
|
@ -2,4 +2,5 @@ matplotlib
|
|||
pydensecrf
|
||||
numpy
|
||||
Pillow
|
||||
torch
|
||||
torch
|
||||
torchvision
|
||||
|
|
Loading…
Reference in a new issue