diff --git a/src/notebooks/predict.ipynb b/src/notebooks/predict.ipynb new file mode 100644 index 0000000..f4ddd5b --- /dev/null +++ b/src/notebooks/predict.ipynb @@ -0,0 +1,333 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [], + "source": [ + "import albumentations as A\n", + "import numpy as np\n", + "import onnx\n", + "import onnxruntime\n", + "from albumentations.pytorch import ToTensorV2\n", + "from PIL import Image\n", + "\n", + "%config InlineBackend.figure_formats = ['svg']\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [], + "source": [ + "def sigmoid(x):\n", + " return 1 / (1 + np.exp(-x))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [], + "source": [ + "def to_numpy(tensor):\n", + " return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [], + "source": [ + "model_path = \"../../checkpoints/best.onnx\"\n", + "# image_path = \"../../images/SM.png\"\n", + "image_path = \"/home/lilian/data_disk/lfainsin/test/2022_SM/DOS_DETAIL/DSC_0050.jpg\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [], + "source": [ + "onnx_model = onnx.load(model_path)\n", + "onnx.checker.check_model(onnx_model)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [], + "source": [ + "session = onnxruntime.InferenceSession(model_path)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(8256, 5504)\n" + ] + }, + { + "data": { + "image/svg+xml": "\n\n\n \n \n \n \n 2022-07-12T11:04:17.764332\n image/svg+xml\n \n \n Matplotlib v3.5.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "image = Image.open(image_path).convert(\"RGB\")\n", + "\n", + "plt.figure(figsize=(25, 10))\n", + "plt.imshow(image)\n", + "\n", + "print(image.size)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [], + "source": [ + "transform = A.Compose(\n", + " [\n", + " A.LongestMaxSize(1024),\n", + " A.ToFloat(max_value=255), # [0, 255] -> [0.0, 1.0]\n", + " ToTensorV2(), # HWC -> CHW\n", + " ],\n", + ")\n", + "aug = transform(image=np.asarray(image))\n", + "img = aug[\"image\"]\n", + "\n", + "img = img.unsqueeze(0) # -> 1CHW" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/svg+xml": "\n\n\n \n \n \n \n 2022-07-12T11:04:23.272460\n image/svg+xml\n \n \n Matplotlib v3.5.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "inputs = {\n", + " session.get_inputs()[0].name: to_numpy(img),\n", + "}\n", + "\n", + "outs = session.run(None, inputs)\n", + "\n", + "img_out = outs[0][0][0] # extract HW\n", + "img_out = img_out * 255 # [0.0, 255.0]\n", + "img_out = img_out.clip(0, 255) \n", + "img_out = np.uint8(img_out) # [0, 255]\n", + "img_out = Image.fromarray(img_out, \"L\") # PIL img\n", + "\n", + "plt.figure(figsize=(25, 10))\n", + "plt.imshow(img_out)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 61, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/svg+xml": "\n\n\n \n \n \n \n 2022-07-12T11:04:25.814018\n image/svg+xml\n \n \n Matplotlib v3.5.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "img_out = outs[0][0][0] # extract HW\n", + "img_out = sigmoid(img_out) # -> [0.0, 1.0]\n", + "img_out = img_out * 255 # [0.0, 255.0]\n", + "img_out = np.uint8(img_out) # [0, 255]\n", + "img_out = Image.fromarray(img_out, \"L\") # PIL img\n", + "\n", + "plt.figure(figsize=(25, 10))\n", + "plt.imshow(img_out)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from utils.dice import dice_score" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image_path = \"/home/lilian/data_disk/lfainsin/test/2022_SM/DOS_DETAIL/DSC_0050.jpg\"\n", + "gt_path = \"/home/lilian/data_disk/lfainsin/test/2022_SM/DOS_DETAIL/MASK.PNG\"\n", + "\n", + "image = Image.open(image_path).convert(\"RGB\")\n", + "image = torch.tensor(np.uint8(image))\n", + "\n", + "ground_truth = Image.open(gt_path).convert(\"L\")\n", + "ground_truth = torch.tensor(np.uint8(ground_truth))\n", + "\n", + "dices = []\n", + "rezs = range(128, 4096+4, 4)\n", + "\n", + "for rez in rezs:\n", + "\n", + " print(rez, end=\"\\r\")\n", + "\n", + " transform = A.Compose(\n", + " [\n", + " A.LongestMaxSize(rez),\n", + " A.ToFloat(max_value=255), # [0, 255] -> [0.0, 1.0]\n", + " ToTensorV2(), # HWC -> CHW\n", + " ],\n", + " )\n", + " aug = transform(image=np.array(image))\n", + " img = aug[\"image\"]\n", + " img = img.unsqueeze(0) # -> 1CHW\n", + "\n", + " inputs = {\n", + " session.get_inputs()[0].name: to_numpy(img),\n", + " }\n", + " outs = session.run(None, inputs)\n", + "\n", + " img_out = outs[0][0][0] # extract HW\n", + " img_out = sigmoid(img_out) # -> [0.0, 1.0]\n", + " img_out = img_out > 0.5 # {False, True}\n", + " img_out = np.uint8(img_out) # [0, 1]\n", + " img_out = torch.tensor(img_out)\n", + "\n", + " aug = transform(image=np.array(ground_truth))\n", + " gt = aug[\"image\"]\n", + " gt = gt > 0.1 # {False, True}\n", + " gt = np.uint8(gt) # [0, 1]\n", + " gt = torch.tensor(gt)\n", + "\n", + " dice = dice_score(gt, img_out, logits=False)\n", + " dices.append(dice)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 64, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/svg+xml": "\n\n\n \n \n \n \n 2022-07-12T11:05:04.293064\n image/svg+xml\n \n \n Matplotlib v3.5.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.figure(figsize=(15, 10))\n", + "plt.plot(rezs, dices)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.0 ('.venv': poetry)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "dc80d2c03865715c8671359a6bf138f6c8ae4e26ae025f2543e0980b8db0ed7e" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}