mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 23:12:05 +00:00
feat: confiming synth "overfitting"
Former-commit-id: 9a6d691503c72e76ac68eab378fc07f0e35f5182 [formerly 9c0469156cf11b9d7540c52115f9fb20ce873d5d] Former-commit-id: c597ef791d930739a3c56ba62e3b0070bff0e82b
This commit is contained in:
parent
50d18a5b39
commit
dc833d2a88
|
@ -2,32 +2,30 @@
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 5,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import torch\n",
|
"import torch\n",
|
||||||
"from unet import UNet\n"
|
"from unet.model import UNet\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 7,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"net = UNet(\n",
|
"net = UNet(\n",
|
||||||
" n_channels=3,\n",
|
" n_channels=3,\n",
|
||||||
" n_classes=1,\n",
|
" n_classes=1,\n",
|
||||||
" batch_size=1,\n",
|
|
||||||
" learning_rate=1e-4,\n",
|
|
||||||
" features=[8, 16, 32, 64],\n",
|
" features=[8, 16, 32, 64],\n",
|
||||||
")\n"
|
")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 10,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -36,28 +34,41 @@
|
||||||
"<All keys matched successfully>"
|
"<All keys matched successfully>"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 3,
|
"execution_count": 10,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"net.load_state_dict(\n",
|
"net.load_state_dict(\n",
|
||||||
" torch.load(\"../checkpoint/best.pth\")\n",
|
" torch.load(\"../../checkpoints/best.pth\")\n",
|
||||||
")\n"
|
")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 11,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
|
||||||
|
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
|
||||||
|
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
|
||||||
|
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
|
||||||
|
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
|
||||||
|
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"dummy_input = torch.randn(1, 3, 1024, 1024, requires_grad=True)\n",
|
"dummy_input = torch.randn(1, 3, 1024, 1024, requires_grad=True)\n",
|
||||||
"torch.onnx.export(\n",
|
"torch.onnx.export(\n",
|
||||||
" net,\n",
|
" net,\n",
|
||||||
" dummy_input,\n",
|
" dummy_input,\n",
|
||||||
" \"../checkpoint/best.onnx\",\n",
|
" \"../../checkpoints/best.onnx\",\n",
|
||||||
" opset_version=14,\n",
|
" opset_version=14,\n",
|
||||||
" input_names=[\"input\"],\n",
|
" input_names=[\"input\"],\n",
|
||||||
" output_names=[\"output\"],\n",
|
" output_names=[\"output\"],\n",
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -24,7 +24,7 @@ class RandomPaste(A.DualTransform):
|
||||||
self,
|
self,
|
||||||
nb,
|
nb,
|
||||||
image_dir,
|
image_dir,
|
||||||
scale_range=(0.05, 0.25),
|
scale_range=(0.05, 0.5),
|
||||||
always_apply=True,
|
always_apply=True,
|
||||||
p=1.0,
|
p=1.0,
|
||||||
):
|
):
|
||||||
|
@ -138,7 +138,12 @@ class RandomPaste(A.DualTransform):
|
||||||
# generate augmentations
|
# generate augmentations
|
||||||
augmentations = []
|
augmentations = []
|
||||||
NB = rd.randint(1, self.nb)
|
NB = rd.randint(1, self.nb)
|
||||||
while len(augmentations) < NB: # TODO: mettre une condition d'arret ite max
|
ite = 0
|
||||||
|
while len(augmentations) < NB:
|
||||||
|
|
||||||
|
if ite > 100:
|
||||||
|
break
|
||||||
|
|
||||||
scale = rd.uniform(*self.scale_range) * min_scale
|
scale = rd.uniform(*self.scale_range) * min_scale
|
||||||
shape = np.array(paste_shape * scale, dtype=np.uint)
|
shape = np.array(paste_shape * scale, dtype=np.uint)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue