feat: new paste dataset

Former-commit-id: 039874208d5a27bf01beb2746a77502fd836ae5c [formerly 66638fcabaea1044d9a2fd48e6ffb20f149ebf47]
Former-commit-id: 6bdf8bba0b3cbd8706337aa3167c36fba8855a4c
This commit is contained in:
Laurent Fainsin 2022-07-07 12:06:41 +02:00
parent b71b57285f
commit 0dd606144f
5 changed files with 206 additions and 21 deletions

View file

@ -1 +1 @@
9cbd3cff7e664a80a5a1fa1404898b7bba3cae0d 3c9a34f197340a6051eb34d11695c7d6b72164f0

177
extract.ipynb Normal file

File diff suppressed because one or more lines are too long

View file

@ -15,16 +15,22 @@ CONFIG = {
"DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/", "DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/",
"DIR_SPHERE_IMG": "/home/lilian/data_disk/lfainsin/spheres/Images/", "DIR_SPHERE_IMG": "/home/lilian/data_disk/lfainsin/spheres/Images/",
"DIR_SPHERE_MASK": "/home/lilian/data_disk/lfainsin/spheres/Masks/", "DIR_SPHERE_MASK": "/home/lilian/data_disk/lfainsin/spheres/Masks/",
"FEATURES": [16, 32, 64, 128], # "FEATURES": [1, 2, 4, 8],
# "FEATURES": [4, 8, 16, 32],
# "FEATURES": [8, 16, 32, 64],
# "FEATURES": [4, 8, 16, 32, 64],
"FEATURES": [8, 16, 32, 64, 128],
# "FEATURES": [16, 32, 64, 128],
# "FEATURES": [64, 128, 256, 512],
"N_CHANNELS": 3, "N_CHANNELS": 3,
"N_CLASSES": 1, "N_CLASSES": 1,
"AMP": True, "AMP": True,
"PIN_MEMORY": True, "PIN_MEMORY": True,
"BENCHMARK": True, "BENCHMARK": True,
"DEVICE": "gpu", "DEVICE": "gpu",
"WORKERS": 8, "WORKERS": 10,
"EPOCHS": 10, "EPOCHS": 1,
"BATCH_SIZE": 16, "BATCH_SIZE": 32,
"LEARNING_RATE": 1e-4, "LEARNING_RATE": 1e-4,
"WEIGHT_DECAY": 1e-8, "WEIGHT_DECAY": 1e-8,
"MOMENTUM": 0.9, "MOMENTUM": 0.9,

View file

@ -82,7 +82,7 @@ class UNet(pl.LightningModule):
) )
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train) ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000))) # ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
return DataLoader( return DataLoader(
ds_train, ds_train,
@ -178,6 +178,8 @@ class UNet(pl.LightningModule):
}, },
}, },
), ),
dice,
dice_bin,
] ]
) )
@ -199,7 +201,7 @@ class UNet(pl.LightningModule):
mae = torch.stack([d["mae"] for d in validation_outputs]).mean() mae = torch.stack([d["mae"] for d in validation_outputs]).mean()
# table unpacking # table unpacking
columns = ["ID", "image", "ground truth", "prediction"] columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
rowss = [d["table_rows"] for d in validation_outputs] rowss = [d["table_rows"] for d in validation_outputs]
rows = list(itertools.chain.from_iterable(rowss)) rows = list(itertools.chain.from_iterable(rowss))

View file

@ -1,5 +1,6 @@
import os import os
import random as rd import random as rd
from pathlib import Path
import albumentations as A import albumentations as A
import numpy as np import numpy as np
@ -22,15 +23,15 @@ class RandomPaste(A.DualTransform):
def __init__( def __init__(
self, self,
nb, nb,
path_paste_img_dir, image_dir,
path_paste_mask_dir,
scale_range=(0.1, 0.2), scale_range=(0.1, 0.2),
always_apply=True, always_apply=True,
p=1.0, p=1.0,
): ):
super().__init__(always_apply, p) super().__init__(always_apply, p)
self.path_paste_img_dir = path_paste_img_dir self.images = []
self.path_paste_mask_dir = path_paste_mask_dir self.images.extend(list(Path(image_dir).glob("**/*.jpg")))
self.images.extend(list(Path(image_dir).glob("**/*.png")))
self.scale_range = scale_range self.scale_range = scale_range
self.nb = nb self.nb = nb
@ -69,14 +70,15 @@ class RandomPaste(A.DualTransform):
return False return False
def get_params_dependent_on_targets(self, params): def get_params_dependent_on_targets(self, params):
# choose a random image inside the image folder # choose a random image and its corresponding mask
filename = rd.choice(os.listdir(self.path_paste_img_dir)) img_path = rd.choice(self.images)
mask_path = img_path.parent.joinpath("MASK.PNG")
# load the "paste" image # load the "paste" image
paste_img = Image.open( paste_img = Image.open(
os.path.join( os.path.join(
self.path_paste_img_dir, self.path_paste_img_dir,
filename, img_path,
) )
).convert("RGBA") ).convert("RGBA")
@ -84,25 +86,23 @@ class RandomPaste(A.DualTransform):
paste_mask = Image.open( paste_mask = Image.open(
os.path.join( os.path.join(
self.path_paste_mask_dir, self.path_paste_mask_dir,
filename, mask_path,
) )
).convert("LA") ).convert("LA")
# load the target image # load the target image
target_img = params["image"] target_img = params["image"]
# compute shapes, for easier computations
target_shape = np.array(target_img.shape[:2], dtype=np.uint) target_shape = np.array(target_img.shape[:2], dtype=np.uint)
paste_shape = np.array(paste_img.size, dtype=np.uint) paste_shape = np.array(paste_img.size, dtype=np.uint)
# change paste_img's brightness randomly
filter = ImageEnhance.Brightness(paste_img)
paste_img = filter.enhance(rd.uniform(0.5, 1.5))
# change paste_img's contrast randomly # change paste_img's contrast randomly
filter = ImageEnhance.Contrast(paste_img) filter = ImageEnhance.Contrast(paste_img)
paste_img = filter.enhance(rd.uniform(0.5, 1.5)) paste_img = filter.enhance(rd.uniform(0.5, 1.5))
# change paste_img's sharpness randomly # change paste_img's brightness randomly
filter = ImageEnhance.Sharpness(paste_img) filter = ImageEnhance.Brightness(paste_img)
paste_img = filter.enhance(rd.uniform(0.5, 1.5)) paste_img = filter.enhance(rd.uniform(0.5, 1.5))
# compute the minimum scaling to fit inside target image # compute the minimum scaling to fit inside target image