mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-14 00:58:15 +00:00
9d7be6e234
Former-commit-id: 79928c84cdf990ef6fe1043a3e4f74b9cc252642
144 lines
4.4 KiB
Python
Executable file
144 lines
4.4 KiB
Python
Executable file
import argparse
|
|
import logging
|
|
import os
|
|
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image
|
|
from torchvision import transforms
|
|
import torch.nn.functional as F
|
|
|
|
from unet import UNet
|
|
from utils.data_vis import plot_img_and_mask
|
|
from utils.dataset import BasicDataset
|
|
from utils.crf import dense_crf
|
|
|
|
|
|
def predict_img(net,
|
|
full_img,
|
|
device,
|
|
scale_factor=1,
|
|
out_threshold=0.5,
|
|
use_dense_crf=False):
|
|
net.eval()
|
|
|
|
ds = BasicDataset('', '', scale=scale_factor)
|
|
img = ds.preprocess(full_img)
|
|
|
|
img = img.unsqueeze(0)
|
|
img = img.to(device=device, dtype=torch.float32)
|
|
|
|
with torch.no_grad():
|
|
output = net(img)
|
|
|
|
if net.n_classes > 1:
|
|
probs = F.softmax(output, dim=1)
|
|
else:
|
|
probs = torch.sigmoid(output)
|
|
|
|
probs = probs.squeeze(0)
|
|
|
|
tf = transforms.Compose(
|
|
[
|
|
transforms.ToPILImage(),
|
|
transforms.Resize(full_img.shape[1]),
|
|
transforms.ToTensor()
|
|
]
|
|
)
|
|
|
|
probs = tf(probs.cpu())
|
|
full_mask = probs.squeeze().cpu().numpy()
|
|
|
|
if use_dense_crf:
|
|
full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)
|
|
|
|
return full_mask > out_threshold
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser(description='Predict masks from input images',
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
parser.add_argument('--model', '-m', default='MODEL.pth',
|
|
metavar='FILE',
|
|
help="Specify the file in which the model is stored")
|
|
parser.add_argument('--input', '-i', metavar='INPUT', nargs='+',
|
|
help='filenames of input images', required=True)
|
|
|
|
parser.add_argument('--output', '-o', metavar='INPUT', nargs='+',
|
|
help='Filenames of ouput images')
|
|
parser.add_argument('--viz', '-v', action='store_true',
|
|
help="Visualize the images as they are processed",
|
|
default=False)
|
|
parser.add_argument('--no-save', '-n', action='store_true',
|
|
help="Do not save the output masks",
|
|
default=False)
|
|
parser.add_argument('--mask-threshold', '-t', type=float,
|
|
help="Minimum probability value to consider a mask pixel white",
|
|
default=0.5)
|
|
parser.add_argument('--scale', '-s', type=float,
|
|
help="Scale factor for the input images",
|
|
default=0.5)
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def get_output_filenames(args):
|
|
in_files = args.input
|
|
out_files = []
|
|
|
|
if not args.output:
|
|
for f in in_files:
|
|
pathsplit = os.path.splitext(f)
|
|
out_files.append("{}_OUT{}".format(pathsplit[0], pathsplit[1]))
|
|
elif len(in_files) != len(args.output):
|
|
logging.error("Input files and output files are not of the same length")
|
|
raise SystemExit()
|
|
else:
|
|
out_files = args.output
|
|
|
|
return out_files
|
|
|
|
|
|
def mask_to_image(mask):
|
|
return Image.fromarray((mask * 255).astype(np.uint8))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = get_args()
|
|
in_files = args.input
|
|
out_files = get_output_filenames(args)
|
|
|
|
net = UNet(n_channels=3, n_classes=1)
|
|
|
|
logging.info("Loading model {}".format(args.model))
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
logging.info(f'Using device {device}')
|
|
net.to(device=device)
|
|
net.load_state_dict(torch.load(args.model, map_location=device))
|
|
|
|
logging.info("Model loaded !")
|
|
|
|
for i, fn in enumerate(in_files):
|
|
logging.info("\nPredicting image {} ...".format(fn))
|
|
|
|
img = Image.open(fn)
|
|
|
|
mask = predict_img(net=net,
|
|
full_img=img,
|
|
scale_factor=args.scale,
|
|
out_threshold=args.mask_threshold,
|
|
use_dense_crf=False,
|
|
device=device)
|
|
|
|
if not args.no_save:
|
|
out_fn = out_files[i]
|
|
result = mask_to_image(mask)
|
|
result.save(out_files[i])
|
|
|
|
logging.info("Mask saved to {}".format(out_files[i]))
|
|
|
|
if args.viz:
|
|
logging.info("Visualizing results for image {}, close to continue ...".format(fn))
|
|
plot_img_and_mask(img, mask)
|