2018-04-09 03:15:24 +00:00
|
|
|
import argparse
|
2019-10-24 19:37:21 +00:00
|
|
|
import logging
|
2018-06-08 17:27:32 +00:00
|
|
|
import os
|
2018-04-09 03:15:24 +00:00
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
import numpy as np
|
2017-08-21 16:00:07 +00:00
|
|
|
import torch
|
2019-12-21 21:04:23 +00:00
|
|
|
import torch.nn.functional as F
|
2018-06-08 17:27:32 +00:00
|
|
|
from PIL import Image
|
2019-10-24 19:37:21 +00:00
|
|
|
from torchvision import transforms
|
2017-08-21 16:00:07 +00:00
|
|
|
|
2021-08-16 00:53:00 +00:00
|
|
|
from data_loading import BasicDataset
|
2017-11-30 05:45:19 +00:00
|
|
|
from unet import UNet
|
2021-08-16 00:53:00 +00:00
|
|
|
from utils import plot_img_and_mask
|
2019-11-23 16:56:14 +00:00
|
|
|
|
2018-06-17 06:31:42 +00:00
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
def predict_img(net,
|
|
|
|
full_img,
|
2019-10-24 19:37:21 +00:00
|
|
|
device,
|
|
|
|
scale_factor=1,
|
2019-12-21 21:04:23 +00:00
|
|
|
out_threshold=0.5):
|
2018-09-26 06:59:49 +00:00
|
|
|
net.eval()
|
2021-08-16 00:53:00 +00:00
|
|
|
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False))
|
2019-11-23 13:22:42 +00:00
|
|
|
img = img.unsqueeze(0)
|
|
|
|
img = img.to(device=device, dtype=torch.float32)
|
2017-08-21 16:00:07 +00:00
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
with torch.no_grad():
|
2019-11-23 13:22:42 +00:00
|
|
|
output = net(img)
|
2019-11-06 12:13:37 +00:00
|
|
|
|
|
|
|
if net.n_classes > 1:
|
2021-08-16 00:53:00 +00:00
|
|
|
probs = F.softmax(output, dim=1)[0]
|
2019-11-06 12:13:37 +00:00
|
|
|
else:
|
2021-08-16 00:53:00 +00:00
|
|
|
probs = torch.sigmoid(output)[0]
|
2018-06-17 06:31:42 +00:00
|
|
|
|
2021-08-16 00:53:00 +00:00
|
|
|
tf = transforms.Compose([
|
|
|
|
transforms.ToPILImage(),
|
|
|
|
transforms.Resize((full_img.size[1], full_img.size[0])),
|
|
|
|
transforms.ToTensor()
|
|
|
|
])
|
2017-11-30 05:45:19 +00:00
|
|
|
|
2021-08-16 00:53:00 +00:00
|
|
|
full_mask = tf(probs.cpu()).squeeze()
|
2018-06-08 17:27:32 +00:00
|
|
|
|
2021-08-16 00:53:00 +00:00
|
|
|
if net.n_classes == 1:
|
|
|
|
return (full_mask > out_threshold).numpy()
|
|
|
|
else:
|
|
|
|
return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy()
|
2018-06-08 17:27:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
def get_args():
|
2021-08-16 00:53:00 +00:00
|
|
|
parser = argparse.ArgumentParser(description='Predict masks from input images')
|
|
|
|
parser.add_argument('--model', '-m', default='MODEL.pth', metavar='FILE',
|
|
|
|
help='Specify the file in which the model is stored')
|
|
|
|
parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', help='Filenames of input images', required=True)
|
|
|
|
parser.add_argument('--output', '-o', metavar='INPUT', nargs='+', help='Filenames of output images')
|
2017-11-30 05:45:19 +00:00
|
|
|
parser.add_argument('--viz', '-v', action='store_true',
|
2021-08-16 00:53:00 +00:00
|
|
|
help='Visualize the images as they are processed')
|
|
|
|
parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks')
|
|
|
|
parser.add_argument('--mask-threshold', '-t', type=float, default=0.5,
|
|
|
|
help='Minimum probability value to consider a mask pixel white')
|
|
|
|
parser.add_argument('--scale', '-s', type=float, default=0.5,
|
|
|
|
help='Scale factor for the input images')
|
2017-11-30 05:45:19 +00:00
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
return parser.parse_args()
|
2017-11-30 05:45:19 +00:00
|
|
|
|
2019-10-24 19:37:21 +00:00
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
def get_output_filenames(args):
|
2021-08-16 00:53:00 +00:00
|
|
|
def _generate_name(fn):
|
|
|
|
split = os.path.splitext(fn)
|
|
|
|
return f'{split[0]}_OUT{split[1]}'
|
2017-11-30 05:45:19 +00:00
|
|
|
|
2021-08-16 00:53:00 +00:00
|
|
|
return args.output or list(map(_generate_name, args.input))
|
2018-06-08 17:27:32 +00:00
|
|
|
|
2019-10-24 19:37:21 +00:00
|
|
|
|
2021-08-16 00:53:00 +00:00
|
|
|
def mask_to_image(mask: np.ndarray):
|
|
|
|
if mask.ndim == 2:
|
|
|
|
return Image.fromarray((mask * 255).astype(np.uint8))
|
|
|
|
elif mask.ndim == 3:
|
|
|
|
return Image.fromarray((np.argmax(mask, dim=0) * 255 / mask.shape[0]).astype(np.uint8))
|
2018-06-08 17:27:32 +00:00
|
|
|
|
2019-10-24 19:37:21 +00:00
|
|
|
|
2021-08-16 00:53:00 +00:00
|
|
|
if __name__ == '__main__':
|
2018-06-08 17:27:32 +00:00
|
|
|
args = get_args()
|
|
|
|
in_files = args.input
|
|
|
|
out_files = get_output_filenames(args)
|
|
|
|
|
2021-08-16 00:53:00 +00:00
|
|
|
net = UNet(n_channels=3, n_classes=2)
|
2018-06-08 17:27:32 +00:00
|
|
|
|
2019-10-24 19:37:21 +00:00
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
2021-08-16 00:53:00 +00:00
|
|
|
logging.info(f'Loading model {args.model}')
|
2019-10-24 19:37:21 +00:00
|
|
|
logging.info(f'Using device {device}')
|
2021-08-16 00:53:00 +00:00
|
|
|
|
2019-10-30 11:27:03 +00:00
|
|
|
net.to(device=device)
|
2019-10-24 19:37:21 +00:00
|
|
|
net.load_state_dict(torch.load(args.model, map_location=device))
|
2018-06-08 17:27:32 +00:00
|
|
|
|
2021-08-16 00:53:00 +00:00
|
|
|
logging.info('Model loaded!')
|
2017-11-30 05:45:19 +00:00
|
|
|
|
2021-08-16 00:53:00 +00:00
|
|
|
for i, filename in enumerate(in_files):
|
|
|
|
logging.info(f'\nPredicting image {filename} ...')
|
|
|
|
img = Image.open(filename)
|
2017-11-30 05:45:19 +00:00
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
mask = predict_img(net=net,
|
|
|
|
full_img=img,
|
|
|
|
scale_factor=args.scale,
|
|
|
|
out_threshold=args.mask_threshold,
|
2019-10-24 19:37:21 +00:00
|
|
|
device=device)
|
2017-11-30 05:45:19 +00:00
|
|
|
|
|
|
|
if not args.no_save:
|
2021-08-16 00:53:00 +00:00
|
|
|
out_filename = out_files[i]
|
2018-06-08 17:27:32 +00:00
|
|
|
result = mask_to_image(mask)
|
2021-08-16 00:53:00 +00:00
|
|
|
result.save(out_filename)
|
|
|
|
logging.info(f'Mask saved to {out_filename}')
|
2019-10-24 19:37:21 +00:00
|
|
|
|
|
|
|
if args.viz:
|
2021-08-16 00:53:00 +00:00
|
|
|
logging.info(f'Visualizing results for image {filename}, close to continue...')
|
2019-10-24 19:37:21 +00:00
|
|
|
plot_img_and_mask(img, mask)
|