From 7ca448803b9487a165129228b33124e5a765373c Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Fri, 8 Jul 2022 11:05:44 +0200 Subject: [PATCH] feat: dynamic onnx ? Former-commit-id: 619c74a13d0674fc77bd5c1bf711013c1b3d4626 [formerly 762126125c2f108855a0837f3688f28e1002dcf7] Former-commit-id: 7aa443fd8b68603171d2bbfa87bb9eddbe6dc066 --- .../comp.ipynb.REMOVED.git-id | 0 src/dynamic.ipynb | 105 ++++++++++++++++++ extract.ipynb => src/extract.ipynb | 0 src/unet/blocks.py | 8 +- src/unet/model.py | 4 +- 5 files changed, 114 insertions(+), 3 deletions(-) rename comp.ipynb.REMOVED.git-id => src/comp.ipynb.REMOVED.git-id (100%) create mode 100644 src/dynamic.ipynb rename extract.ipynb => src/extract.ipynb (100%) diff --git a/comp.ipynb.REMOVED.git-id b/src/comp.ipynb.REMOVED.git-id similarity index 100% rename from comp.ipynb.REMOVED.git-id rename to src/comp.ipynb.REMOVED.git-id diff --git a/src/dynamic.ipynb b/src/dynamic.ipynb new file mode 100644 index 0000000..371cb92 --- /dev/null +++ b/src/dynamic.ipynb @@ -0,0 +1,105 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from unet import UNet\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "net = UNet(\n", + " n_channels=3,\n", + " n_classes=1,\n", + " batch_size=1,\n", + " learning_rate=1e-4,\n", + " features=[8, 16, 32, 64],\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "net.load_state_dict(\n", + " torch.load(\"../best.pth\")\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "dummy_input = torch.randn(1, 3, 1024, 1024, requires_grad=True)\n", + "torch.onnx.export(\n", + " net,\n", + " dummy_input,\n", + " \"model-test.onnx\",\n", + " opset_version=14,\n", + " input_names=[\"input\"],\n", + " output_names=[\"output\"],\n", + " dynamic_axes={\n", + " \"input\": {\n", + " 2: \"height\",\n", + " 3: \"width\",\n", + " },\n", + " \"output\": {\n", + " 2: \"height\",\n", + " 3: \"width\",\n", + " },\n", + " },\n", + ")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.0 ('.venv': poetry)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "dc80d2c03865715c8671359a6bf138f6c8ae4e26ae025f2543e0980b8db0ed7e" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/extract.ipynb b/src/extract.ipynb similarity index 100% rename from extract.ipynb rename to src/extract.ipynb diff --git a/src/unet/blocks.py b/src/unet/blocks.py index d125002..0df7f5f 100644 --- a/src/unet/blocks.py +++ b/src/unet/blocks.py @@ -59,8 +59,14 @@ class Up(nn.Module): # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] + diffY2 = torch.div(diffY, 2, rounding_mode="trunc") + diffX2 = torch.div(diffX, 2, rounding_mode="trunc") + + x1 = F.pad( + input=x1, + pad=[diffX2, diffX - diffX2, diffY2, diffY - diffY2], + ) - x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) x = torch.cat([x2, x1], dim=1) return self.conv(x) diff --git a/src/unet/model.py b/src/unet/model.py index 11ddc65..c244735 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -268,14 +268,14 @@ class UNet(pl.LightningModule): # export model to pth torch.save(self.state_dict(), f"checkpoints/model.pth") artifact = wandb.Artifact("pth", type="model") - artifact.add_file(f"checkpoints/model.pth") + artifact.add_file("checkpoints/model.pth") wandb.run.log_artifact(artifact) # export model to onnx dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True) torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx") artifact = wandb.Artifact("onnx", type="model") - artifact.add_file(f"checkpoints/model.onnx") + artifact.add_file("checkpoints/model.onnx") wandb.run.log_artifact(artifact) def configure_optimizers(self):