feat: confiming synth "overfitting"
Former-commit-id: 9a6d691503c72e76ac68eab378fc07f0e35f5182 [formerly 9c0469156cf11b9d7540c52115f9fb20ce873d5d] Former-commit-id: c597ef791d930739a3c56ba62e3b0070bff0e82b
This commit is contained in:
parent
50d18a5b39
commit
dc833d2a88
|
@ -2,32 +2,30 @@
|
|||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from unet import UNet\n"
|
||||
"from unet.model import UNet\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"net = UNet(\n",
|
||||
" n_channels=3,\n",
|
||||
" n_classes=1,\n",
|
||||
" batch_size=1,\n",
|
||||
" learning_rate=1e-4,\n",
|
||||
" features=[8, 16, 32, 64],\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -36,28 +34,41 @@
|
|||
"<All keys matched successfully>"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"net.load_state_dict(\n",
|
||||
" torch.load(\"../checkpoint/best.pth\")\n",
|
||||
" torch.load(\"../../checkpoints/best.pth\")\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
|
||||
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
|
||||
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
|
||||
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
|
||||
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
|
||||
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dummy_input = torch.randn(1, 3, 1024, 1024, requires_grad=True)\n",
|
||||
"torch.onnx.export(\n",
|
||||
" net,\n",
|
||||
" dummy_input,\n",
|
||||
" \"../checkpoint/best.onnx\",\n",
|
||||
" \"../../checkpoints/best.onnx\",\n",
|
||||
" opset_version=14,\n",
|
||||
" input_names=[\"input\"],\n",
|
||||
" output_names=[\"output\"],\n",
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -24,7 +24,7 @@ class RandomPaste(A.DualTransform):
|
|||
self,
|
||||
nb,
|
||||
image_dir,
|
||||
scale_range=(0.05, 0.25),
|
||||
scale_range=(0.05, 0.5),
|
||||
always_apply=True,
|
||||
p=1.0,
|
||||
):
|
||||
|
@ -138,7 +138,12 @@ class RandomPaste(A.DualTransform):
|
|||
# generate augmentations
|
||||
augmentations = []
|
||||
NB = rd.randint(1, self.nb)
|
||||
while len(augmentations) < NB: # TODO: mettre une condition d'arret ite max
|
||||
ite = 0
|
||||
while len(augmentations) < NB:
|
||||
|
||||
if ite > 100:
|
||||
break
|
||||
|
||||
scale = rd.uniform(*self.scale_range) * min_scale
|
||||
shape = np.array(paste_shape * scale, dtype=np.uint)
|
||||
|
||||
|
|
Loading…
Reference in a new issue