From 0dd606144fbe1e5cc6a6bdf800cf74d8e65408c7 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Thu, 7 Jul 2022 12:06:41 +0200 Subject: [PATCH] feat: new paste dataset Former-commit-id: 039874208d5a27bf01beb2746a77502fd836ae5c [formerly 66638fcabaea1044d9a2fd48e6ffb20f149ebf47] Former-commit-id: 6bdf8bba0b3cbd8706337aa3167c36fba8855a4c --- comp.ipynb.REMOVED.git-id | 2 +- extract.ipynb | 177 ++++++++++++++++++++++++++++++++++++++ src/train.py | 14 ++- src/unet/model.py | 6 +- src/utils/paste.py | 28 +++--- 5 files changed, 206 insertions(+), 21 deletions(-) create mode 100644 extract.ipynb diff --git a/comp.ipynb.REMOVED.git-id b/comp.ipynb.REMOVED.git-id index b439b71..3c6779c 100644 --- a/comp.ipynb.REMOVED.git-id +++ b/comp.ipynb.REMOVED.git-id @@ -1 +1 @@ -9cbd3cff7e664a80a5a1fa1404898b7bba3cae0d \ No newline at end of file +3c9a34f197340a6051eb34d11695c7d6b72164f0 \ No newline at end of file diff --git a/extract.ipynb b/extract.ipynb new file mode 100644 index 0000000..f5c6750 --- /dev/null +++ b/extract.ipynb @@ -0,0 +1,177 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from PIL import Image\n", + "import numpy as np\n", + "\n", + "import albumentations as A\n", + "\n", + "%config InlineBackend.figure_formats = ['svg']\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n \n \n \n \n 2022-07-07T10:16:03.003643\n image/svg+xml\n \n \n Matplotlib v3.5.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "img = Image.open(\"/tmp/extract/photo.jpg\").convert(\"RGBA\")\n", + "mask = Image.open(\"/tmp/extract/MASK.PNG\").convert(\"LA\")\n", + "\n", + "plt.figure(figsize=(18, 10))\n", + "\n", + "plt.subplot(1, 2, 1)\n", + "plt.imshow(img)\n", + "\n", + "plt.subplot(1, 2, 2)\n", + "plt.imshow(mask)\n", + "\n", + "ax = plt.gca()\n", + "ax.set_facecolor('xkcd:salmon')\n", + "\n", + "plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n \n \n \n \n 2022-07-07T10:22:39.796028\n image/svg+xml\n \n \n Matplotlib v3.5.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(434, 425)\n" + ] + } + ], + "source": [ + "box = mask.getbbox()\n", + "\n", + "crop_img = img.crop(box)\n", + "crop_mask = mask.crop(box)\n", + "\n", + "plt.figure(figsize=(18, 10))\n", + "\n", + "plt.subplot(2, 2, 1)\n", + "plt.imshow(crop_img)\n", + "\n", + "plt.subplot(2, 2, 2)\n", + "plt.imshow(crop_mask)\n", + "\n", + "ax = plt.gca()\n", + "ax.set_facecolor('xkcd:salmon')\n", + "\n", + "empty = Image.fromarray(np.zeros(crop_img.size), \"RGBA\")\n", + "empty.paste(crop_img, crop_mask)\n", + "\n", + "plt.subplot(2, 2, 3)\n", + "plt.imshow(empty.resize((100, 100)))\n", + "\n", + "plt.subplot(2, 2, 4)\n", + "plt.imshow(crop_mask.resize((100, 100)))\n", + "\n", + "ax = plt.gca()\n", + "ax.set_facecolor('xkcd:salmon')\n", + "\n", + "plt.show()\n", + "\n", + "print(crop_img.size)" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "spheres_to_extract_dir = \"/home/lilian/data_disk/lfainsin/test/\"\n", + "\n", + "spheres = list(Path(spheres_to_extract_dir).glob(\"**/*.jpg\"))\n", + "\n", + "parents = [path.parent for path in spheres]\n", + "parents = set(parents)\n", + "\n", + "for parent in parents:\n", + " mask_path = parent.joinpath(\"MASK.PNG\")\n", + " mask = Image.open(mask_path).convert(\"LA\")\n", + " box = mask.getbbox()\n", + " crop_mask = mask.crop(box)\n", + "\n", + " filename = Path(\"/tmp/saves/\" + str(mask_path).strip(spheres_to_extract_dir))\n", + " filename.parent.mkdir(parents=True, exist_ok=True)\n", + " crop_mask.save(filename)\n", + "\n", + " spheres = list(parent.glob(\"*.jpg\"))\n", + " for sphere in spheres:\n", + " img = Image.open(sphere).convert(\"RGB\")\n", + " crop_img = img.crop(box)\n", + "\n", + " filename = Path(\"/tmp/saves/\" + str(sphere).strip(spheres_to_extract_dir))\n", + " filename.parent.mkdir(parents=True, exist_ok=True)\n", + " crop_img.save(filename)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.0 ('.venv': poetry)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "dc80d2c03865715c8671359a6bf138f6c8ae4e26ae025f2543e0980b8db0ed7e" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/train.py b/src/train.py index d97802d..eb1b115 100644 --- a/src/train.py +++ b/src/train.py @@ -15,16 +15,22 @@ CONFIG = { "DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/", "DIR_SPHERE_IMG": "/home/lilian/data_disk/lfainsin/spheres/Images/", "DIR_SPHERE_MASK": "/home/lilian/data_disk/lfainsin/spheres/Masks/", - "FEATURES": [16, 32, 64, 128], + # "FEATURES": [1, 2, 4, 8], + # "FEATURES": [4, 8, 16, 32], + # "FEATURES": [8, 16, 32, 64], + # "FEATURES": [4, 8, 16, 32, 64], + "FEATURES": [8, 16, 32, 64, 128], + # "FEATURES": [16, 32, 64, 128], + # "FEATURES": [64, 128, 256, 512], "N_CHANNELS": 3, "N_CLASSES": 1, "AMP": True, "PIN_MEMORY": True, "BENCHMARK": True, "DEVICE": "gpu", - "WORKERS": 8, - "EPOCHS": 10, - "BATCH_SIZE": 16, + "WORKERS": 10, + "EPOCHS": 1, + "BATCH_SIZE": 32, "LEARNING_RATE": 1e-4, "WEIGHT_DECAY": 1e-8, "MOMENTUM": 0.9, diff --git a/src/unet/model.py b/src/unet/model.py index 4e87827..4508b27 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -82,7 +82,7 @@ class UNet(pl.LightningModule): ) ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train) - ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000))) + # ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000))) return DataLoader( ds_train, @@ -178,6 +178,8 @@ class UNet(pl.LightningModule): }, }, ), + dice, + dice_bin, ] ) @@ -199,7 +201,7 @@ class UNet(pl.LightningModule): mae = torch.stack([d["mae"] for d in validation_outputs]).mean() # table unpacking - columns = ["ID", "image", "ground truth", "prediction"] + columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"] rowss = [d["table_rows"] for d in validation_outputs] rows = list(itertools.chain.from_iterable(rowss)) diff --git a/src/utils/paste.py b/src/utils/paste.py index a1e24e4..a25289b 100644 --- a/src/utils/paste.py +++ b/src/utils/paste.py @@ -1,5 +1,6 @@ import os import random as rd +from pathlib import Path import albumentations as A import numpy as np @@ -22,15 +23,15 @@ class RandomPaste(A.DualTransform): def __init__( self, nb, - path_paste_img_dir, - path_paste_mask_dir, + image_dir, scale_range=(0.1, 0.2), always_apply=True, p=1.0, ): super().__init__(always_apply, p) - self.path_paste_img_dir = path_paste_img_dir - self.path_paste_mask_dir = path_paste_mask_dir + self.images = [] + self.images.extend(list(Path(image_dir).glob("**/*.jpg"))) + self.images.extend(list(Path(image_dir).glob("**/*.png"))) self.scale_range = scale_range self.nb = nb @@ -69,14 +70,15 @@ class RandomPaste(A.DualTransform): return False def get_params_dependent_on_targets(self, params): - # choose a random image inside the image folder - filename = rd.choice(os.listdir(self.path_paste_img_dir)) + # choose a random image and its corresponding mask + img_path = rd.choice(self.images) + mask_path = img_path.parent.joinpath("MASK.PNG") # load the "paste" image paste_img = Image.open( os.path.join( self.path_paste_img_dir, - filename, + img_path, ) ).convert("RGBA") @@ -84,25 +86,23 @@ class RandomPaste(A.DualTransform): paste_mask = Image.open( os.path.join( self.path_paste_mask_dir, - filename, + mask_path, ) ).convert("LA") # load the target image target_img = params["image"] + + # compute shapes, for easier computations target_shape = np.array(target_img.shape[:2], dtype=np.uint) paste_shape = np.array(paste_img.size, dtype=np.uint) - # change paste_img's brightness randomly - filter = ImageEnhance.Brightness(paste_img) - paste_img = filter.enhance(rd.uniform(0.5, 1.5)) - # change paste_img's contrast randomly filter = ImageEnhance.Contrast(paste_img) paste_img = filter.enhance(rd.uniform(0.5, 1.5)) - # change paste_img's sharpness randomly - filter = ImageEnhance.Sharpness(paste_img) + # change paste_img's brightness randomly + filter = ImageEnhance.Brightness(paste_img) paste_img = filter.enhance(rd.uniform(0.5, 1.5)) # compute the minimum scaling to fit inside target image