feat: dynamic onnx ?
Former-commit-id: 619c74a13d0674fc77bd5c1bf711013c1b3d4626 [formerly 762126125c2f108855a0837f3688f28e1002dcf7] Former-commit-id: 7aa443fd8b68603171d2bbfa87bb9eddbe6dc066
This commit is contained in:
parent
90978bfdc3
commit
7ca448803b
105
src/dynamic.ipynb
Normal file
105
src/dynamic.ipynb
Normal file
|
@ -0,0 +1,105 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from unet import UNet\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"net = UNet(\n",
|
||||
" n_channels=3,\n",
|
||||
" n_classes=1,\n",
|
||||
" batch_size=1,\n",
|
||||
" learning_rate=1e-4,\n",
|
||||
" features=[8, 16, 32, 64],\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<All keys matched successfully>"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"net.load_state_dict(\n",
|
||||
" torch.load(\"../best.pth\")\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dummy_input = torch.randn(1, 3, 1024, 1024, requires_grad=True)\n",
|
||||
"torch.onnx.export(\n",
|
||||
" net,\n",
|
||||
" dummy_input,\n",
|
||||
" \"model-test.onnx\",\n",
|
||||
" opset_version=14,\n",
|
||||
" input_names=[\"input\"],\n",
|
||||
" output_names=[\"output\"],\n",
|
||||
" dynamic_axes={\n",
|
||||
" \"input\": {\n",
|
||||
" 2: \"height\",\n",
|
||||
" 3: \"width\",\n",
|
||||
" },\n",
|
||||
" \"output\": {\n",
|
||||
" 2: \"height\",\n",
|
||||
" 3: \"width\",\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
")\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.8.0 ('.venv': poetry)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.0"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "dc80d2c03865715c8671359a6bf138f6c8ae4e26ae025f2543e0980b8db0ed7e"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -59,8 +59,14 @@ class Up(nn.Module):
|
|||
# input is CHW
|
||||
diffY = x2.size()[2] - x1.size()[2]
|
||||
diffX = x2.size()[3] - x1.size()[3]
|
||||
diffY2 = torch.div(diffY, 2, rounding_mode="trunc")
|
||||
diffX2 = torch.div(diffX, 2, rounding_mode="trunc")
|
||||
|
||||
x1 = F.pad(
|
||||
input=x1,
|
||||
pad=[diffX2, diffX - diffX2, diffY2, diffY - diffY2],
|
||||
)
|
||||
|
||||
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
|
||||
x = torch.cat([x2, x1], dim=1)
|
||||
|
||||
return self.conv(x)
|
||||
|
|
|
@ -268,14 +268,14 @@ class UNet(pl.LightningModule):
|
|||
# export model to pth
|
||||
torch.save(self.state_dict(), f"checkpoints/model.pth")
|
||||
artifact = wandb.Artifact("pth", type="model")
|
||||
artifact.add_file(f"checkpoints/model.pth")
|
||||
artifact.add_file("checkpoints/model.pth")
|
||||
wandb.run.log_artifact(artifact)
|
||||
|
||||
# export model to onnx
|
||||
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True)
|
||||
torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx")
|
||||
artifact = wandb.Artifact("onnx", type="model")
|
||||
artifact.add_file(f"checkpoints/model.onnx")
|
||||
artifact.add_file("checkpoints/model.onnx")
|
||||
wandb.run.log_artifact(artifact)
|
||||
|
||||
def configure_optimizers(self):
|
||||
|
|
Loading…
Reference in a new issue