mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
179b08e6a3
Former-commit-id: 7f855d53668a576236b16670fd052d8caf816d80
166 lines
5.3 KiB
Python
Executable file
166 lines
5.3 KiB
Python
Executable file
import argparse
|
|
import os
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from PIL import Image
|
|
|
|
from unet import UNet
|
|
from utils import resize_and_crop, normalize, split_img_into_squares, hwc_to_chw, merge_masks, dense_crf
|
|
from utils import plot_img_and_mask
|
|
|
|
from torchvision import transforms
|
|
|
|
def predict_img(net,
|
|
full_img,
|
|
scale_factor=0.5,
|
|
out_threshold=0.5,
|
|
use_dense_crf=True,
|
|
use_gpu=False):
|
|
|
|
img_height = full_img.size[1]
|
|
img_width = full_img.size[0]
|
|
|
|
img = resize_and_crop(full_img, scale=scale_factor)
|
|
img = normalize(img)
|
|
|
|
left_square, right_square = split_img_into_squares(img)
|
|
|
|
left_square = hwc_to_chw(left_square)
|
|
right_square = hwc_to_chw(right_square)
|
|
|
|
X_left = torch.from_numpy(left_square).unsqueeze(0)
|
|
X_right = torch.from_numpy(right_square).unsqueeze(0)
|
|
|
|
if use_gpu:
|
|
X_left = X_left.cuda()
|
|
X_right = X_right.cuda()
|
|
|
|
with torch.no_grad():
|
|
output_left = net(X_left)
|
|
output_right = net(X_right)
|
|
|
|
left_probs = F.sigmoid(output_left).squeeze(0)
|
|
right_probs = F.sigmoid(output_right).squeeze(0)
|
|
|
|
tf = transforms.Compose(
|
|
[
|
|
transforms.ToPILImage(),
|
|
transforms.Resize(img_height),
|
|
transforms.ToTensor()
|
|
]
|
|
)
|
|
|
|
left_probs = tf(left_probs.cpu())
|
|
right_probs = tf(right_probs.cpu())
|
|
|
|
left_mask_np = left_probs.squeeze().cpu().numpy()
|
|
right_mask_np = right_probs.squeeze().cpu().numpy()
|
|
|
|
full_mask = merge_masks(left_mask_np, right_mask_np, img_width)
|
|
|
|
if use_dense_crf:
|
|
full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)
|
|
|
|
return full_mask > out_threshold
|
|
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--model', '-m', default='MODEL.pth',
|
|
metavar='FILE',
|
|
help="Specify the file in which is stored the model"
|
|
" (default : 'MODEL.pth')")
|
|
parser.add_argument('--input', '-i', metavar='INPUT', nargs='+',
|
|
help='filenames of input images', required=True)
|
|
|
|
parser.add_argument('--output', '-o', metavar='INPUT', nargs='+',
|
|
help='filenames of ouput images')
|
|
parser.add_argument('--cpu', '-c', action='store_true',
|
|
help="Do not use the cuda version of the net",
|
|
default=False)
|
|
parser.add_argument('--viz', '-v', action='store_true',
|
|
help="Visualize the images as they are processed",
|
|
default=False)
|
|
parser.add_argument('--no-save', '-n', action='store_true',
|
|
help="Do not save the output masks",
|
|
default=False)
|
|
parser.add_argument('--no-crf', '-r', action='store_true',
|
|
help="Do not use dense CRF postprocessing",
|
|
default=False)
|
|
parser.add_argument('--mask-threshold', '-t', type=float,
|
|
help="Minimum probability value to consider a mask pixel white",
|
|
default=0.5)
|
|
parser.add_argument('--scale', '-s', type=float,
|
|
help="Scale factor for the input images",
|
|
default=0.5)
|
|
|
|
return parser.parse_args()
|
|
|
|
def get_output_filenames(args):
|
|
in_files = args.input
|
|
out_files = []
|
|
|
|
if not args.output:
|
|
for f in in_files:
|
|
pathsplit = os.path.splitext(f)
|
|
out_files.append("{}_OUT{}".format(pathsplit[0], pathsplit[1]))
|
|
elif len(in_files) != len(args.output):
|
|
print("Error : Input files and output files are not of the same length")
|
|
raise SystemExit()
|
|
else:
|
|
out_files = args.output
|
|
|
|
return out_files
|
|
|
|
def mask_to_image(mask):
|
|
return Image.fromarray((mask * 255).astype(np.uint8))
|
|
|
|
if __name__ == "__main__":
|
|
args = get_args()
|
|
in_files = args.input
|
|
out_files = get_output_filenames(args)
|
|
|
|
net = UNet(n_channels=3, n_classes=1)
|
|
|
|
print("Loading model {}".format(args.model))
|
|
|
|
if not args.cpu:
|
|
print("Using CUDA version of the net, prepare your GPU !")
|
|
net.cuda()
|
|
net.load_state_dict(torch.load(args.model))
|
|
else:
|
|
net.cpu()
|
|
net.load_state_dict(torch.load(args.model, map_location='cpu'))
|
|
print("Using CPU version of the net, this may be very slow")
|
|
|
|
print("Model loaded !")
|
|
|
|
for i, fn in enumerate(in_files):
|
|
print("\nPredicting image {} ...".format(fn))
|
|
|
|
img = Image.open(fn)
|
|
if img.size[0] < img.size[1]:
|
|
print("Error: image height larger than the width")
|
|
|
|
mask = predict_img(net=net,
|
|
full_img=img,
|
|
scale_factor=args.scale,
|
|
out_threshold=args.mask_threshold,
|
|
use_dense_crf= not args.no_crf,
|
|
use_gpu=not args.cpu)
|
|
|
|
if args.viz:
|
|
print("Visualizing results for image {}, close to continue ...".format(fn))
|
|
plot_img_and_mask(img, mask)
|
|
|
|
if not args.no_save:
|
|
out_fn = out_files[i]
|
|
result = mask_to_image(mask)
|
|
result.save(out_files[i])
|
|
|
|
print("Mask saved to {}".format(out_files[i]))
|