Merge pull request #13 from ZiyuanTonyZhang/master

fix scaling issue

Former-commit-id: b282e761f04904697418c9e759d761004a045428
This commit is contained in:
milesial 2018-06-23 19:59:38 +02:00 committed by GitHub
commit 14e0ff1c62
2 changed files with 18 additions and 6 deletions

21
predict.py Normal file → Executable file
View file

@ -11,6 +11,8 @@ from unet import UNet
from utils import resize_and_crop, normalize, split_img_into_squares, hwc_to_chw, merge_masks, dense_crf from utils import resize_and_crop, normalize, split_img_into_squares, hwc_to_chw, merge_masks, dense_crf
from utils import plot_img_and_mask from utils import plot_img_and_mask
from torchvision import transforms
def predict_img(net, def predict_img(net,
full_img, full_img,
scale_factor=0.5, scale_factor=0.5,
@ -31,7 +33,7 @@ def predict_img(net,
X_left = torch.from_numpy(left_square).unsqueeze(0) X_left = torch.from_numpy(left_square).unsqueeze(0)
X_right = torch.from_numpy(right_square).unsqueeze(0) X_right = torch.from_numpy(right_square).unsqueeze(0)
if use_gpu: if use_gpu:
X_left = X_left.cuda() X_left = X_left.cuda()
X_right = X_right.cuda() X_right = X_right.cuda()
@ -40,11 +42,19 @@ def predict_img(net,
output_left = net(X_left) output_left = net(X_left)
output_right = net(X_right) output_right = net(X_right)
left_probs = F.sigmoid(output_left) left_probs = F.sigmoid(output_left).squeeze(0)
right_probs = F.sigmoid(output_right) right_probs = F.sigmoid(output_right).squeeze(0)
left_probs = F.upsample(left_probs, size=(img_height, img_height)) tf = transforms.Compose(
right_probs = F.upsample(right_probs, size=(img_height, img_height)) [
transforms.ToPILImage(),
transforms.Resize(img_height),
transforms.ToTensor()
]
)
left_probs = tf(left_probs.cpu())
right_probs = tf(right_probs.cpu())
left_mask_np = left_probs.squeeze().cpu().numpy() left_mask_np = left_probs.squeeze().cpu().numpy()
right_mask_np = right_probs.squeeze().cpu().numpy() right_mask_np = right_probs.squeeze().cpu().numpy()
@ -66,6 +76,7 @@ def get_args():
" (default : 'MODEL.pth')") " (default : 'MODEL.pth')")
parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', parser.add_argument('--input', '-i', metavar='INPUT', nargs='+',
help='filenames of input images', required=True) help='filenames of input images', required=True)
parser.add_argument('--output', '-o', metavar='INPUT', nargs='+', parser.add_argument('--output', '-o', metavar='INPUT', nargs='+',
help='filenames of ouput images') help='filenames of ouput images')
parser.add_argument('--cpu', '-c', action='store_true', parser.add_argument('--cpu', '-c', action='store_true',

View file

@ -2,4 +2,5 @@ matplotlib
pydensecrf pydensecrf
numpy numpy
Pillow Pillow
torch torch
torchvision