From 0fb1d4fb7a52fd95d5f68d1eccf5f8324f4fe08c Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Tue, 5 Jul 2022 12:06:12 +0200 Subject: [PATCH 1/2] feat(WIP): broken onnx prediction Former-commit-id: cde7623ec486cf79a710949085aadd92d8a33a3e [formerly db0f1d0b9ea536c741f23a3b683e19a9335bcd35] Former-commit-id: 7332ccb0f74c58c3a284a4568fb8f80a6d416cf4 --- .vscode/launch.json | 6 ++-- src/predict.py | 53 +++++++++++---------------- src/train.py | 40 ++++++++++++--------- src/utils/dice.py | 88 ++++++++------------------------------------- src/utils/paste.py | 2 +- 5 files changed, 64 insertions(+), 125 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 178c089..4992c79 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -14,9 +14,11 @@ "--input", "images/SM.png", "--output", - "output.png", + "output_onnx.png", + "--model", + "good.onnx", ], "justMyCode": true } ] -} \ No newline at end of file +} diff --git a/src/predict.py b/src/predict.py index a69f7bc..0df1bab 100755 --- a/src/predict.py +++ b/src/predict.py @@ -2,13 +2,12 @@ import argparse import logging import albumentations as A +import cv2 import numpy as np import torch from albumentations.pytorch import ToTensorV2 from PIL import Image -from unet import UNet - def get_args(): parser = argparse.ArgumentParser( @@ -38,47 +37,35 @@ def get_args(): return parser.parse_args() +def sigmoid(x): + return 1 / (1 + np.exp(-x)) + + if __name__ == "__main__": args = get_args() logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - net = UNet(n_channels=3, n_classes=1) - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - logging.info(f"Using device {device}") - - logging.info("Transfering model to device") - net.to(device=device) - - logging.info(f"Loading model {args.model}") - net.load_state_dict(torch.load(args.model, map_location=device)) + net = cv2.dnn.readNetFromONNX(args.model) + logging.info("onnx model loaded") logging.info(f"Loading image {args.input}") - img = Image.open(args.input).convert("RGB") + input_img = cv2.imread(args.input, cv2.IMREAD_COLOR) + input_img = input_img.astype(np.float32) + # input_img = cv2.resize(input_img, (512, 512)) - logging.info(f"Preprocessing image {args.input}") - tf = A.Compose( - [ - A.ToFloat(max_value=255), - ToTensorV2(), - ], + logging.info("converting to blob") + input_blob = cv2.dnn.blobFromImage( + image=input_img, + scalefactor=1 / 255, ) - aug = tf(image=np.asarray(img)) - img = aug["image"] - logging.info(f"Predicting image {args.input}") - img = img.unsqueeze(0).to(device=device, dtype=torch.float32) - - net.eval() - with torch.inference_mode(): - mask = net(img) - mask = torch.sigmoid(mask)[0] - mask = mask.cpu() - mask = mask.squeeze() - mask = mask > 0.5 - mask = np.asarray(mask) + net.setInput(input_blob) + mask = net.forward() + mask = sigmoid(mask) + mask = mask > 0.5 + mask = mask.astype(np.float32) logging.info(f"Saving prediction to {args.output}") - mask = Image.fromarray(mask) + mask = Image.fromarray(mask, "L") mask.save(args.output) diff --git a/src/train.py b/src/train.py index 9ddac8e..2963c3b 100644 --- a/src/train.py +++ b/src/train.py @@ -5,12 +5,13 @@ import torch import yaml from albumentations.pytorch import ToTensorV2 from torch.utils.data import DataLoader +from torchmetrics import Dice from tqdm import tqdm import wandb from src.utils.dataset import SphereDataset from unet import UNet -from utils.dice import dice_coeff +from utils.dice import DiceLoss from utils.paste import RandomPaste class_labels = { @@ -37,8 +38,8 @@ if __name__ == "__main__": PIN_MEMORY=True, BENCHMARK=True, DEVICE="cuda", - WORKERS=8, - EPOCHS=5, + WORKERS=7, + EPOCHS=1001, BATCH_SIZE=16, LEARNING_RATE=1e-4, WEIGHT_DECAY=1e-8, @@ -92,9 +93,13 @@ if __name__ == "__main__": ds_test = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG) # 2.5. Create subset, if uncommented - ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000))) - ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 1000))) - ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100))) + # ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000))) + # ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100))) + # ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100))) + + ds_train = torch.utils.data.Subset(ds_train, [0]) + ds_valid = torch.utils.data.Subset(ds_valid, [0]) + ds_test = torch.utils.data.Subset(ds_test, [0]) # 3. Create data loaders train_loader = DataLoader( @@ -131,18 +136,19 @@ if __name__ == "__main__": scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP) criterion = torch.nn.BCEWithLogitsLoss() + dice_loss = DiceLoss() # save model.onxx dummy_input = torch.randn( 1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True ).to(device) - torch.onnx.export(net, dummy_input, "checkpoints/model-0.onnx") + torch.onnx.export(net, dummy_input, "checkpoints/model.onnx") artifact = wandb.Artifact("onnx", type="model") artifact.add_file("checkpoints/model-0.onnx") wandb.run.log_artifact(artifact) # log gradients and weights four time per epoch - wandb.watch(net, criterion, log_freq=100) + wandb.watch(net, log_freq=100) # print the config logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}") @@ -176,6 +182,8 @@ if __name__ == "__main__": pred_masks = net(images) train_loss = criterion(pred_masks, true_masks) + # compute loss + # backward optimizer.zero_grad(set_to_none=True) grad_scaler.scale(train_loss).backward() @@ -185,7 +193,7 @@ if __name__ == "__main__": # compute metrics pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float() accuracy = (true_masks == pred_masks_bin).float().mean() - dice = dice_coeff(pred_masks_bin, true_masks) + dice = dice_loss.coeff(pred_masks, true_masks) mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks) # update tqdm progress bar @@ -197,13 +205,13 @@ if __name__ == "__main__": { "epoch": epoch - 1 + step / len(train_loader), "train/accuracy": accuracy, - "train/bce": train_loss, + "train/loss": train_loss, "train/dice": dice, "train/mae": mae, } ) - if step and (step % 250 == 0 or step == len(train_loader)): + if step and (step % 100 == 0 or step == len(train_loader)): # Evaluation round net.eval() accuracy = 0 @@ -223,10 +231,10 @@ if __name__ == "__main__": # compute metrics val_loss += criterion(masks_pred, masks_true) + dice += dice_loss.coeff(pred_masks, true_masks) masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true) accuracy += (masks_true == masks_pred_bin).float().mean() - dice += dice_coeff(masks_pred_bin, masks_true) # update progress bar pbar2.update(images.shape[0]) @@ -267,7 +275,7 @@ if __name__ == "__main__": "val/predictions": table, "train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"], "val/accuracy": accuracy, - "val/bce": val_loss, + "val/loss": val_loss, "val/dice": dice, "val/mae": mae, }, @@ -276,7 +284,7 @@ if __name__ == "__main__": # update hyperparameters net.train() - scheduler.step(dice) + scheduler.step(train_loss) # export model to onnx format when validation ends dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device) @@ -304,10 +312,10 @@ if __name__ == "__main__": # compute metrics val_loss += criterion(masks_pred, masks_true) + dice += dice_loss.coeff(pred_masks, true_masks) masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true) accuracy += (masks_true == masks_pred_bin).float().mean() - dice += dice_coeff(masks_pred_bin, masks_true) # update progress bar pbar3.update(images.shape[0]) @@ -347,7 +355,7 @@ if __name__ == "__main__": { "test/predictions": table, "test/accuracy": accuracy, - "test/bce": val_loss, + "test/loss": val_loss, "test/dice": dice, "test/mae": mae, }, diff --git a/src/utils/dice.py b/src/utils/dice.py index a29f794..c4b201c 100644 --- a/src/utils/dice.py +++ b/src/utils/dice.py @@ -1,80 +1,22 @@ import torch -from torch import Tensor +import torch.nn as nn -def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float: - """Average of Dice coefficient for all batches, or for a single mask. +class DiceLoss(nn.Module): + def __init__(self, weight=None, size_average=True): + super(DiceLoss, self).__init__() - Args: - input (Tensor): _description_ - target (Tensor): _description_ - reduce_batch_first (bool, optional): _description_. Defaults to False. - epsilon (_type_, optional): _description_. Defaults to 1e-6. + @staticmethod + def coeff(inputs, targets, smooth=1): + # comment out if your model contains a sigmoid or equivalent activation layer + inputs = torch.sigmoid(inputs) - Raises: - ValueError: _description_ + # flatten label and prediction tensors + inputs = inputs.view(-1) + targets = targets.view(-1) - Returns: - float: _description_ - """ - assert input.size() == target.size() + intersection = (inputs * targets).sum() + return (2.0 * intersection + smooth) / (inputs.sum() + targets.sum() + smooth) - if input.dim() == 2 and reduce_batch_first: - raise ValueError(f"Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})") - - if input.dim() == 2 or reduce_batch_first: - inter = torch.dot(input.reshape(-1), target.reshape(-1)) - sets_sum = torch.sum(input) + torch.sum(target) - - if sets_sum.item() == 0: - sets_sum = 2 * inter - - return (2 * inter + epsilon) / (sets_sum + epsilon) - else: - # compute and average metric for each batch element - dice = 0 - - for i in range(input.shape[0]): - dice += dice_coeff(input[i, ...], target[i, ...]) - - return dice / input.shape[0] - - -def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float: - """Average of Dice coefficient for all classes. - - Args: - input (Tensor): _description_ - target (Tensor): _description_ - reduce_batch_first (bool, optional): _description_. Defaults to False. - epsilon (_type_, optional): _description_. Defaults to 1e-6. - - Returns: - float: _description_ - """ - assert input.size() == target.size() - - dice = 0 - - for channel in range(input.shape[1]): - dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon) - - return dice / input.shape[1] - - -def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False) -> float: - """Dice loss (objective to minimize) between 0 and 1. - - Args: - input (Tensor): _description_ - target (Tensor): _description_ - multiclass (bool, optional): _description_. Defaults to False. - - Returns: - float: _description_ - """ - assert input.size() == target.size() - - fn = multiclass_dice_coeff if multiclass else dice_coeff - - return 1 - fn(input, target, reduce_batch_first=True) + def forward(self, inputs, targets, smooth=1): + return 1 - self.coeff(inputs, targets, smooth) diff --git a/src/utils/paste.py b/src/utils/paste.py index 486a8ec..90ef0a8 100644 --- a/src/utils/paste.py +++ b/src/utils/paste.py @@ -24,7 +24,7 @@ class RandomPaste(A.DualTransform): nb, path_paste_img_dir, path_paste_mask_dir, - scale_range=(0.1, 0.2), + scale_range=(0.05, 0.25), always_apply=True, p=1.0, ): From 36b044c719effa98e6cacd325193d1fd53a5fd52 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Tue, 5 Jul 2022 12:17:32 +0200 Subject: [PATCH 2/2] feat: working prediction but only for 512x512 Former-commit-id: c9d88ad18de91409fc1be1f1abe59d6e75ff2235 [formerly 8bf12bad1c3e8424aa26c7bd9a441facc670b059] Former-commit-id: 820e327f8a79bad36d1f15944012e77ba1ecd560 --- .vscode/launch.json | 2 +- poetry.lock | 83 ++++++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 2 ++ src/predict.py | 48 +++++++++++++++----------- 4 files changed, 114 insertions(+), 21 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 4992c79..a0ae3f2 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -12,7 +12,7 @@ "console": "integratedTerminal", "args": [ "--input", - "images/SM.png", + "images/test.png", "--output", "output_onnx.png", "--model", diff --git a/poetry.lock b/poetry.lock index ec4dac1..dea7656 100644 --- a/poetry.lock +++ b/poetry.lock @@ -240,6 +240,14 @@ category = "dev" optional = false python-versions = "*" +[[package]] +name = "flatbuffers" +version = "2.0" +description = "The FlatBuffers serialization format for Python" +category = "main" +optional = false +python-versions = "*" + [[package]] name = "fonttools" version = "4.33.3" @@ -673,6 +681,35 @@ rsa = ["cryptography (>=3.0.0)"] signals = ["blinker (>=1.4.0)"] signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] +[[package]] +name = "onnx" +version = "1.12.0" +description = "Open Neural Network Exchange" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +numpy = ">=1.16.6" +protobuf = ">=3.12.2,<=3.20.1" +typing-extensions = ">=3.6.2.1" + +[package.extras] +lint = ["clang-format (==13.0.0)", "flake8", "mypy (==0.782)", "types-protobuf (==3.18.4)"] + +[[package]] +name = "onnxruntime" +version = "1.11.1" +description = "ONNX Runtime is a runtime accelerator for Machine Learning models" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +flatbuffers = "*" +numpy = ">=1.21.6" +protobuf = "*" + [[package]] name = "opencv-python-headless" version = "4.6.0.66" @@ -1446,7 +1483,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = ">=3.8,<3.11" -content-hash = "b192d0e5f593e99630bb92cd31c510dcdea67b0b54861176f92f50505724e7d5" +content-hash = "c1855a97cbe537d31526a76a49fde822e98acc89713f5f902639327b688c079a" [metadata.files] absl-py = [ @@ -1692,6 +1729,10 @@ executing = [ {file = "executing-0.8.3-py2.py3-none-any.whl", hash = "sha256:d1eef132db1b83649a3905ca6dd8897f71ac6f8cac79a7e58a1a09cf137546c9"}, {file = "executing-0.8.3.tar.gz", hash = "sha256:c6554e21c6b060590a6d3be4b82fb78f8f0194d809de5ea7df1c093763311501"}, ] +flatbuffers = [ + {file = "flatbuffers-2.0-py2.py3-none-any.whl", hash = "sha256:3751954f0604580d3219ae49a85fafec9d85eec599c0b96226e1bc0b48e57474"}, + {file = "flatbuffers-2.0.tar.gz", hash = "sha256:12158ab0272375eab8db2d663ae97370c33f152b27801fa6024e1d6105fd4dd2"}, +] fonttools = [ {file = "fonttools-4.33.3-py3-none-any.whl", hash = "sha256:f829c579a8678fa939a1d9e9894d01941db869de44390adb49ce67055a06cc2a"}, {file = "fonttools-4.33.3.zip", hash = "sha256:c0fdcfa8ceebd7c1b2021240bd46ef77aa8e7408cf10434be55df52384865f8e"}, @@ -2056,6 +2097,46 @@ oauthlib = [ {file = "oauthlib-3.2.0-py3-none-any.whl", hash = "sha256:6db33440354787f9b7f3a6dbd4febf5d0f93758354060e802f6c06cb493022fe"}, {file = "oauthlib-3.2.0.tar.gz", hash = "sha256:23a8208d75b902797ea29fd31fa80a15ed9dc2c6c16fe73f5d346f83f6fa27a2"}, ] +onnx = [ + {file = "onnx-1.12.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:bdbd2578424c70836f4d0f9dda16c21868ddb07cc8192f9e8a176908b43d694b"}, + {file = "onnx-1.12.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:213e73610173f6b2e99f99a4b0636f80b379c417312079d603806e48ada4ca8b"}, + {file = "onnx-1.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fd2f4e23078df197bb76a59b9cd8f5a43a6ad2edc035edb3ecfb9042093e05a"}, + {file = "onnx-1.12.0-cp310-cp310-win32.whl", hash = "sha256:23781594bb8b7ee985de1005b3c601648d5b0568a81e01365c48f91d1f5648e4"}, + {file = "onnx-1.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:81a3555fd67be2518bf86096299b48fb9154652596219890abfe90bd43a9ec13"}, + {file = "onnx-1.12.0-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:5578b93dc6c918cec4dee7fb7d9dd3b09d338301ee64ca8b4f28bc217ed42dca"}, + {file = "onnx-1.12.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c11162ffc487167da140f1112f49c4f82d815824f06e58bc3095407699f05863"}, + {file = "onnx-1.12.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:341c7016e23273e9ffa9b6e301eee95b8c37d0f04df7cedbdb169d2c39524c96"}, + {file = "onnx-1.12.0-cp37-cp37m-win32.whl", hash = "sha256:3c6e6bcffc3f5c1e148df3837dc667fa4c51999788c1b76b0b8fbba607e02da8"}, + {file = "onnx-1.12.0-cp37-cp37m-win_amd64.whl", hash = "sha256:8a7aa61aea339bd28f310f4af4f52ce6c4b876386228760b16308efd58f95059"}, + {file = "onnx-1.12.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:56ceb7e094c43882b723cfaa107d85ad673cfdf91faeb28d7dcadacca4f43a07"}, + {file = "onnx-1.12.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3629e8258db15d4e2c9b7f1be91a3186719dd94661c218c6f5fde3cc7de3d4d"}, + {file = "onnx-1.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d9a7db54e75529160337232282a4816cc50667dc7dc34be178fd6f6b79d4705"}, + {file = "onnx-1.12.0-cp38-cp38-win32.whl", hash = "sha256:fea5156a03398fe0e23248042d8651c1eaac5f6637d4dd683b4c1f1320b9f7b4"}, + {file = "onnx-1.12.0-cp38-cp38-win_amd64.whl", hash = "sha256:f66d2996e65f490a57b3ae952e4e9189b53cc9fe3f75e601d50d4db2dc1b1cd9"}, + {file = "onnx-1.12.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:c39a7a0352c856f1df30dccf527eb6cb4909052e5eaf6fa2772a637324c526aa"}, + {file = "onnx-1.12.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fab13feb4d94342aae6d357d480f2e47d41b9f4e584367542b21ca6defda9e0a"}, + {file = "onnx-1.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7a9b3ea02c30efc1d2662337e280266aca491a8e86be0d8a657f874b7cccd1e"}, + {file = "onnx-1.12.0-cp39-cp39-win32.whl", hash = "sha256:f8800f28c746ab06e51ef8449fd1215621f4ddba91be3ffc264658937d38a2af"}, + {file = "onnx-1.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:af90427ca04c6b7b8107c2021e1273227a3ef1a7a01f3073039cae7855a59833"}, + {file = "onnx-1.12.0.tar.gz", hash = "sha256:13b3e77d27523b9dbf4f30dfc9c959455859d5e34e921c44f712d69b8369eff9"}, +] +onnxruntime = [ + {file = "onnxruntime-1.11.1-cp37-cp37m-macosx_10_14_x86_64.whl", hash = "sha256:88b94a900754ef189c2b06f2046f2de8008753e0e8a3e562b2fb03298026b4b4"}, + {file = "onnxruntime-1.11.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:958974be7808b46815533c74e8849a2d73e73d656df8369a114ce3359f77760b"}, + {file = "onnxruntime-1.11.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3f83e7d52932b68b08cfba4920816efc7c3177036a90116137b11888e1f2490"}, + {file = "onnxruntime-1.11.1-cp37-cp37m-win32.whl", hash = "sha256:3106bfcd1532afcaa26fc47931f7a8770dc710263647e8fbb5f75fa5a8fc70f9"}, + {file = "onnxruntime-1.11.1-cp37-cp37m-win_amd64.whl", hash = "sha256:80775f4f64850b6774dbaa955888a89dc719cf654f1995ed5418e78c0139b5f4"}, + {file = "onnxruntime-1.11.1-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:b24a3cd1e6d7fe7c4c5be2996ba02ebf8beed6347b2fd3ac869d1c685a2e0264"}, + {file = "onnxruntime-1.11.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73c4df0a446fe49d59629746d2163fa39b732b6afb3b5d00f8c9ec91a040e5c4"}, + {file = "onnxruntime-1.11.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:149bf850c4e320e33894cab3e350a945ab17690cf54ffa00ef965273112ef614"}, + {file = "onnxruntime-1.11.1-cp38-cp38-win32.whl", hash = "sha256:9e202d7323b5728cdc3c0ee3bbc35f10cd56c7120c9626887c4ebe5d8503b488"}, + {file = "onnxruntime-1.11.1-cp38-cp38-win_amd64.whl", hash = "sha256:00632fc2ee3cf86349f5b00f5385a62fe5720ef14b471c919cc2c94faeb446d0"}, + {file = "onnxruntime-1.11.1-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:792985ddf3d3c46efa24bcfef970e7ccd4421d46173a96ca3974dab709598591"}, + {file = "onnxruntime-1.11.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ce95870b0bc7cbef5383b3c3062c6d9784af71f266192bc887928d7b927a46"}, + {file = "onnxruntime-1.11.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:53121b024f68d6b16bb93bc3fb73ba05b6f55647d12054a8efae7f48ed761add"}, + {file = "onnxruntime-1.11.1-cp39-cp39-win32.whl", hash = "sha256:b90124277454c50c5b2073bb9e1368b2a5672f30c2c8f3fe01393967dcd6dce2"}, + {file = "onnxruntime-1.11.1-cp39-cp39-win_amd64.whl", hash = "sha256:b1ffefc961fc607e5929fef92f3bc8bc48bd3a074b2a6448887be23eb313f75a"}, +] opencv-python-headless = [ {file = "opencv-python-headless-4.6.0.66.tar.gz", hash = "sha256:d5291d7e10aa2c19cab6fd86f0d61af8617290ecd2d7ffcb051e446868d04cc5"}, {file = "opencv_python_headless-4.6.0.66-cp36-abi3-macosx_10_15_x86_64.whl", hash = "sha256:21e70f8b0c04098cdf466d27184fe6c3820aaef944a22548db95099959c95889"}, diff --git a/pyproject.toml b/pyproject.toml index 426fef6..32eaa53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,8 @@ torch = "^1.11.0" torchvision = "^0.12.0" tqdm = "^4.64.0" wandb = "^0.12.19" +onnx = "^1.12.0" +onnxruntime = "^1.11.1" [tool.poetry.dev-dependencies] black = "^22.3.0" diff --git a/src/predict.py b/src/predict.py index 0df1bab..f1be604 100755 --- a/src/predict.py +++ b/src/predict.py @@ -2,8 +2,9 @@ import argparse import logging import albumentations as A -import cv2 import numpy as np +import onnx +import onnxruntime import torch from albumentations.pytorch import ToTensorV2 from PIL import Image @@ -46,26 +47,35 @@ if __name__ == "__main__": logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - net = cv2.dnn.readNetFromONNX(args.model) - logging.info("onnx model loaded") + onnx_model = onnx.load(args.model) + onnx.checker.check_model(onnx_model) - logging.info(f"Loading image {args.input}") - input_img = cv2.imread(args.input, cv2.IMREAD_COLOR) - input_img = input_img.astype(np.float32) - # input_img = cv2.resize(input_img, (512, 512)) + ort_session = onnxruntime.InferenceSession(args.model) - logging.info("converting to blob") - input_blob = cv2.dnn.blobFromImage( - image=input_img, - scalefactor=1 / 255, + def to_numpy(tensor): + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() + + img = Image.open(args.input).convert("RGB") + + logging.info(f"Preprocessing image {args.input}") + tf = A.Compose( + [ + A.ToFloat(max_value=255), + ToTensorV2(), + ], ) + aug = tf(image=np.asarray(img)) + img = aug["image"] - net.setInput(input_blob) - mask = net.forward() - mask = sigmoid(mask) - mask = mask > 0.5 - mask = mask.astype(np.float32) + logging.info(f"Predicting image {args.input}") + img = img.unsqueeze(0) - logging.info(f"Saving prediction to {args.output}") - mask = Image.fromarray(mask, "L") - mask.save(args.output) + # compute ONNX Runtime output prediction + ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)} + ort_outs = ort_session.run(None, ort_inputs) + + img_out_y = ort_outs[0] + + img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode="L") + + img_out_y.save(args.output)