mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
Merge branch 'onnx' into pl
Former-commit-id: 7f28adc38d86319181dea3e311976162ae20b6d2 [formerly 2b8801777106b1a17a9cef2c3a7d65e6ad780f3f] Former-commit-id: 6d658bfd379005fcfdb75f97509be8dc920a8ef5
This commit is contained in:
commit
70b19b3b94
8
.vscode/launch.json
vendored
8
.vscode/launch.json
vendored
|
@ -12,11 +12,13 @@
|
|||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"--input",
|
||||
"images/SM.png",
|
||||
"images/test.png",
|
||||
"--output",
|
||||
"output.png",
|
||||
"output_onnx.png",
|
||||
"--model",
|
||||
"good.onnx",
|
||||
],
|
||||
"justMyCode": true
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
3409
poetry.lock
generated
3409
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -8,14 +8,16 @@ version = "0.1.0"
|
|||
albumentations = "^1.2.0"
|
||||
matplotlib = "^3.5.2"
|
||||
numpy = "^1.23.0"
|
||||
onnx = "^1.12.0"
|
||||
onnxruntime = "^1.11.1"
|
||||
python = ">=3.8,<3.11"
|
||||
pytorch-lightning = "^1.6.4"
|
||||
rich = "^12.4.4"
|
||||
scipy = "^1.8.1"
|
||||
torch = "^1.11.0"
|
||||
torchvision = "^0.12.0"
|
||||
tqdm = "^4.64.0"
|
||||
wandb = "^0.12.19"
|
||||
rich = "^12.4.4"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
black = "^22.3.0"
|
||||
|
|
|
@ -3,12 +3,12 @@ import logging
|
|||
|
||||
import albumentations as A
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnxruntime
|
||||
import torch
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from PIL import Image
|
||||
|
||||
from unet import UNet
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -38,23 +38,23 @@ def get_args():
|
|||
return parser.parse_args()
|
||||
|
||||
|
||||
def sigmoid(x):
|
||||
return 1 / (1 + np.exp(-x))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
|
||||
net = UNet(n_channels=3, n_classes=1)
|
||||
onnx_model = onnx.load(args.model)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logging.info(f"Using device {device}")
|
||||
ort_session = onnxruntime.InferenceSession(args.model)
|
||||
|
||||
logging.info("Transfering model to device")
|
||||
net.to(device=device)
|
||||
def to_numpy(tensor):
|
||||
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
|
||||
|
||||
logging.info(f"Loading model {args.model}")
|
||||
net.load_state_dict(torch.load(args.model, map_location=device))
|
||||
|
||||
logging.info(f"Loading image {args.input}")
|
||||
img = Image.open(args.input).convert("RGB")
|
||||
|
||||
logging.info(f"Preprocessing image {args.input}")
|
||||
|
@ -68,17 +68,14 @@ if __name__ == "__main__":
|
|||
img = aug["image"]
|
||||
|
||||
logging.info(f"Predicting image {args.input}")
|
||||
img = img.unsqueeze(0).to(device=device, dtype=torch.float32)
|
||||
img = img.unsqueeze(0)
|
||||
|
||||
net.eval()
|
||||
with torch.inference_mode():
|
||||
mask = net(img)
|
||||
mask = torch.sigmoid(mask)[0]
|
||||
mask = mask.cpu()
|
||||
mask = mask.squeeze()
|
||||
mask = mask > 0.5
|
||||
mask = np.asarray(mask)
|
||||
# compute ONNX Runtime output prediction
|
||||
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
|
||||
ort_outs = ort_session.run(None, ort_inputs)
|
||||
|
||||
logging.info(f"Saving prediction to {args.output}")
|
||||
mask = Image.fromarray(mask)
|
||||
mask.save(args.output)
|
||||
img_out_y = ort_outs[0]
|
||||
|
||||
img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode="L")
|
||||
|
||||
img_out_y.save(args.output)
|
||||
|
|
|
@ -24,7 +24,7 @@ class RandomPaste(A.DualTransform):
|
|||
self,
|
||||
nb,
|
||||
image_dir,
|
||||
scale_range=(0.1, 0.2),
|
||||
scale_range=(0.05, 0.25),
|
||||
always_apply=True,
|
||||
p=1.0,
|
||||
):
|
||||
|
|
Loading…
Reference in a new issue