REVA-QCAV/src/predict.py

82 lines
1.9 KiB
Python
Raw Normal View History

import argparse
import logging
import albumentations as A
import numpy as np
import onnx
import onnxruntime
2017-08-21 16:00:07 +00:00
import torch
from albumentations.pytorch import ToTensorV2
from PIL import Image
2017-08-21 16:00:07 +00:00
def get_args():
parser = argparse.ArgumentParser(
description="Predict masks from input images",
)
parser.add_argument(
"--model",
"-m",
default="model.pth",
metavar="FILE",
help="Specify the file in which the model is stored",
)
parser.add_argument(
"--input",
"-i",
metavar="INPUT",
help="Filenames of input images",
required=True,
)
parser.add_argument(
"--output",
"-o",
metavar="OUTPUT",
help="Filenames of output images",
)
return parser.parse_args()
def sigmoid(x):
return 1 / (1 + np.exp(-x))
if __name__ == "__main__":
args = get_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
onnx_model = onnx.load(args.model)
onnx.checker.check_model(onnx_model)
ort_session = onnxruntime.InferenceSession(args.model)
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
img = Image.open(args.input).convert("RGB")
logging.info(f"Preprocessing image {args.input}")
tf = A.Compose(
[
A.ToFloat(max_value=255),
ToTensorV2(),
],
)
aug = tf(image=np.asarray(img))
img = aug["image"]
logging.info(f"Predicting image {args.input}")
img = img.unsqueeze(0)
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
ort_outs = ort_session.run(None, ort_inputs)
img_out_y = ort_outs[0]
img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode="L")
img_out_y.save(args.output)