feat: working prediction but only for 512x512

Former-commit-id: c9d88ad18de91409fc1be1f1abe59d6e75ff2235 [formerly 8bf12bad1c3e8424aa26c7bd9a441facc670b059]
Former-commit-id: 820e327f8a79bad36d1f15944012e77ba1ecd560
This commit is contained in:
Laurent Fainsin 2022-07-05 12:17:32 +02:00
parent 0fb1d4fb7a
commit 36b044c719
4 changed files with 114 additions and 21 deletions

2
.vscode/launch.json vendored
View file

@ -12,7 +12,7 @@
"console": "integratedTerminal",
"args": [
"--input",
"images/SM.png",
"images/test.png",
"--output",
"output_onnx.png",
"--model",

83
poetry.lock generated
View file

@ -240,6 +240,14 @@ category = "dev"
optional = false
python-versions = "*"
[[package]]
name = "flatbuffers"
version = "2.0"
description = "The FlatBuffers serialization format for Python"
category = "main"
optional = false
python-versions = "*"
[[package]]
name = "fonttools"
version = "4.33.3"
@ -673,6 +681,35 @@ rsa = ["cryptography (>=3.0.0)"]
signals = ["blinker (>=1.4.0)"]
signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
[[package]]
name = "onnx"
version = "1.12.0"
description = "Open Neural Network Exchange"
category = "main"
optional = false
python-versions = "*"
[package.dependencies]
numpy = ">=1.16.6"
protobuf = ">=3.12.2,<=3.20.1"
typing-extensions = ">=3.6.2.1"
[package.extras]
lint = ["clang-format (==13.0.0)", "flake8", "mypy (==0.782)", "types-protobuf (==3.18.4)"]
[[package]]
name = "onnxruntime"
version = "1.11.1"
description = "ONNX Runtime is a runtime accelerator for Machine Learning models"
category = "main"
optional = false
python-versions = "*"
[package.dependencies]
flatbuffers = "*"
numpy = ">=1.21.6"
protobuf = "*"
[[package]]
name = "opencv-python-headless"
version = "4.6.0.66"
@ -1446,7 +1483,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-
[metadata]
lock-version = "1.1"
python-versions = ">=3.8,<3.11"
content-hash = "b192d0e5f593e99630bb92cd31c510dcdea67b0b54861176f92f50505724e7d5"
content-hash = "c1855a97cbe537d31526a76a49fde822e98acc89713f5f902639327b688c079a"
[metadata.files]
absl-py = [
@ -1692,6 +1729,10 @@ executing = [
{file = "executing-0.8.3-py2.py3-none-any.whl", hash = "sha256:d1eef132db1b83649a3905ca6dd8897f71ac6f8cac79a7e58a1a09cf137546c9"},
{file = "executing-0.8.3.tar.gz", hash = "sha256:c6554e21c6b060590a6d3be4b82fb78f8f0194d809de5ea7df1c093763311501"},
]
flatbuffers = [
{file = "flatbuffers-2.0-py2.py3-none-any.whl", hash = "sha256:3751954f0604580d3219ae49a85fafec9d85eec599c0b96226e1bc0b48e57474"},
{file = "flatbuffers-2.0.tar.gz", hash = "sha256:12158ab0272375eab8db2d663ae97370c33f152b27801fa6024e1d6105fd4dd2"},
]
fonttools = [
{file = "fonttools-4.33.3-py3-none-any.whl", hash = "sha256:f829c579a8678fa939a1d9e9894d01941db869de44390adb49ce67055a06cc2a"},
{file = "fonttools-4.33.3.zip", hash = "sha256:c0fdcfa8ceebd7c1b2021240bd46ef77aa8e7408cf10434be55df52384865f8e"},
@ -2056,6 +2097,46 @@ oauthlib = [
{file = "oauthlib-3.2.0-py3-none-any.whl", hash = "sha256:6db33440354787f9b7f3a6dbd4febf5d0f93758354060e802f6c06cb493022fe"},
{file = "oauthlib-3.2.0.tar.gz", hash = "sha256:23a8208d75b902797ea29fd31fa80a15ed9dc2c6c16fe73f5d346f83f6fa27a2"},
]
onnx = [
{file = "onnx-1.12.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:bdbd2578424c70836f4d0f9dda16c21868ddb07cc8192f9e8a176908b43d694b"},
{file = "onnx-1.12.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:213e73610173f6b2e99f99a4b0636f80b379c417312079d603806e48ada4ca8b"},
{file = "onnx-1.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fd2f4e23078df197bb76a59b9cd8f5a43a6ad2edc035edb3ecfb9042093e05a"},
{file = "onnx-1.12.0-cp310-cp310-win32.whl", hash = "sha256:23781594bb8b7ee985de1005b3c601648d5b0568a81e01365c48f91d1f5648e4"},
{file = "onnx-1.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:81a3555fd67be2518bf86096299b48fb9154652596219890abfe90bd43a9ec13"},
{file = "onnx-1.12.0-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:5578b93dc6c918cec4dee7fb7d9dd3b09d338301ee64ca8b4f28bc217ed42dca"},
{file = "onnx-1.12.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c11162ffc487167da140f1112f49c4f82d815824f06e58bc3095407699f05863"},
{file = "onnx-1.12.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:341c7016e23273e9ffa9b6e301eee95b8c37d0f04df7cedbdb169d2c39524c96"},
{file = "onnx-1.12.0-cp37-cp37m-win32.whl", hash = "sha256:3c6e6bcffc3f5c1e148df3837dc667fa4c51999788c1b76b0b8fbba607e02da8"},
{file = "onnx-1.12.0-cp37-cp37m-win_amd64.whl", hash = "sha256:8a7aa61aea339bd28f310f4af4f52ce6c4b876386228760b16308efd58f95059"},
{file = "onnx-1.12.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:56ceb7e094c43882b723cfaa107d85ad673cfdf91faeb28d7dcadacca4f43a07"},
{file = "onnx-1.12.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3629e8258db15d4e2c9b7f1be91a3186719dd94661c218c6f5fde3cc7de3d4d"},
{file = "onnx-1.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d9a7db54e75529160337232282a4816cc50667dc7dc34be178fd6f6b79d4705"},
{file = "onnx-1.12.0-cp38-cp38-win32.whl", hash = "sha256:fea5156a03398fe0e23248042d8651c1eaac5f6637d4dd683b4c1f1320b9f7b4"},
{file = "onnx-1.12.0-cp38-cp38-win_amd64.whl", hash = "sha256:f66d2996e65f490a57b3ae952e4e9189b53cc9fe3f75e601d50d4db2dc1b1cd9"},
{file = "onnx-1.12.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:c39a7a0352c856f1df30dccf527eb6cb4909052e5eaf6fa2772a637324c526aa"},
{file = "onnx-1.12.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fab13feb4d94342aae6d357d480f2e47d41b9f4e584367542b21ca6defda9e0a"},
{file = "onnx-1.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7a9b3ea02c30efc1d2662337e280266aca491a8e86be0d8a657f874b7cccd1e"},
{file = "onnx-1.12.0-cp39-cp39-win32.whl", hash = "sha256:f8800f28c746ab06e51ef8449fd1215621f4ddba91be3ffc264658937d38a2af"},
{file = "onnx-1.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:af90427ca04c6b7b8107c2021e1273227a3ef1a7a01f3073039cae7855a59833"},
{file = "onnx-1.12.0.tar.gz", hash = "sha256:13b3e77d27523b9dbf4f30dfc9c959455859d5e34e921c44f712d69b8369eff9"},
]
onnxruntime = [
{file = "onnxruntime-1.11.1-cp37-cp37m-macosx_10_14_x86_64.whl", hash = "sha256:88b94a900754ef189c2b06f2046f2de8008753e0e8a3e562b2fb03298026b4b4"},
{file = "onnxruntime-1.11.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:958974be7808b46815533c74e8849a2d73e73d656df8369a114ce3359f77760b"},
{file = "onnxruntime-1.11.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3f83e7d52932b68b08cfba4920816efc7c3177036a90116137b11888e1f2490"},
{file = "onnxruntime-1.11.1-cp37-cp37m-win32.whl", hash = "sha256:3106bfcd1532afcaa26fc47931f7a8770dc710263647e8fbb5f75fa5a8fc70f9"},
{file = "onnxruntime-1.11.1-cp37-cp37m-win_amd64.whl", hash = "sha256:80775f4f64850b6774dbaa955888a89dc719cf654f1995ed5418e78c0139b5f4"},
{file = "onnxruntime-1.11.1-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:b24a3cd1e6d7fe7c4c5be2996ba02ebf8beed6347b2fd3ac869d1c685a2e0264"},
{file = "onnxruntime-1.11.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73c4df0a446fe49d59629746d2163fa39b732b6afb3b5d00f8c9ec91a040e5c4"},
{file = "onnxruntime-1.11.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:149bf850c4e320e33894cab3e350a945ab17690cf54ffa00ef965273112ef614"},
{file = "onnxruntime-1.11.1-cp38-cp38-win32.whl", hash = "sha256:9e202d7323b5728cdc3c0ee3bbc35f10cd56c7120c9626887c4ebe5d8503b488"},
{file = "onnxruntime-1.11.1-cp38-cp38-win_amd64.whl", hash = "sha256:00632fc2ee3cf86349f5b00f5385a62fe5720ef14b471c919cc2c94faeb446d0"},
{file = "onnxruntime-1.11.1-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:792985ddf3d3c46efa24bcfef970e7ccd4421d46173a96ca3974dab709598591"},
{file = "onnxruntime-1.11.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ce95870b0bc7cbef5383b3c3062c6d9784af71f266192bc887928d7b927a46"},
{file = "onnxruntime-1.11.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:53121b024f68d6b16bb93bc3fb73ba05b6f55647d12054a8efae7f48ed761add"},
{file = "onnxruntime-1.11.1-cp39-cp39-win32.whl", hash = "sha256:b90124277454c50c5b2073bb9e1368b2a5672f30c2c8f3fe01393967dcd6dce2"},
{file = "onnxruntime-1.11.1-cp39-cp39-win_amd64.whl", hash = "sha256:b1ffefc961fc607e5929fef92f3bc8bc48bd3a074b2a6448887be23eb313f75a"},
]
opencv-python-headless = [
{file = "opencv-python-headless-4.6.0.66.tar.gz", hash = "sha256:d5291d7e10aa2c19cab6fd86f0d61af8617290ecd2d7ffcb051e446868d04cc5"},
{file = "opencv_python_headless-4.6.0.66-cp36-abi3-macosx_10_15_x86_64.whl", hash = "sha256:21e70f8b0c04098cdf466d27184fe6c3820aaef944a22548db95099959c95889"},

View file

@ -15,6 +15,8 @@ torch = "^1.11.0"
torchvision = "^0.12.0"
tqdm = "^4.64.0"
wandb = "^0.12.19"
onnx = "^1.12.0"
onnxruntime = "^1.11.1"
[tool.poetry.dev-dependencies]
black = "^22.3.0"

View file

@ -2,8 +2,9 @@ import argparse
import logging
import albumentations as A
import cv2
import numpy as np
import onnx
import onnxruntime
import torch
from albumentations.pytorch import ToTensorV2
from PIL import Image
@ -46,26 +47,35 @@ if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
net = cv2.dnn.readNetFromONNX(args.model)
logging.info("onnx model loaded")
onnx_model = onnx.load(args.model)
onnx.checker.check_model(onnx_model)
logging.info(f"Loading image {args.input}")
input_img = cv2.imread(args.input, cv2.IMREAD_COLOR)
input_img = input_img.astype(np.float32)
# input_img = cv2.resize(input_img, (512, 512))
ort_session = onnxruntime.InferenceSession(args.model)
logging.info("converting to blob")
input_blob = cv2.dnn.blobFromImage(
image=input_img,
scalefactor=1 / 255,
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
img = Image.open(args.input).convert("RGB")
logging.info(f"Preprocessing image {args.input}")
tf = A.Compose(
[
A.ToFloat(max_value=255),
ToTensorV2(),
],
)
aug = tf(image=np.asarray(img))
img = aug["image"]
net.setInput(input_blob)
mask = net.forward()
mask = sigmoid(mask)
mask = mask > 0.5
mask = mask.astype(np.float32)
logging.info(f"Predicting image {args.input}")
img = img.unsqueeze(0)
logging.info(f"Saving prediction to {args.output}")
mask = Image.fromarray(mask, "L")
mask.save(args.output)
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
ort_outs = ort_session.run(None, ort_inputs)
img_out_y = ort_outs[0]
img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode="L")
img_out_y.save(args.output)