Merge branch 'onnx' into pl

Former-commit-id: 7f28adc38d86319181dea3e311976162ae20b6d2 [formerly 2b8801777106b1a17a9cef2c3a7d65e6ad780f3f]
Former-commit-id: 6d658bfd379005fcfdb75f97509be8dc920a8ef5
This commit is contained in:
Laurent Fainsin 2022-07-08 11:07:26 +02:00
commit 70b19b3b94
5 changed files with 1932 additions and 1534 deletions

8
.vscode/launch.json vendored
View file

@ -12,11 +12,13 @@
"console": "integratedTerminal", "console": "integratedTerminal",
"args": [ "args": [
"--input", "--input",
"images/SM.png", "images/test.png",
"--output", "--output",
"output.png", "output_onnx.png",
"--model",
"good.onnx",
], ],
"justMyCode": true "justMyCode": true
} }
] ]
} }

3409
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -8,14 +8,16 @@ version = "0.1.0"
albumentations = "^1.2.0" albumentations = "^1.2.0"
matplotlib = "^3.5.2" matplotlib = "^3.5.2"
numpy = "^1.23.0" numpy = "^1.23.0"
onnx = "^1.12.0"
onnxruntime = "^1.11.1"
python = ">=3.8,<3.11" python = ">=3.8,<3.11"
pytorch-lightning = "^1.6.4" pytorch-lightning = "^1.6.4"
rich = "^12.4.4"
scipy = "^1.8.1" scipy = "^1.8.1"
torch = "^1.11.0" torch = "^1.11.0"
torchvision = "^0.12.0" torchvision = "^0.12.0"
tqdm = "^4.64.0" tqdm = "^4.64.0"
wandb = "^0.12.19" wandb = "^0.12.19"
rich = "^12.4.4"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
black = "^22.3.0" black = "^22.3.0"

View file

@ -3,12 +3,12 @@ import logging
import albumentations as A import albumentations as A
import numpy as np import numpy as np
import onnx
import onnxruntime
import torch import torch
from albumentations.pytorch import ToTensorV2 from albumentations.pytorch import ToTensorV2
from PIL import Image from PIL import Image
from unet import UNet
def get_args(): def get_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -38,23 +38,23 @@ def get_args():
return parser.parse_args() return parser.parse_args()
def sigmoid(x):
return 1 / (1 + np.exp(-x))
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
net = UNet(n_channels=3, n_classes=1) onnx_model = onnx.load(args.model)
onnx.checker.check_model(onnx_model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ort_session = onnxruntime.InferenceSession(args.model)
logging.info(f"Using device {device}")
logging.info("Transfering model to device") def to_numpy(tensor):
net.to(device=device) return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
logging.info(f"Loading model {args.model}")
net.load_state_dict(torch.load(args.model, map_location=device))
logging.info(f"Loading image {args.input}")
img = Image.open(args.input).convert("RGB") img = Image.open(args.input).convert("RGB")
logging.info(f"Preprocessing image {args.input}") logging.info(f"Preprocessing image {args.input}")
@ -68,17 +68,14 @@ if __name__ == "__main__":
img = aug["image"] img = aug["image"]
logging.info(f"Predicting image {args.input}") logging.info(f"Predicting image {args.input}")
img = img.unsqueeze(0).to(device=device, dtype=torch.float32) img = img.unsqueeze(0)
net.eval() # compute ONNX Runtime output prediction
with torch.inference_mode(): ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
mask = net(img) ort_outs = ort_session.run(None, ort_inputs)
mask = torch.sigmoid(mask)[0]
mask = mask.cpu()
mask = mask.squeeze()
mask = mask > 0.5
mask = np.asarray(mask)
logging.info(f"Saving prediction to {args.output}") img_out_y = ort_outs[0]
mask = Image.fromarray(mask)
mask.save(args.output) img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode="L")
img_out_y.save(args.output)

View file

@ -24,7 +24,7 @@ class RandomPaste(A.DualTransform):
self, self,
nb, nb,
image_dir, image_dir,
scale_range=(0.1, 0.2), scale_range=(0.05, 0.25),
always_apply=True, always_apply=True,
p=1.0, p=1.0,
): ):