Merge branch 'onnx' into pl

Former-commit-id: 7f28adc38d86319181dea3e311976162ae20b6d2 [formerly 2b8801777106b1a17a9cef2c3a7d65e6ad780f3f]
Former-commit-id: 6d658bfd379005fcfdb75f97509be8dc920a8ef5
This commit is contained in:
Laurent Fainsin 2022-07-08 11:07:26 +02:00
commit 70b19b3b94
5 changed files with 1932 additions and 1534 deletions

6
.vscode/launch.json vendored
View file

@ -12,9 +12,11 @@
"console": "integratedTerminal",
"args": [
"--input",
"images/SM.png",
"images/test.png",
"--output",
"output.png",
"output_onnx.png",
"--model",
"good.onnx",
],
"justMyCode": true
}

1127
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -8,14 +8,16 @@ version = "0.1.0"
albumentations = "^1.2.0"
matplotlib = "^3.5.2"
numpy = "^1.23.0"
onnx = "^1.12.0"
onnxruntime = "^1.11.1"
python = ">=3.8,<3.11"
pytorch-lightning = "^1.6.4"
rich = "^12.4.4"
scipy = "^1.8.1"
torch = "^1.11.0"
torchvision = "^0.12.0"
tqdm = "^4.64.0"
wandb = "^0.12.19"
rich = "^12.4.4"
[tool.poetry.dev-dependencies]
black = "^22.3.0"

View file

@ -3,12 +3,12 @@ import logging
import albumentations as A
import numpy as np
import onnx
import onnxruntime
import torch
from albumentations.pytorch import ToTensorV2
from PIL import Image
from unet import UNet
def get_args():
parser = argparse.ArgumentParser(
@ -38,23 +38,23 @@ def get_args():
return parser.parse_args()
def sigmoid(x):
return 1 / (1 + np.exp(-x))
if __name__ == "__main__":
args = get_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
net = UNet(n_channels=3, n_classes=1)
onnx_model = onnx.load(args.model)
onnx.checker.check_model(onnx_model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device {device}")
ort_session = onnxruntime.InferenceSession(args.model)
logging.info("Transfering model to device")
net.to(device=device)
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
logging.info(f"Loading model {args.model}")
net.load_state_dict(torch.load(args.model, map_location=device))
logging.info(f"Loading image {args.input}")
img = Image.open(args.input).convert("RGB")
logging.info(f"Preprocessing image {args.input}")
@ -68,17 +68,14 @@ if __name__ == "__main__":
img = aug["image"]
logging.info(f"Predicting image {args.input}")
img = img.unsqueeze(0).to(device=device, dtype=torch.float32)
img = img.unsqueeze(0)
net.eval()
with torch.inference_mode():
mask = net(img)
mask = torch.sigmoid(mask)[0]
mask = mask.cpu()
mask = mask.squeeze()
mask = mask > 0.5
mask = np.asarray(mask)
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
ort_outs = ort_session.run(None, ort_inputs)
logging.info(f"Saving prediction to {args.output}")
mask = Image.fromarray(mask)
mask.save(args.output)
img_out_y = ort_outs[0]
img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode="L")
img_out_y.save(args.output)

View file

@ -24,7 +24,7 @@ class RandomPaste(A.DualTransform):
self,
nb,
image_dir,
scale_range=(0.1, 0.2),
scale_range=(0.05, 0.25),
always_apply=True,
p=1.0,
):