REVA-QCAV/predict.py
milesial 7ea54febec Added CLI for predict, cleaned up code, updated README
Former-commit-id: 77555ccc0925a8fba796ce7e42843d95b6e9dce0
2017-11-30 06:45:19 +01:00

119 lines
3.8 KiB
Python

import torch
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy
from PIL import Image
import argparse
import os
from utils import *
from crf import dense_crf
from unet import UNet
def predict_img(net, full_img, gpu=False):
img = resize_and_crop(full_img)
left = get_square(img, 0)
right = get_square(img, 1)
right = normalize(right)
left = normalize(left)
right = np.transpose(right, axes=[2, 0, 1])
left = np.transpose(left, axes=[2, 0, 1])
X_l = torch.FloatTensor(left).unsqueeze(0)
X_r = torch.FloatTensor(right).unsqueeze(0)
if gpu:
X_l = Variable(X_l, volatile=True).cuda()
X_r = Variable(X_r, volatile=True).cuda()
else:
X_l = Variable(X_l, volatile=True)
X_r = Variable(X_r, volatile=True)
y_l = F.sigmoid(net(X_l))
y_r = F.sigmoid(net(X_r))
y_l = F.upsample_bilinear(y_l, scale_factor=2).data[0][0].cpu().numpy()
y_r = F.upsample_bilinear(y_r, scale_factor=2).data[0][0].cpu().numpy()
y = merge_masks(y_l, y_r, 1918)
yy = dense_crf(np.array(full_img).astype(np.uint8), y)
return yy > 0.5
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model', '-m', default='MODEL.pth',
metavar='FILE',
help="Specify the file in which is stored the model"
" (default : 'MODEL.pth')")
parser.add_argument('--input', '-i', metavar='INPUT', nargs='+',
help='filenames of input images', required=True)
parser.add_argument('--output', '-o', metavar='INPUT', nargs='+',
help='filenames of ouput images')
parser.add_argument('--cpu', '-c', action='store_true',
help="Do not use the cuda version of the net",
default=False)
parser.add_argument('--viz', '-v', action='store_true',
help="Visualize the images as they are processed",
default=False)
parser.add_argument('--no-save', '-n', action='store_false',
help="Do not save the output masks",
default=False)
args = parser.parse_args()
print("Using model file : {}".format(args.model))
net = UNet(3, 1)
if not args.cpu:
print("Using CUDA version of the net, prepare your GPU !")
net.cuda()
else:
net.cpu()
print("Using CPU version of the net, this may be very slow")
in_files = args.input
out_files = []
if not args.output:
for f in in_files:
pathsplit = os.path.splitext(f)
out_files.append("{}_OUT{}".format(pathsplit[0], pathsplit[1]))
elif len(in_files) != len(args.output):
print("Error : Input files and output files are not of the same length")
raise SystemExit()
else:
out_files = args.output
print("Loading model ...")
net.load_state_dict(torch.load(args.model))
print("Model loaded !")
for i, fn in enumerate(in_files):
print("\nPredicting image {} ...".format(fn))
img = Image.open(fn)
out = predict_img(net, img, not args.cpu)
if args.viz:
print("Vizualising results for image {}, close to continue ..."
.format(fn))
fig = plt.figure()
a = fig.add_subplot(1, 2, 1)
a.set_title('Input image')
plt.imshow(img)
b = fig.add_subplot(1, 2, 2)
b.set_title('Output mask')
plt.imshow(out)
plt.show()
if not args.no_save:
out_fn = out_files[i]
result = Image.fromarray((out * 255).astype(numpy.uint8))
result.save(out_files[i])
print("Mask saved to {}".format(out_files[i]))