REVA-QCAV/src/predict.py

144 lines
3.8 KiB
Python
Raw Normal View History

import argparse
import logging
import os
import numpy as np
2017-08-21 16:00:07 +00:00
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
2017-08-21 16:00:07 +00:00
from src.utils.dataset import BasicDataset
from unet import UNet
from utils.utils import plot_img_and_mask
def get_args():
parser = argparse.ArgumentParser(
description="Predict masks from input images",
)
parser.add_argument(
"--model",
"-m",
default="MODEL.pth",
metavar="FILE",
help="Specify the file in which the model is stored",
)
parser.add_argument(
"--input",
"-i",
metavar="INPUT",
nargs="+",
help="Filenames of input images",
required=True,
)
parser.add_argument(
"--output",
"-o",
metavar="OUTPUT",
nargs="+",
help="Filenames of output images",
)
parser.add_argument(
"--viz",
"-v",
action="store_true",
help="Visualize the images as they are processed",
)
parser.add_argument(
"--no-save",
"-n",
action="store_true",
help="Do not save the output masks",
)
parser.add_argument(
"--mask-threshold",
"-t",
type=float,
default=0.5,
help="Minimum probability value to consider a mask pixel white",
)
parser.add_argument(
"--scale",
"-s",
type=float,
default=0.5,
help="Scale factor for the input images",
)
return parser.parse_args()
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5):
net.eval()
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False))
img = img.unsqueeze(0)
img = img.to(device=device, dtype=torch.float32)
with torch.inference_mode():
output = net(img)
probs = torch.sigmoid(output)[0]
tf = transforms.Compose(
[transforms.ToPILImage(), transforms.Resize((full_img.size[1], full_img.size[0])), transforms.ToTensor()]
)
full_mask = tf(probs.cpu()).squeeze()
if net.n_classes == 1:
return (full_mask > out_threshold).numpy()
else:
return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy()
def get_output_filenames(args):
def _generate_name(fn):
split = os.path.splitext(fn)
return f"{split[0]}_OUT{split[1]}"
return args.output or list(map(_generate_name, args.input))
def mask_to_image(mask: np.ndarray):
if mask.ndim == 2:
return Image.fromarray((mask * 255).astype(np.uint8))
elif mask.ndim == 3:
return Image.fromarray((np.argmax(mask, axis=0) * 255 / mask.shape[0]).astype(np.uint8))
if __name__ == "__main__":
args = get_args()
in_files = args.input
out_files = get_output_filenames(args)
net = UNet(n_channels=3, n_classes=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Loading model {args.model}")
logging.info(f"Using device {device}")
net.to(device=device)
net.load_state_dict(torch.load(args.model, map_location=device))
logging.info("Model loaded!")
for i, filename in enumerate(in_files):
logging.info(f"\nPredicting image {filename} ...")
img = Image.open(filename)
mask = predict_img(
net=net, full_img=img, scale_factor=args.scale, out_threshold=args.mask_threshold, device=device
)
if not args.no_save:
out_filename = out_files[i]
result = mask_to_image(mask)
result.save(out_filename)
logging.info(f"Mask saved to {out_filename}")
if args.viz:
logging.info(f"Visualizing results for image {filename}, close to continue...")
plot_img_and_mask(img, mask)