mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
feat: working prediction but only for 512x512
Former-commit-id: c9d88ad18de91409fc1be1f1abe59d6e75ff2235 [formerly 8bf12bad1c3e8424aa26c7bd9a441facc670b059] Former-commit-id: 820e327f8a79bad36d1f15944012e77ba1ecd560
This commit is contained in:
parent
0fb1d4fb7a
commit
36b044c719
2
.vscode/launch.json
vendored
2
.vscode/launch.json
vendored
|
@ -12,7 +12,7 @@
|
|||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"--input",
|
||||
"images/SM.png",
|
||||
"images/test.png",
|
||||
"--output",
|
||||
"output_onnx.png",
|
||||
"--model",
|
||||
|
|
83
poetry.lock
generated
83
poetry.lock
generated
|
@ -240,6 +240,14 @@ category = "dev"
|
|||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[[package]]
|
||||
name = "flatbuffers"
|
||||
version = "2.0"
|
||||
description = "The FlatBuffers serialization format for Python"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[[package]]
|
||||
name = "fonttools"
|
||||
version = "4.33.3"
|
||||
|
@ -673,6 +681,35 @@ rsa = ["cryptography (>=3.0.0)"]
|
|||
signals = ["blinker (>=1.4.0)"]
|
||||
signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
|
||||
|
||||
[[package]]
|
||||
name = "onnx"
|
||||
version = "1.12.0"
|
||||
description = "Open Neural Network Exchange"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[package.dependencies]
|
||||
numpy = ">=1.16.6"
|
||||
protobuf = ">=3.12.2,<=3.20.1"
|
||||
typing-extensions = ">=3.6.2.1"
|
||||
|
||||
[package.extras]
|
||||
lint = ["clang-format (==13.0.0)", "flake8", "mypy (==0.782)", "types-protobuf (==3.18.4)"]
|
||||
|
||||
[[package]]
|
||||
name = "onnxruntime"
|
||||
version = "1.11.1"
|
||||
description = "ONNX Runtime is a runtime accelerator for Machine Learning models"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[package.dependencies]
|
||||
flatbuffers = "*"
|
||||
numpy = ">=1.21.6"
|
||||
protobuf = "*"
|
||||
|
||||
[[package]]
|
||||
name = "opencv-python-headless"
|
||||
version = "4.6.0.66"
|
||||
|
@ -1446,7 +1483,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-
|
|||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = ">=3.8,<3.11"
|
||||
content-hash = "b192d0e5f593e99630bb92cd31c510dcdea67b0b54861176f92f50505724e7d5"
|
||||
content-hash = "c1855a97cbe537d31526a76a49fde822e98acc89713f5f902639327b688c079a"
|
||||
|
||||
[metadata.files]
|
||||
absl-py = [
|
||||
|
@ -1692,6 +1729,10 @@ executing = [
|
|||
{file = "executing-0.8.3-py2.py3-none-any.whl", hash = "sha256:d1eef132db1b83649a3905ca6dd8897f71ac6f8cac79a7e58a1a09cf137546c9"},
|
||||
{file = "executing-0.8.3.tar.gz", hash = "sha256:c6554e21c6b060590a6d3be4b82fb78f8f0194d809de5ea7df1c093763311501"},
|
||||
]
|
||||
flatbuffers = [
|
||||
{file = "flatbuffers-2.0-py2.py3-none-any.whl", hash = "sha256:3751954f0604580d3219ae49a85fafec9d85eec599c0b96226e1bc0b48e57474"},
|
||||
{file = "flatbuffers-2.0.tar.gz", hash = "sha256:12158ab0272375eab8db2d663ae97370c33f152b27801fa6024e1d6105fd4dd2"},
|
||||
]
|
||||
fonttools = [
|
||||
{file = "fonttools-4.33.3-py3-none-any.whl", hash = "sha256:f829c579a8678fa939a1d9e9894d01941db869de44390adb49ce67055a06cc2a"},
|
||||
{file = "fonttools-4.33.3.zip", hash = "sha256:c0fdcfa8ceebd7c1b2021240bd46ef77aa8e7408cf10434be55df52384865f8e"},
|
||||
|
@ -2056,6 +2097,46 @@ oauthlib = [
|
|||
{file = "oauthlib-3.2.0-py3-none-any.whl", hash = "sha256:6db33440354787f9b7f3a6dbd4febf5d0f93758354060e802f6c06cb493022fe"},
|
||||
{file = "oauthlib-3.2.0.tar.gz", hash = "sha256:23a8208d75b902797ea29fd31fa80a15ed9dc2c6c16fe73f5d346f83f6fa27a2"},
|
||||
]
|
||||
onnx = [
|
||||
{file = "onnx-1.12.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:bdbd2578424c70836f4d0f9dda16c21868ddb07cc8192f9e8a176908b43d694b"},
|
||||
{file = "onnx-1.12.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:213e73610173f6b2e99f99a4b0636f80b379c417312079d603806e48ada4ca8b"},
|
||||
{file = "onnx-1.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fd2f4e23078df197bb76a59b9cd8f5a43a6ad2edc035edb3ecfb9042093e05a"},
|
||||
{file = "onnx-1.12.0-cp310-cp310-win32.whl", hash = "sha256:23781594bb8b7ee985de1005b3c601648d5b0568a81e01365c48f91d1f5648e4"},
|
||||
{file = "onnx-1.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:81a3555fd67be2518bf86096299b48fb9154652596219890abfe90bd43a9ec13"},
|
||||
{file = "onnx-1.12.0-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:5578b93dc6c918cec4dee7fb7d9dd3b09d338301ee64ca8b4f28bc217ed42dca"},
|
||||
{file = "onnx-1.12.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c11162ffc487167da140f1112f49c4f82d815824f06e58bc3095407699f05863"},
|
||||
{file = "onnx-1.12.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:341c7016e23273e9ffa9b6e301eee95b8c37d0f04df7cedbdb169d2c39524c96"},
|
||||
{file = "onnx-1.12.0-cp37-cp37m-win32.whl", hash = "sha256:3c6e6bcffc3f5c1e148df3837dc667fa4c51999788c1b76b0b8fbba607e02da8"},
|
||||
{file = "onnx-1.12.0-cp37-cp37m-win_amd64.whl", hash = "sha256:8a7aa61aea339bd28f310f4af4f52ce6c4b876386228760b16308efd58f95059"},
|
||||
{file = "onnx-1.12.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:56ceb7e094c43882b723cfaa107d85ad673cfdf91faeb28d7dcadacca4f43a07"},
|
||||
{file = "onnx-1.12.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3629e8258db15d4e2c9b7f1be91a3186719dd94661c218c6f5fde3cc7de3d4d"},
|
||||
{file = "onnx-1.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d9a7db54e75529160337232282a4816cc50667dc7dc34be178fd6f6b79d4705"},
|
||||
{file = "onnx-1.12.0-cp38-cp38-win32.whl", hash = "sha256:fea5156a03398fe0e23248042d8651c1eaac5f6637d4dd683b4c1f1320b9f7b4"},
|
||||
{file = "onnx-1.12.0-cp38-cp38-win_amd64.whl", hash = "sha256:f66d2996e65f490a57b3ae952e4e9189b53cc9fe3f75e601d50d4db2dc1b1cd9"},
|
||||
{file = "onnx-1.12.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:c39a7a0352c856f1df30dccf527eb6cb4909052e5eaf6fa2772a637324c526aa"},
|
||||
{file = "onnx-1.12.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fab13feb4d94342aae6d357d480f2e47d41b9f4e584367542b21ca6defda9e0a"},
|
||||
{file = "onnx-1.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7a9b3ea02c30efc1d2662337e280266aca491a8e86be0d8a657f874b7cccd1e"},
|
||||
{file = "onnx-1.12.0-cp39-cp39-win32.whl", hash = "sha256:f8800f28c746ab06e51ef8449fd1215621f4ddba91be3ffc264658937d38a2af"},
|
||||
{file = "onnx-1.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:af90427ca04c6b7b8107c2021e1273227a3ef1a7a01f3073039cae7855a59833"},
|
||||
{file = "onnx-1.12.0.tar.gz", hash = "sha256:13b3e77d27523b9dbf4f30dfc9c959455859d5e34e921c44f712d69b8369eff9"},
|
||||
]
|
||||
onnxruntime = [
|
||||
{file = "onnxruntime-1.11.1-cp37-cp37m-macosx_10_14_x86_64.whl", hash = "sha256:88b94a900754ef189c2b06f2046f2de8008753e0e8a3e562b2fb03298026b4b4"},
|
||||
{file = "onnxruntime-1.11.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:958974be7808b46815533c74e8849a2d73e73d656df8369a114ce3359f77760b"},
|
||||
{file = "onnxruntime-1.11.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3f83e7d52932b68b08cfba4920816efc7c3177036a90116137b11888e1f2490"},
|
||||
{file = "onnxruntime-1.11.1-cp37-cp37m-win32.whl", hash = "sha256:3106bfcd1532afcaa26fc47931f7a8770dc710263647e8fbb5f75fa5a8fc70f9"},
|
||||
{file = "onnxruntime-1.11.1-cp37-cp37m-win_amd64.whl", hash = "sha256:80775f4f64850b6774dbaa955888a89dc719cf654f1995ed5418e78c0139b5f4"},
|
||||
{file = "onnxruntime-1.11.1-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:b24a3cd1e6d7fe7c4c5be2996ba02ebf8beed6347b2fd3ac869d1c685a2e0264"},
|
||||
{file = "onnxruntime-1.11.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73c4df0a446fe49d59629746d2163fa39b732b6afb3b5d00f8c9ec91a040e5c4"},
|
||||
{file = "onnxruntime-1.11.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:149bf850c4e320e33894cab3e350a945ab17690cf54ffa00ef965273112ef614"},
|
||||
{file = "onnxruntime-1.11.1-cp38-cp38-win32.whl", hash = "sha256:9e202d7323b5728cdc3c0ee3bbc35f10cd56c7120c9626887c4ebe5d8503b488"},
|
||||
{file = "onnxruntime-1.11.1-cp38-cp38-win_amd64.whl", hash = "sha256:00632fc2ee3cf86349f5b00f5385a62fe5720ef14b471c919cc2c94faeb446d0"},
|
||||
{file = "onnxruntime-1.11.1-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:792985ddf3d3c46efa24bcfef970e7ccd4421d46173a96ca3974dab709598591"},
|
||||
{file = "onnxruntime-1.11.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ce95870b0bc7cbef5383b3c3062c6d9784af71f266192bc887928d7b927a46"},
|
||||
{file = "onnxruntime-1.11.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:53121b024f68d6b16bb93bc3fb73ba05b6f55647d12054a8efae7f48ed761add"},
|
||||
{file = "onnxruntime-1.11.1-cp39-cp39-win32.whl", hash = "sha256:b90124277454c50c5b2073bb9e1368b2a5672f30c2c8f3fe01393967dcd6dce2"},
|
||||
{file = "onnxruntime-1.11.1-cp39-cp39-win_amd64.whl", hash = "sha256:b1ffefc961fc607e5929fef92f3bc8bc48bd3a074b2a6448887be23eb313f75a"},
|
||||
]
|
||||
opencv-python-headless = [
|
||||
{file = "opencv-python-headless-4.6.0.66.tar.gz", hash = "sha256:d5291d7e10aa2c19cab6fd86f0d61af8617290ecd2d7ffcb051e446868d04cc5"},
|
||||
{file = "opencv_python_headless-4.6.0.66-cp36-abi3-macosx_10_15_x86_64.whl", hash = "sha256:21e70f8b0c04098cdf466d27184fe6c3820aaef944a22548db95099959c95889"},
|
||||
|
|
|
@ -15,6 +15,8 @@ torch = "^1.11.0"
|
|||
torchvision = "^0.12.0"
|
||||
tqdm = "^4.64.0"
|
||||
wandb = "^0.12.19"
|
||||
onnx = "^1.12.0"
|
||||
onnxruntime = "^1.11.1"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
black = "^22.3.0"
|
||||
|
|
|
@ -2,8 +2,9 @@ import argparse
|
|||
import logging
|
||||
|
||||
import albumentations as A
|
||||
import cv2
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnxruntime
|
||||
import torch
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from PIL import Image
|
||||
|
@ -46,26 +47,35 @@ if __name__ == "__main__":
|
|||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
|
||||
net = cv2.dnn.readNetFromONNX(args.model)
|
||||
logging.info("onnx model loaded")
|
||||
onnx_model = onnx.load(args.model)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
|
||||
logging.info(f"Loading image {args.input}")
|
||||
input_img = cv2.imread(args.input, cv2.IMREAD_COLOR)
|
||||
input_img = input_img.astype(np.float32)
|
||||
# input_img = cv2.resize(input_img, (512, 512))
|
||||
ort_session = onnxruntime.InferenceSession(args.model)
|
||||
|
||||
logging.info("converting to blob")
|
||||
input_blob = cv2.dnn.blobFromImage(
|
||||
image=input_img,
|
||||
scalefactor=1 / 255,
|
||||
def to_numpy(tensor):
|
||||
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
|
||||
|
||||
img = Image.open(args.input).convert("RGB")
|
||||
|
||||
logging.info(f"Preprocessing image {args.input}")
|
||||
tf = A.Compose(
|
||||
[
|
||||
A.ToFloat(max_value=255),
|
||||
ToTensorV2(),
|
||||
],
|
||||
)
|
||||
aug = tf(image=np.asarray(img))
|
||||
img = aug["image"]
|
||||
|
||||
net.setInput(input_blob)
|
||||
mask = net.forward()
|
||||
mask = sigmoid(mask)
|
||||
mask = mask > 0.5
|
||||
mask = mask.astype(np.float32)
|
||||
logging.info(f"Predicting image {args.input}")
|
||||
img = img.unsqueeze(0)
|
||||
|
||||
logging.info(f"Saving prediction to {args.output}")
|
||||
mask = Image.fromarray(mask, "L")
|
||||
mask.save(args.output)
|
||||
# compute ONNX Runtime output prediction
|
||||
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
|
||||
ort_outs = ort_session.run(None, ort_inputs)
|
||||
|
||||
img_out_y = ort_outs[0]
|
||||
|
||||
img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode="L")
|
||||
|
||||
img_out_y.save(args.output)
|
||||
|
|
Loading…
Reference in a new issue