feat: confiming synth "overfitting"

Former-commit-id: 9a6d691503c72e76ac68eab378fc07f0e35f5182 [formerly 9c0469156cf11b9d7540c52115f9fb20ce873d5d]
Former-commit-id: c597ef791d930739a3c56ba62e3b0070bff0e82b
This commit is contained in:
Laurent Fainsin 2022-07-12 14:21:23 +02:00
parent 50d18a5b39
commit dc833d2a88
3 changed files with 246 additions and 115 deletions

View file

@ -2,32 +2,30 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import torch\n", "import torch\n",
"from unet import UNet\n" "from unet.model import UNet\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 7,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"net = UNet(\n", "net = UNet(\n",
" n_channels=3,\n", " n_channels=3,\n",
" n_classes=1,\n", " n_classes=1,\n",
" batch_size=1,\n",
" learning_rate=1e-4,\n",
" features=[8, 16, 32, 64],\n", " features=[8, 16, 32, 64],\n",
")\n" ")\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 10,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -36,28 +34,41 @@
"<All keys matched successfully>" "<All keys matched successfully>"
] ]
}, },
"execution_count": 3, "execution_count": 10,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
], ],
"source": [ "source": [
"net.load_state_dict(\n", "net.load_state_dict(\n",
" torch.load(\"../checkpoint/best.pth\")\n", " torch.load(\"../../checkpoints/best.pth\")\n",
")\n" ")\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 11,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n"
]
}
],
"source": [ "source": [
"dummy_input = torch.randn(1, 3, 1024, 1024, requires_grad=True)\n", "dummy_input = torch.randn(1, 3, 1024, 1024, requires_grad=True)\n",
"torch.onnx.export(\n", "torch.onnx.export(\n",
" net,\n", " net,\n",
" dummy_input,\n", " dummy_input,\n",
" \"../checkpoint/best.onnx\",\n", " \"../../checkpoints/best.onnx\",\n",
" opset_version=14,\n", " opset_version=14,\n",
" input_names=[\"input\"],\n", " input_names=[\"input\"],\n",
" output_names=[\"output\"],\n", " output_names=[\"output\"],\n",

File diff suppressed because one or more lines are too long

View file

@ -24,7 +24,7 @@ class RandomPaste(A.DualTransform):
self, self,
nb, nb,
image_dir, image_dir,
scale_range=(0.05, 0.25), scale_range=(0.05, 0.5),
always_apply=True, always_apply=True,
p=1.0, p=1.0,
): ):
@ -138,7 +138,12 @@ class RandomPaste(A.DualTransform):
# generate augmentations # generate augmentations
augmentations = [] augmentations = []
NB = rd.randint(1, self.nb) NB = rd.randint(1, self.nb)
while len(augmentations) < NB: # TODO: mettre une condition d'arret ite max ite = 0
while len(augmentations) < NB:
if ite > 100:
break
scale = rd.uniform(*self.scale_range) * min_scale scale = rd.uniform(*self.scale_range) * min_scale
shape = np.array(paste_shape * scale, dtype=np.uint) shape = np.array(paste_shape * scale, dtype=np.uint)