2018-04-09 03:15:24 +00:00
|
|
|
import argparse
|
2019-10-24 19:37:21 +00:00
|
|
|
import logging
|
2018-04-09 03:15:24 +00:00
|
|
|
|
2022-06-29 14:12:00 +00:00
|
|
|
import albumentations as A
|
2022-07-05 10:06:12 +00:00
|
|
|
import cv2
|
2018-06-08 17:27:32 +00:00
|
|
|
import numpy as np
|
2017-08-21 16:00:07 +00:00
|
|
|
import torch
|
2022-06-29 14:12:00 +00:00
|
|
|
from albumentations.pytorch import ToTensorV2
|
2018-06-08 17:27:32 +00:00
|
|
|
from PIL import Image
|
2017-08-21 16:00:07 +00:00
|
|
|
|
2022-06-27 13:39:44 +00:00
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
def get_args():
|
2022-06-27 13:39:44 +00:00
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
description="Predict masks from input images",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--model",
|
|
|
|
"-m",
|
2022-06-29 14:12:00 +00:00
|
|
|
default="model.pth",
|
2022-06-27 13:39:44 +00:00
|
|
|
metavar="FILE",
|
|
|
|
help="Specify the file in which the model is stored",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--input",
|
|
|
|
"-i",
|
|
|
|
metavar="INPUT",
|
|
|
|
help="Filenames of input images",
|
|
|
|
required=True,
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--output",
|
|
|
|
"-o",
|
|
|
|
metavar="OUTPUT",
|
|
|
|
help="Filenames of output images",
|
|
|
|
)
|
2017-11-30 05:45:19 +00:00
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
return parser.parse_args()
|
2017-11-30 05:45:19 +00:00
|
|
|
|
2019-10-24 19:37:21 +00:00
|
|
|
|
2022-07-05 10:06:12 +00:00
|
|
|
def sigmoid(x):
|
|
|
|
return 1 / (1 + np.exp(-x))
|
|
|
|
|
|
|
|
|
2022-06-27 13:39:44 +00:00
|
|
|
if __name__ == "__main__":
|
2018-06-08 17:27:32 +00:00
|
|
|
args = get_args()
|
|
|
|
|
2022-06-30 12:36:48 +00:00
|
|
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
|
|
|
|
2022-07-05 10:06:12 +00:00
|
|
|
net = cv2.dnn.readNetFromONNX(args.model)
|
|
|
|
logging.info("onnx model loaded")
|
2022-06-29 12:15:04 +00:00
|
|
|
|
2022-06-29 14:12:00 +00:00
|
|
|
logging.info(f"Loading image {args.input}")
|
2022-07-05 10:06:12 +00:00
|
|
|
input_img = cv2.imread(args.input, cv2.IMREAD_COLOR)
|
|
|
|
input_img = input_img.astype(np.float32)
|
|
|
|
# input_img = cv2.resize(input_img, (512, 512))
|
|
|
|
|
|
|
|
logging.info("converting to blob")
|
|
|
|
input_blob = cv2.dnn.blobFromImage(
|
|
|
|
image=input_img,
|
|
|
|
scalefactor=1 / 255,
|
2022-06-29 14:12:00 +00:00
|
|
|
)
|
2022-06-30 14:47:28 +00:00
|
|
|
|
2022-07-05 10:06:12 +00:00
|
|
|
net.setInput(input_blob)
|
|
|
|
mask = net.forward()
|
|
|
|
mask = sigmoid(mask)
|
|
|
|
mask = mask > 0.5
|
|
|
|
mask = mask.astype(np.float32)
|
2019-10-24 19:37:21 +00:00
|
|
|
|
2022-06-29 14:12:00 +00:00
|
|
|
logging.info(f"Saving prediction to {args.output}")
|
2022-07-05 10:06:12 +00:00
|
|
|
mask = Image.fromarray(mask, "L")
|
2022-06-30 12:36:48 +00:00
|
|
|
mask.save(args.output)
|