feat: dynamic onnx ?
Former-commit-id: 619c74a13d0674fc77bd5c1bf711013c1b3d4626 [formerly 762126125c2f108855a0837f3688f28e1002dcf7] Former-commit-id: 7aa443fd8b68603171d2bbfa87bb9eddbe6dc066
This commit is contained in:
parent
90978bfdc3
commit
7ca448803b
105
src/dynamic.ipynb
Normal file
105
src/dynamic.ipynb
Normal file
|
@ -0,0 +1,105 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import torch\n",
|
||||||
|
"from unet import UNet\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"net = UNet(\n",
|
||||||
|
" n_channels=3,\n",
|
||||||
|
" n_classes=1,\n",
|
||||||
|
" batch_size=1,\n",
|
||||||
|
" learning_rate=1e-4,\n",
|
||||||
|
" features=[8, 16, 32, 64],\n",
|
||||||
|
")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"<All keys matched successfully>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"net.load_state_dict(\n",
|
||||||
|
" torch.load(\"../best.pth\")\n",
|
||||||
|
")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dummy_input = torch.randn(1, 3, 1024, 1024, requires_grad=True)\n",
|
||||||
|
"torch.onnx.export(\n",
|
||||||
|
" net,\n",
|
||||||
|
" dummy_input,\n",
|
||||||
|
" \"model-test.onnx\",\n",
|
||||||
|
" opset_version=14,\n",
|
||||||
|
" input_names=[\"input\"],\n",
|
||||||
|
" output_names=[\"output\"],\n",
|
||||||
|
" dynamic_axes={\n",
|
||||||
|
" \"input\": {\n",
|
||||||
|
" 2: \"height\",\n",
|
||||||
|
" 3: \"width\",\n",
|
||||||
|
" },\n",
|
||||||
|
" \"output\": {\n",
|
||||||
|
" 2: \"height\",\n",
|
||||||
|
" 3: \"width\",\n",
|
||||||
|
" },\n",
|
||||||
|
" },\n",
|
||||||
|
")\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3.8.0 ('.venv': poetry)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.8.0"
|
||||||
|
},
|
||||||
|
"orig_nbformat": 4,
|
||||||
|
"vscode": {
|
||||||
|
"interpreter": {
|
||||||
|
"hash": "dc80d2c03865715c8671359a6bf138f6c8ae4e26ae025f2543e0980b8db0ed7e"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
|
@ -59,8 +59,14 @@ class Up(nn.Module):
|
||||||
# input is CHW
|
# input is CHW
|
||||||
diffY = x2.size()[2] - x1.size()[2]
|
diffY = x2.size()[2] - x1.size()[2]
|
||||||
diffX = x2.size()[3] - x1.size()[3]
|
diffX = x2.size()[3] - x1.size()[3]
|
||||||
|
diffY2 = torch.div(diffY, 2, rounding_mode="trunc")
|
||||||
|
diffX2 = torch.div(diffX, 2, rounding_mode="trunc")
|
||||||
|
|
||||||
|
x1 = F.pad(
|
||||||
|
input=x1,
|
||||||
|
pad=[diffX2, diffX - diffX2, diffY2, diffY - diffY2],
|
||||||
|
)
|
||||||
|
|
||||||
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
|
|
||||||
x = torch.cat([x2, x1], dim=1)
|
x = torch.cat([x2, x1], dim=1)
|
||||||
|
|
||||||
return self.conv(x)
|
return self.conv(x)
|
||||||
|
|
|
@ -268,14 +268,14 @@ class UNet(pl.LightningModule):
|
||||||
# export model to pth
|
# export model to pth
|
||||||
torch.save(self.state_dict(), f"checkpoints/model.pth")
|
torch.save(self.state_dict(), f"checkpoints/model.pth")
|
||||||
artifact = wandb.Artifact("pth", type="model")
|
artifact = wandb.Artifact("pth", type="model")
|
||||||
artifact.add_file(f"checkpoints/model.pth")
|
artifact.add_file("checkpoints/model.pth")
|
||||||
wandb.run.log_artifact(artifact)
|
wandb.run.log_artifact(artifact)
|
||||||
|
|
||||||
# export model to onnx
|
# export model to onnx
|
||||||
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True)
|
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True)
|
||||||
torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx")
|
torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx")
|
||||||
artifact = wandb.Artifact("onnx", type="model")
|
artifact = wandb.Artifact("onnx", type="model")
|
||||||
artifact.add_file(f"checkpoints/model.onnx")
|
artifact.add_file("checkpoints/model.onnx")
|
||||||
wandb.run.log_artifact(artifact)
|
wandb.run.log_artifact(artifact)
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
|
|
Loading…
Reference in a new issue