mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 23:12:05 +00:00
Merge branch 'onnx' into pl
Former-commit-id: 7f28adc38d86319181dea3e311976162ae20b6d2 [formerly 2b8801777106b1a17a9cef2c3a7d65e6ad780f3f] Former-commit-id: 6d658bfd379005fcfdb75f97509be8dc920a8ef5
This commit is contained in:
commit
70b19b3b94
6
.vscode/launch.json
vendored
6
.vscode/launch.json
vendored
|
@ -12,9 +12,11 @@
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"args": [
|
"args": [
|
||||||
"--input",
|
"--input",
|
||||||
"images/SM.png",
|
"images/test.png",
|
||||||
"--output",
|
"--output",
|
||||||
"output.png",
|
"output_onnx.png",
|
||||||
|
"--model",
|
||||||
|
"good.onnx",
|
||||||
],
|
],
|
||||||
"justMyCode": true
|
"justMyCode": true
|
||||||
}
|
}
|
||||||
|
|
1127
poetry.lock
generated
1127
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -8,14 +8,16 @@ version = "0.1.0"
|
||||||
albumentations = "^1.2.0"
|
albumentations = "^1.2.0"
|
||||||
matplotlib = "^3.5.2"
|
matplotlib = "^3.5.2"
|
||||||
numpy = "^1.23.0"
|
numpy = "^1.23.0"
|
||||||
|
onnx = "^1.12.0"
|
||||||
|
onnxruntime = "^1.11.1"
|
||||||
python = ">=3.8,<3.11"
|
python = ">=3.8,<3.11"
|
||||||
pytorch-lightning = "^1.6.4"
|
pytorch-lightning = "^1.6.4"
|
||||||
|
rich = "^12.4.4"
|
||||||
scipy = "^1.8.1"
|
scipy = "^1.8.1"
|
||||||
torch = "^1.11.0"
|
torch = "^1.11.0"
|
||||||
torchvision = "^0.12.0"
|
torchvision = "^0.12.0"
|
||||||
tqdm = "^4.64.0"
|
tqdm = "^4.64.0"
|
||||||
wandb = "^0.12.19"
|
wandb = "^0.12.19"
|
||||||
rich = "^12.4.4"
|
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
black = "^22.3.0"
|
black = "^22.3.0"
|
||||||
|
|
|
@ -3,12 +3,12 @@ import logging
|
||||||
|
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import onnx
|
||||||
|
import onnxruntime
|
||||||
import torch
|
import torch
|
||||||
from albumentations.pytorch import ToTensorV2
|
from albumentations.pytorch import ToTensorV2
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from unet import UNet
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
@ -38,23 +38,23 @@ def get_args():
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def sigmoid(x):
|
||||||
|
return 1 / (1 + np.exp(-x))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = get_args()
|
args = get_args()
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||||
|
|
||||||
net = UNet(n_channels=3, n_classes=1)
|
onnx_model = onnx.load(args.model)
|
||||||
|
onnx.checker.check_model(onnx_model)
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
ort_session = onnxruntime.InferenceSession(args.model)
|
||||||
logging.info(f"Using device {device}")
|
|
||||||
|
|
||||||
logging.info("Transfering model to device")
|
def to_numpy(tensor):
|
||||||
net.to(device=device)
|
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
|
||||||
|
|
||||||
logging.info(f"Loading model {args.model}")
|
|
||||||
net.load_state_dict(torch.load(args.model, map_location=device))
|
|
||||||
|
|
||||||
logging.info(f"Loading image {args.input}")
|
|
||||||
img = Image.open(args.input).convert("RGB")
|
img = Image.open(args.input).convert("RGB")
|
||||||
|
|
||||||
logging.info(f"Preprocessing image {args.input}")
|
logging.info(f"Preprocessing image {args.input}")
|
||||||
|
@ -68,17 +68,14 @@ if __name__ == "__main__":
|
||||||
img = aug["image"]
|
img = aug["image"]
|
||||||
|
|
||||||
logging.info(f"Predicting image {args.input}")
|
logging.info(f"Predicting image {args.input}")
|
||||||
img = img.unsqueeze(0).to(device=device, dtype=torch.float32)
|
img = img.unsqueeze(0)
|
||||||
|
|
||||||
net.eval()
|
# compute ONNX Runtime output prediction
|
||||||
with torch.inference_mode():
|
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
|
||||||
mask = net(img)
|
ort_outs = ort_session.run(None, ort_inputs)
|
||||||
mask = torch.sigmoid(mask)[0]
|
|
||||||
mask = mask.cpu()
|
|
||||||
mask = mask.squeeze()
|
|
||||||
mask = mask > 0.5
|
|
||||||
mask = np.asarray(mask)
|
|
||||||
|
|
||||||
logging.info(f"Saving prediction to {args.output}")
|
img_out_y = ort_outs[0]
|
||||||
mask = Image.fromarray(mask)
|
|
||||||
mask.save(args.output)
|
img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode="L")
|
||||||
|
|
||||||
|
img_out_y.save(args.output)
|
||||||
|
|
|
@ -24,7 +24,7 @@ class RandomPaste(A.DualTransform):
|
||||||
self,
|
self,
|
||||||
nb,
|
nb,
|
||||||
image_dir,
|
image_dir,
|
||||||
scale_range=(0.1, 0.2),
|
scale_range=(0.05, 0.25),
|
||||||
always_apply=True,
|
always_apply=True,
|
||||||
p=1.0,
|
p=1.0,
|
||||||
):
|
):
|
||||||
|
|
Loading…
Reference in a new issue