From 36b044c719effa98e6cacd325193d1fd53a5fd52 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Tue, 5 Jul 2022 12:17:32 +0200 Subject: [PATCH] feat: working prediction but only for 512x512 Former-commit-id: c9d88ad18de91409fc1be1f1abe59d6e75ff2235 [formerly 8bf12bad1c3e8424aa26c7bd9a441facc670b059] Former-commit-id: 820e327f8a79bad36d1f15944012e77ba1ecd560 --- .vscode/launch.json | 2 +- poetry.lock | 83 ++++++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 2 ++ src/predict.py | 48 +++++++++++++++----------- 4 files changed, 114 insertions(+), 21 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 4992c79..a0ae3f2 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -12,7 +12,7 @@ "console": "integratedTerminal", "args": [ "--input", - "images/SM.png", + "images/test.png", "--output", "output_onnx.png", "--model", diff --git a/poetry.lock b/poetry.lock index ec4dac1..dea7656 100644 --- a/poetry.lock +++ b/poetry.lock @@ -240,6 +240,14 @@ category = "dev" optional = false python-versions = "*" +[[package]] +name = "flatbuffers" +version = "2.0" +description = "The FlatBuffers serialization format for Python" +category = "main" +optional = false +python-versions = "*" + [[package]] name = "fonttools" version = "4.33.3" @@ -673,6 +681,35 @@ rsa = ["cryptography (>=3.0.0)"] signals = ["blinker (>=1.4.0)"] signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] +[[package]] +name = "onnx" +version = "1.12.0" +description = "Open Neural Network Exchange" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +numpy = ">=1.16.6" +protobuf = ">=3.12.2,<=3.20.1" +typing-extensions = ">=3.6.2.1" + +[package.extras] +lint = ["clang-format (==13.0.0)", "flake8", "mypy (==0.782)", "types-protobuf (==3.18.4)"] + +[[package]] +name = "onnxruntime" +version = "1.11.1" +description = "ONNX Runtime is a runtime accelerator for Machine Learning models" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +flatbuffers = "*" +numpy = ">=1.21.6" +protobuf = "*" + [[package]] name = "opencv-python-headless" version = "4.6.0.66" @@ -1446,7 +1483,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = ">=3.8,<3.11" -content-hash = "b192d0e5f593e99630bb92cd31c510dcdea67b0b54861176f92f50505724e7d5" +content-hash = "c1855a97cbe537d31526a76a49fde822e98acc89713f5f902639327b688c079a" [metadata.files] absl-py = [ @@ -1692,6 +1729,10 @@ executing = [ {file = "executing-0.8.3-py2.py3-none-any.whl", hash = "sha256:d1eef132db1b83649a3905ca6dd8897f71ac6f8cac79a7e58a1a09cf137546c9"}, {file = "executing-0.8.3.tar.gz", hash = "sha256:c6554e21c6b060590a6d3be4b82fb78f8f0194d809de5ea7df1c093763311501"}, ] +flatbuffers = [ + {file = "flatbuffers-2.0-py2.py3-none-any.whl", hash = "sha256:3751954f0604580d3219ae49a85fafec9d85eec599c0b96226e1bc0b48e57474"}, + {file = "flatbuffers-2.0.tar.gz", hash = "sha256:12158ab0272375eab8db2d663ae97370c33f152b27801fa6024e1d6105fd4dd2"}, +] fonttools = [ {file = "fonttools-4.33.3-py3-none-any.whl", hash = "sha256:f829c579a8678fa939a1d9e9894d01941db869de44390adb49ce67055a06cc2a"}, {file = "fonttools-4.33.3.zip", hash = "sha256:c0fdcfa8ceebd7c1b2021240bd46ef77aa8e7408cf10434be55df52384865f8e"}, @@ -2056,6 +2097,46 @@ oauthlib = [ {file = "oauthlib-3.2.0-py3-none-any.whl", hash = "sha256:6db33440354787f9b7f3a6dbd4febf5d0f93758354060e802f6c06cb493022fe"}, {file = "oauthlib-3.2.0.tar.gz", hash = "sha256:23a8208d75b902797ea29fd31fa80a15ed9dc2c6c16fe73f5d346f83f6fa27a2"}, ] +onnx = [ + {file = "onnx-1.12.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:bdbd2578424c70836f4d0f9dda16c21868ddb07cc8192f9e8a176908b43d694b"}, + {file = "onnx-1.12.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:213e73610173f6b2e99f99a4b0636f80b379c417312079d603806e48ada4ca8b"}, + {file = "onnx-1.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fd2f4e23078df197bb76a59b9cd8f5a43a6ad2edc035edb3ecfb9042093e05a"}, + {file = "onnx-1.12.0-cp310-cp310-win32.whl", hash = "sha256:23781594bb8b7ee985de1005b3c601648d5b0568a81e01365c48f91d1f5648e4"}, + {file = "onnx-1.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:81a3555fd67be2518bf86096299b48fb9154652596219890abfe90bd43a9ec13"}, + {file = "onnx-1.12.0-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:5578b93dc6c918cec4dee7fb7d9dd3b09d338301ee64ca8b4f28bc217ed42dca"}, + {file = "onnx-1.12.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c11162ffc487167da140f1112f49c4f82d815824f06e58bc3095407699f05863"}, + {file = "onnx-1.12.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:341c7016e23273e9ffa9b6e301eee95b8c37d0f04df7cedbdb169d2c39524c96"}, + {file = "onnx-1.12.0-cp37-cp37m-win32.whl", hash = "sha256:3c6e6bcffc3f5c1e148df3837dc667fa4c51999788c1b76b0b8fbba607e02da8"}, + {file = "onnx-1.12.0-cp37-cp37m-win_amd64.whl", hash = "sha256:8a7aa61aea339bd28f310f4af4f52ce6c4b876386228760b16308efd58f95059"}, + {file = "onnx-1.12.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:56ceb7e094c43882b723cfaa107d85ad673cfdf91faeb28d7dcadacca4f43a07"}, + {file = "onnx-1.12.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3629e8258db15d4e2c9b7f1be91a3186719dd94661c218c6f5fde3cc7de3d4d"}, + {file = "onnx-1.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d9a7db54e75529160337232282a4816cc50667dc7dc34be178fd6f6b79d4705"}, + {file = "onnx-1.12.0-cp38-cp38-win32.whl", hash = "sha256:fea5156a03398fe0e23248042d8651c1eaac5f6637d4dd683b4c1f1320b9f7b4"}, + {file = "onnx-1.12.0-cp38-cp38-win_amd64.whl", hash = "sha256:f66d2996e65f490a57b3ae952e4e9189b53cc9fe3f75e601d50d4db2dc1b1cd9"}, + {file = "onnx-1.12.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:c39a7a0352c856f1df30dccf527eb6cb4909052e5eaf6fa2772a637324c526aa"}, + {file = "onnx-1.12.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fab13feb4d94342aae6d357d480f2e47d41b9f4e584367542b21ca6defda9e0a"}, + {file = "onnx-1.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7a9b3ea02c30efc1d2662337e280266aca491a8e86be0d8a657f874b7cccd1e"}, + {file = "onnx-1.12.0-cp39-cp39-win32.whl", hash = "sha256:f8800f28c746ab06e51ef8449fd1215621f4ddba91be3ffc264658937d38a2af"}, + {file = "onnx-1.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:af90427ca04c6b7b8107c2021e1273227a3ef1a7a01f3073039cae7855a59833"}, + {file = "onnx-1.12.0.tar.gz", hash = "sha256:13b3e77d27523b9dbf4f30dfc9c959455859d5e34e921c44f712d69b8369eff9"}, +] +onnxruntime = [ + {file = "onnxruntime-1.11.1-cp37-cp37m-macosx_10_14_x86_64.whl", hash = "sha256:88b94a900754ef189c2b06f2046f2de8008753e0e8a3e562b2fb03298026b4b4"}, + {file = "onnxruntime-1.11.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:958974be7808b46815533c74e8849a2d73e73d656df8369a114ce3359f77760b"}, + {file = "onnxruntime-1.11.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3f83e7d52932b68b08cfba4920816efc7c3177036a90116137b11888e1f2490"}, + {file = "onnxruntime-1.11.1-cp37-cp37m-win32.whl", hash = "sha256:3106bfcd1532afcaa26fc47931f7a8770dc710263647e8fbb5f75fa5a8fc70f9"}, + {file = "onnxruntime-1.11.1-cp37-cp37m-win_amd64.whl", hash = "sha256:80775f4f64850b6774dbaa955888a89dc719cf654f1995ed5418e78c0139b5f4"}, + {file = "onnxruntime-1.11.1-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:b24a3cd1e6d7fe7c4c5be2996ba02ebf8beed6347b2fd3ac869d1c685a2e0264"}, + {file = "onnxruntime-1.11.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73c4df0a446fe49d59629746d2163fa39b732b6afb3b5d00f8c9ec91a040e5c4"}, + {file = "onnxruntime-1.11.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:149bf850c4e320e33894cab3e350a945ab17690cf54ffa00ef965273112ef614"}, + {file = "onnxruntime-1.11.1-cp38-cp38-win32.whl", hash = "sha256:9e202d7323b5728cdc3c0ee3bbc35f10cd56c7120c9626887c4ebe5d8503b488"}, + {file = "onnxruntime-1.11.1-cp38-cp38-win_amd64.whl", hash = "sha256:00632fc2ee3cf86349f5b00f5385a62fe5720ef14b471c919cc2c94faeb446d0"}, + {file = "onnxruntime-1.11.1-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:792985ddf3d3c46efa24bcfef970e7ccd4421d46173a96ca3974dab709598591"}, + {file = "onnxruntime-1.11.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ce95870b0bc7cbef5383b3c3062c6d9784af71f266192bc887928d7b927a46"}, + {file = "onnxruntime-1.11.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:53121b024f68d6b16bb93bc3fb73ba05b6f55647d12054a8efae7f48ed761add"}, + {file = "onnxruntime-1.11.1-cp39-cp39-win32.whl", hash = "sha256:b90124277454c50c5b2073bb9e1368b2a5672f30c2c8f3fe01393967dcd6dce2"}, + {file = "onnxruntime-1.11.1-cp39-cp39-win_amd64.whl", hash = "sha256:b1ffefc961fc607e5929fef92f3bc8bc48bd3a074b2a6448887be23eb313f75a"}, +] opencv-python-headless = [ {file = "opencv-python-headless-4.6.0.66.tar.gz", hash = "sha256:d5291d7e10aa2c19cab6fd86f0d61af8617290ecd2d7ffcb051e446868d04cc5"}, {file = "opencv_python_headless-4.6.0.66-cp36-abi3-macosx_10_15_x86_64.whl", hash = "sha256:21e70f8b0c04098cdf466d27184fe6c3820aaef944a22548db95099959c95889"}, diff --git a/pyproject.toml b/pyproject.toml index 426fef6..32eaa53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,8 @@ torch = "^1.11.0" torchvision = "^0.12.0" tqdm = "^4.64.0" wandb = "^0.12.19" +onnx = "^1.12.0" +onnxruntime = "^1.11.1" [tool.poetry.dev-dependencies] black = "^22.3.0" diff --git a/src/predict.py b/src/predict.py index 0df1bab..f1be604 100755 --- a/src/predict.py +++ b/src/predict.py @@ -2,8 +2,9 @@ import argparse import logging import albumentations as A -import cv2 import numpy as np +import onnx +import onnxruntime import torch from albumentations.pytorch import ToTensorV2 from PIL import Image @@ -46,26 +47,35 @@ if __name__ == "__main__": logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - net = cv2.dnn.readNetFromONNX(args.model) - logging.info("onnx model loaded") + onnx_model = onnx.load(args.model) + onnx.checker.check_model(onnx_model) - logging.info(f"Loading image {args.input}") - input_img = cv2.imread(args.input, cv2.IMREAD_COLOR) - input_img = input_img.astype(np.float32) - # input_img = cv2.resize(input_img, (512, 512)) + ort_session = onnxruntime.InferenceSession(args.model) - logging.info("converting to blob") - input_blob = cv2.dnn.blobFromImage( - image=input_img, - scalefactor=1 / 255, + def to_numpy(tensor): + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() + + img = Image.open(args.input).convert("RGB") + + logging.info(f"Preprocessing image {args.input}") + tf = A.Compose( + [ + A.ToFloat(max_value=255), + ToTensorV2(), + ], ) + aug = tf(image=np.asarray(img)) + img = aug["image"] - net.setInput(input_blob) - mask = net.forward() - mask = sigmoid(mask) - mask = mask > 0.5 - mask = mask.astype(np.float32) + logging.info(f"Predicting image {args.input}") + img = img.unsqueeze(0) - logging.info(f"Saving prediction to {args.output}") - mask = Image.fromarray(mask, "L") - mask.save(args.output) + # compute ONNX Runtime output prediction + ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)} + ort_outs = ort_session.run(None, ort_inputs) + + img_out_y = ort_outs[0] + + img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode="L") + + img_out_y.save(args.output)