feat: confiming synth "overfitting"

Former-commit-id: 9a6d691503c72e76ac68eab378fc07f0e35f5182 [formerly 9c0469156cf11b9d7540c52115f9fb20ce873d5d]
Former-commit-id: c597ef791d930739a3c56ba62e3b0070bff0e82b
This commit is contained in:
Laurent Fainsin 2022-07-12 14:21:23 +02:00
parent 50d18a5b39
commit dc833d2a88
3 changed files with 246 additions and 115 deletions

View file

@ -2,32 +2,30 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from unet import UNet\n"
"from unet.model import UNet\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"net = UNet(\n",
" n_channels=3,\n",
" n_classes=1,\n",
" batch_size=1,\n",
" learning_rate=1e-4,\n",
" features=[8, 16, 32, 64],\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 10,
"metadata": {},
"outputs": [
{
@ -36,28 +34,41 @@
"<All keys matched successfully>"
]
},
"execution_count": 3,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"net.load_state_dict(\n",
" torch.load(\"../checkpoint/best.pth\")\n",
" torch.load(\"../../checkpoints/best.pth\")\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 11,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n"
]
}
],
"source": [
"dummy_input = torch.randn(1, 3, 1024, 1024, requires_grad=True)\n",
"torch.onnx.export(\n",
" net,\n",
" dummy_input,\n",
" \"../checkpoint/best.onnx\",\n",
" \"../../checkpoints/best.onnx\",\n",
" opset_version=14,\n",
" input_names=[\"input\"],\n",
" output_names=[\"output\"],\n",

File diff suppressed because one or more lines are too long

View file

@ -24,7 +24,7 @@ class RandomPaste(A.DualTransform):
self,
nb,
image_dir,
scale_range=(0.05, 0.25),
scale_range=(0.05, 0.5),
always_apply=True,
p=1.0,
):
@ -138,7 +138,12 @@ class RandomPaste(A.DualTransform):
# generate augmentations
augmentations = []
NB = rd.randint(1, self.nb)
while len(augmentations) < NB: # TODO: mettre une condition d'arret ite max
ite = 0
while len(augmentations) < NB:
if ite > 100:
break
scale = rd.uniform(*self.scale_range) * min_scale
shape = np.array(paste_shape * scale, dtype=np.uint)