feat: dynamic onnx ?

Former-commit-id: 619c74a13d0674fc77bd5c1bf711013c1b3d4626 [formerly 762126125c2f108855a0837f3688f28e1002dcf7]
Former-commit-id: 7aa443fd8b68603171d2bbfa87bb9eddbe6dc066
This commit is contained in:
Laurent Fainsin 2022-07-08 11:05:44 +02:00
parent 90978bfdc3
commit 7ca448803b
5 changed files with 114 additions and 3 deletions

105
src/dynamic.ipynb Normal file
View file

@ -0,0 +1,105 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from unet import UNet\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"net = UNet(\n",
" n_channels=3,\n",
" n_classes=1,\n",
" batch_size=1,\n",
" learning_rate=1e-4,\n",
" features=[8, 16, 32, 64],\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"net.load_state_dict(\n",
" torch.load(\"../best.pth\")\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"dummy_input = torch.randn(1, 3, 1024, 1024, requires_grad=True)\n",
"torch.onnx.export(\n",
" net,\n",
" dummy_input,\n",
" \"model-test.onnx\",\n",
" opset_version=14,\n",
" input_names=[\"input\"],\n",
" output_names=[\"output\"],\n",
" dynamic_axes={\n",
" \"input\": {\n",
" 2: \"height\",\n",
" 3: \"width\",\n",
" },\n",
" \"output\": {\n",
" 2: \"height\",\n",
" 3: \"width\",\n",
" },\n",
" },\n",
")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.0 ('.venv': poetry)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "dc80d2c03865715c8671359a6bf138f6c8ae4e26ae025f2543e0980b8db0ed7e"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View file

@ -59,8 +59,14 @@ class Up(nn.Module):
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
diffY2 = torch.div(diffY, 2, rounding_mode="trunc")
diffX2 = torch.div(diffX, 2, rounding_mode="trunc")
x1 = F.pad(
input=x1,
pad=[diffX2, diffX - diffX2, diffY2, diffY - diffY2],
)
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)

View file

@ -268,14 +268,14 @@ class UNet(pl.LightningModule):
# export model to pth
torch.save(self.state_dict(), f"checkpoints/model.pth")
artifact = wandb.Artifact("pth", type="model")
artifact.add_file(f"checkpoints/model.pth")
artifact.add_file("checkpoints/model.pth")
wandb.run.log_artifact(artifact)
# export model to onnx
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True)
torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx")
artifact = wandb.Artifact("onnx", type="model")
artifact.add_file(f"checkpoints/model.onnx")
artifact.add_file("checkpoints/model.onnx")
wandb.run.log_artifact(artifact)
def configure_optimizers(self):