diff --git a/src/notebooks/convert.ipynb b/src/notebooks/convert.ipynb index 94363d6..37f1c63 100644 --- a/src/notebooks/convert.ipynb +++ b/src/notebooks/convert.ipynb @@ -2,32 +2,30 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import torch\n", - "from unet import UNet\n" + "from unet.model import UNet\n" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "net = UNet(\n", " n_channels=3,\n", " n_classes=1,\n", - " batch_size=1,\n", - " learning_rate=1e-4,\n", " features=[8, 16, 32, 64],\n", ")\n" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -36,28 +34,41 @@ "" ] }, - "execution_count": 3, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net.load_state_dict(\n", - " torch.load(\"../checkpoint/best.pth\")\n", + " torch.load(\"../../checkpoints/best.pth\")\n", ")\n" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 11, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n", + "WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n", + "WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n", + "WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n", + "WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n", + "WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n" + ] + } + ], "source": [ "dummy_input = torch.randn(1, 3, 1024, 1024, requires_grad=True)\n", "torch.onnx.export(\n", " net,\n", " dummy_input,\n", - " \"../checkpoint/best.onnx\",\n", + " \"../../checkpoints/best.onnx\",\n", " opset_version=14,\n", " input_names=[\"input\"],\n", " output_names=[\"output\"],\n", diff --git a/src/notebooks/predict.ipynb b/src/notebooks/predict.ipynb index f4ddd5b..e83106f 100644 --- a/src/notebooks/predict.ipynb +++ b/src/notebooks/predict.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 52, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -15,12 +15,15 @@ "\n", "%config InlineBackend.figure_formats = ['svg']\n", "import matplotlib.pyplot as plt\n", - "%matplotlib inline\n" + "%matplotlib inline\n", + "\n", + "import torch\n", + "from utils.dice import dice_score\n" ] }, { "cell_type": "code", - "execution_count": 53, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -30,7 +33,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -40,18 +43,19 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_path = \"../../checkpoints/best.onnx\"\n", "# image_path = \"../../images/SM.png\"\n", - "image_path = \"/home/lilian/data_disk/lfainsin/test/2022_SM/DOS_DETAIL/DSC_0050.jpg\"\n" + "image_path = \"/home/lilian/data_disk/lfainsin/test/2022_SM/DOS_DETAIL/DSC_0055.jpg\"\n", + "gt_path = \"/home/lilian/data_disk/lfainsin/test/2022_SM/DOS_DETAIL/MASK.PNG\"\n" ] }, { "cell_type": "code", - "execution_count": 56, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -61,7 +65,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -70,29 +74,9 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(8256, 5504)\n" - ] - }, - { - "data": { - "image/svg+xml": "\n\n\n \n \n \n \n 2022-07-12T11:04:17.764332\n image/svg+xml\n \n \n Matplotlib v3.5.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "image = Image.open(image_path).convert(\"RGB\")\n", "\n", @@ -104,7 +88,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -123,32 +107,9 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 60, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/svg+xml": "\n\n\n \n \n \n \n 2022-07-12T11:04:23.272460\n image/svg+xml\n \n \n Matplotlib v3.5.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "inputs = {\n", " session.get_inputs()[0].name: to_numpy(img),\n", @@ -162,30 +123,216 @@ "img_out = np.uint8(img_out) # [0, 255]\n", "img_out = Image.fromarray(img_out, \"L\") # PIL img\n", "\n", - "plt.figure(figsize=(25, 10))\n", - "plt.imshow(img_out)\n" + "# plt.figure(figsize=(25, 10))\n", + "# plt.imshow(img_out, cmap=plt.cm.gray)\n" ] }, { "cell_type": "code", - "execution_count": 61, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "inputs = {\n", + " session.get_inputs()[0].name: to_numpy(img),\n", + "}\n", + "\n", + "outs = session.run(None, inputs)\n", + "\n", + "img_out = outs[0][0][0] # extract HW\n", + "img_out = sigmoid(img_out) # -> [0.0, 1.0]\n", + "img_out = img_out * 255 # [0.0, 255.0]\n", + "img_out = np.uint8(img_out) # [0, 255]\n", + "img_out = Image.fromarray(img_out, \"L\") # PIL img\n", + "\n", + "plt.figure(figsize=(25, 10))\n", + "plt.imshow(img_out, cmap=plt.cm.gray)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from skimage import measure, draw\n", + "\n", + "# Find contours at a constant value of 0.8\n", + "contours = measure.find_contours(np.asarray(img_out))\n", + "\n", + "# Display the image and plot all contours found\n", + "plt.figure(figsize=(19, 18))\n", + "plt.imshow(img_out, cmap=plt.cm.gray)\n", + "\n", + "lenc = [len(c) for c in contours]\n", + "indexs = np.argsort(lenc)\n", + "l = indexs[-1]\n", + "\n", + "# plt.plot(contours[l][:, 1], contours[l][:, 0], linewidth=2, c=\"red\")\n", + "\n", + "# on estime l'ellipse\n", + "ellipse = measure.EllipseModel()\n", + "ellipse.estimate(contours[l])\n", + "\n", + "# on récupère les coords des points de l'ellipse\n", + "cx, cy, a, b, theta = ellipse.params\n", + "ex, ey = draw.ellipse_perimeter(int(cx), int(cy), int(a), int(b), orientation=theta, shape=img_out.size[::-1])\n", + "\n", + "plt.scatter(ey, ex, c=\"green\", s=0.5)\n", + "plt.scatter(cy, cx, c=\"green\", s=0.5)\n", + "\n", + "# # on estime le cercle\n", + "circle = measure.CircleModel()\n", + "circle.estimate(contours[l])\n", + "\n", + "# on récupère les coords des points du cercle\n", + "cx, cy, r = circle.params\n", + "ex, ey = draw.circle_perimeter(int(cx), int(cy), int(r), shape=img_out.size[::-1])\n", + "\n", + "plt.scatter(ey, ex, c=\"blue\", s=0.5)\n", + "plt.scatter(cy, cx, c=\"blue\", s=0.5)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# mutilresolution zoom" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "factor = image.size[0] / img_out.size[0]\n", + "taille = min(a, b) * factor\n", + "percentage = 0.1\n", + "size = taille / percentage\n", + "\n", + "img = image.crop((cy*factor - size, cx*factor - size/1.5, cy*factor + size, cx*factor + size/1.5))\n", + "\n", + "transform = A.Compose(\n", + " [\n", + " A.LongestMaxSize(1024),\n", + " A.ToFloat(max_value=255), # [0, 255] -> [0.0, 1.0]\n", + " ToTensorV2(), # HWC -> CHW\n", + " ],\n", + ")\n", + "aug = transform(image=np.asarray(img))\n", + "img = aug[\"image\"]\n", + "\n", + "img = img.unsqueeze(0) # -> 1CHW" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "inputs = {\n", + " session.get_inputs()[0].name: to_numpy(img),\n", + "}\n", + "\n", + "outs = session.run(None, inputs)\n", + "\n", + "img_out = outs[0][0][0] # extract HW\n", + "img_out = sigmoid(img_out) # -> [0.0, 1.0]\n", + "img_out = img_out * 255 # [0.0, 255.0]\n", + "img_out = np.uint8(img_out) # [0, 255]\n", + "img_out = Image.fromarray(img_out, \"L\") # PIL img\n", + "\n", + "plt.figure(figsize=(25, 10))\n", + "plt.imshow(img_out, cmap=plt.cm.gray)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.356565656565656564\r" + ] + } + ], + "source": [ + "image = Image.open(image_path).convert(\"RGB\")\n", + "ground_truth = Image.open(gt_path).convert(\"L\")\n", + "\n", + "Image.MAX_IMAGE_PIXELS = 5000000000 \n", + "\n", + "dices = []\n", + "percentages = np.linspace(0.01, 0.35, 100)\n", + "\n", + "for p in percentages:\n", + " size = taille / p \n", + " print(p, end=\"\\r\")\n", + "\n", + " img = image.crop((cy * factor - size, cx * factor - size, cy * factor + size, cx * factor + size))\n", + " gt = ground_truth.crop((cy * factor - size, cx * factor - size, cy * factor + size, cx * factor + size))\n", + "\n", + " transform = A.Compose(\n", + " [\n", + " A.LongestMaxSize(1024),\n", + " A.ToFloat(max_value=255), # [0, 255] -> [0.0, 1.0]\n", + " ToTensorV2(), # HWC -> CHW\n", + " ],\n", + " )\n", + " aug = transform(image=np.asarray(img))\n", + " img = aug[\"image\"]\n", + " img = img.unsqueeze(0) # -> 1CHW\n", + "\n", + " inputs = {\n", + " session.get_inputs()[0].name: to_numpy(img),\n", + " }\n", + "\n", + " outs = session.run(None, inputs)\n", + "\n", + " img_out2 = outs[0][0][0] # extract HW\n", + " img_out2 = sigmoid(img_out2) # -> [0.0, 1.0]\n", + " img_out2 = img_out2 > 0.5 # {False, True}\n", + " img_out2 = np.uint8(img_out2) # [0, 1]\n", + " img_out2 = torch.tensor(img_out2)\n", + "\n", + " aug = transform(image=np.array(gt))\n", + " gt = aug[\"image\"]\n", + " gt = gt > 0.1 # {False, True}\n", + " gt = np.uint8(gt) # [0, 1]\n", + " gt = torch.tensor(gt)\n", + "\n", + " dice = dice_score(gt, img_out2, logits=False)\n", + " dices.append(dice)\n", + "\n", + " # plt.figure(figsize=(25, 10))\n", + " # plt.imshow(Image.fromarray(gt.squeeze(0).numpy()), cmap=plt.cm.gray)" + ] + }, + { + "cell_type": "code", + "execution_count": 60, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "[]" ] }, - "execution_count": 61, + "execution_count": 60, "metadata": {}, "output_type": "execute_result" }, { "data": { - "image/svg+xml": "\n\n\n \n \n \n \n 2022-07-12T11:04:25.814018\n image/svg+xml\n \n \n Matplotlib v3.5.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "image/svg+xml": "\n\n\n \n \n \n \n 2022-07-12T14:21:04.233003\n image/svg+xml\n \n \n Matplotlib v3.5.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", "text/plain": [ - "
" + "
" ] }, "metadata": { @@ -195,24 +342,15 @@ } ], "source": [ - "img_out = outs[0][0][0] # extract HW\n", - "img_out = sigmoid(img_out) # -> [0.0, 1.0]\n", - "img_out = img_out * 255 # [0.0, 255.0]\n", - "img_out = np.uint8(img_out) # [0, 255]\n", - "img_out = Image.fromarray(img_out, \"L\") # PIL img\n", - "\n", - "plt.figure(figsize=(25, 10))\n", - "plt.imshow(img_out)\n" + "plt.figure(figsize=(15, 10))\n", + "plt.plot(percentages, dices)\n" ] }, { - "cell_type": "code", - "execution_count": 62, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "import torch\n", - "from utils.dice import dice_score" + "# Dice(resolution)" ] }, { @@ -231,7 +369,7 @@ "ground_truth = torch.tensor(np.uint8(ground_truth))\n", "\n", "dices = []\n", - "rezs = range(128, 4096+4, 4)\n", + "rezs = range(128, 1920+4, 4)\n", "\n", "for rez in rezs:\n", "\n", @@ -271,32 +409,9 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[]" - ] - }, - "execution_count": 64, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/svg+xml": "\n\n\n \n \n \n \n 2022-07-12T11:05:04.293064\n image/svg+xml\n \n \n Matplotlib v3.5.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "plt.figure(figsize=(15, 10))\n", "plt.plot(rezs, dices)\n" diff --git a/src/utils/paste.py b/src/utils/paste.py index be6f25d..2e2e88c 100644 --- a/src/utils/paste.py +++ b/src/utils/paste.py @@ -24,7 +24,7 @@ class RandomPaste(A.DualTransform): self, nb, image_dir, - scale_range=(0.05, 0.25), + scale_range=(0.05, 0.5), always_apply=True, p=1.0, ): @@ -138,7 +138,12 @@ class RandomPaste(A.DualTransform): # generate augmentations augmentations = [] NB = rd.randint(1, self.nb) - while len(augmentations) < NB: # TODO: mettre une condition d'arret ite max + ite = 0 + while len(augmentations) < NB: + + if ite > 100: + break + scale = rd.uniform(*self.scale_range) * min_scale shape = np.array(paste_shape * scale, dtype=np.uint)