mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 23:12:05 +00:00
feat: new paste dataset
Former-commit-id: 039874208d5a27bf01beb2746a77502fd836ae5c [formerly 66638fcabaea1044d9a2fd48e6ffb20f149ebf47] Former-commit-id: 6bdf8bba0b3cbd8706337aa3167c36fba8855a4c
This commit is contained in:
parent
b71b57285f
commit
0dd606144f
|
@ -1 +1 @@
|
||||||
9cbd3cff7e664a80a5a1fa1404898b7bba3cae0d
|
3c9a34f197340a6051eb34d11695c7d6b72164f0
|
177
extract.ipynb
Normal file
177
extract.ipynb
Normal file
File diff suppressed because one or more lines are too long
14
src/train.py
14
src/train.py
|
@ -15,16 +15,22 @@ CONFIG = {
|
||||||
"DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/",
|
"DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/",
|
||||||
"DIR_SPHERE_IMG": "/home/lilian/data_disk/lfainsin/spheres/Images/",
|
"DIR_SPHERE_IMG": "/home/lilian/data_disk/lfainsin/spheres/Images/",
|
||||||
"DIR_SPHERE_MASK": "/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
"DIR_SPHERE_MASK": "/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
||||||
"FEATURES": [16, 32, 64, 128],
|
# "FEATURES": [1, 2, 4, 8],
|
||||||
|
# "FEATURES": [4, 8, 16, 32],
|
||||||
|
# "FEATURES": [8, 16, 32, 64],
|
||||||
|
# "FEATURES": [4, 8, 16, 32, 64],
|
||||||
|
"FEATURES": [8, 16, 32, 64, 128],
|
||||||
|
# "FEATURES": [16, 32, 64, 128],
|
||||||
|
# "FEATURES": [64, 128, 256, 512],
|
||||||
"N_CHANNELS": 3,
|
"N_CHANNELS": 3,
|
||||||
"N_CLASSES": 1,
|
"N_CLASSES": 1,
|
||||||
"AMP": True,
|
"AMP": True,
|
||||||
"PIN_MEMORY": True,
|
"PIN_MEMORY": True,
|
||||||
"BENCHMARK": True,
|
"BENCHMARK": True,
|
||||||
"DEVICE": "gpu",
|
"DEVICE": "gpu",
|
||||||
"WORKERS": 8,
|
"WORKERS": 10,
|
||||||
"EPOCHS": 10,
|
"EPOCHS": 1,
|
||||||
"BATCH_SIZE": 16,
|
"BATCH_SIZE": 32,
|
||||||
"LEARNING_RATE": 1e-4,
|
"LEARNING_RATE": 1e-4,
|
||||||
"WEIGHT_DECAY": 1e-8,
|
"WEIGHT_DECAY": 1e-8,
|
||||||
"MOMENTUM": 0.9,
|
"MOMENTUM": 0.9,
|
||||||
|
|
|
@ -82,7 +82,7 @@ class UNet(pl.LightningModule):
|
||||||
)
|
)
|
||||||
|
|
||||||
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
|
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
|
||||||
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000)))
|
# ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
|
||||||
|
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
ds_train,
|
ds_train,
|
||||||
|
@ -178,6 +178,8 @@ class UNet(pl.LightningModule):
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
dice,
|
||||||
|
dice_bin,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -199,7 +201,7 @@ class UNet(pl.LightningModule):
|
||||||
mae = torch.stack([d["mae"] for d in validation_outputs]).mean()
|
mae = torch.stack([d["mae"] for d in validation_outputs]).mean()
|
||||||
|
|
||||||
# table unpacking
|
# table unpacking
|
||||||
columns = ["ID", "image", "ground truth", "prediction"]
|
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
|
||||||
rowss = [d["table_rows"] for d in validation_outputs]
|
rowss = [d["table_rows"] for d in validation_outputs]
|
||||||
rows = list(itertools.chain.from_iterable(rowss))
|
rows = list(itertools.chain.from_iterable(rowss))
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import random as rd
|
import random as rd
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -22,15 +23,15 @@ class RandomPaste(A.DualTransform):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
nb,
|
nb,
|
||||||
path_paste_img_dir,
|
image_dir,
|
||||||
path_paste_mask_dir,
|
|
||||||
scale_range=(0.1, 0.2),
|
scale_range=(0.1, 0.2),
|
||||||
always_apply=True,
|
always_apply=True,
|
||||||
p=1.0,
|
p=1.0,
|
||||||
):
|
):
|
||||||
super().__init__(always_apply, p)
|
super().__init__(always_apply, p)
|
||||||
self.path_paste_img_dir = path_paste_img_dir
|
self.images = []
|
||||||
self.path_paste_mask_dir = path_paste_mask_dir
|
self.images.extend(list(Path(image_dir).glob("**/*.jpg")))
|
||||||
|
self.images.extend(list(Path(image_dir).glob("**/*.png")))
|
||||||
self.scale_range = scale_range
|
self.scale_range = scale_range
|
||||||
self.nb = nb
|
self.nb = nb
|
||||||
|
|
||||||
|
@ -69,14 +70,15 @@ class RandomPaste(A.DualTransform):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_params_dependent_on_targets(self, params):
|
def get_params_dependent_on_targets(self, params):
|
||||||
# choose a random image inside the image folder
|
# choose a random image and its corresponding mask
|
||||||
filename = rd.choice(os.listdir(self.path_paste_img_dir))
|
img_path = rd.choice(self.images)
|
||||||
|
mask_path = img_path.parent.joinpath("MASK.PNG")
|
||||||
|
|
||||||
# load the "paste" image
|
# load the "paste" image
|
||||||
paste_img = Image.open(
|
paste_img = Image.open(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
self.path_paste_img_dir,
|
self.path_paste_img_dir,
|
||||||
filename,
|
img_path,
|
||||||
)
|
)
|
||||||
).convert("RGBA")
|
).convert("RGBA")
|
||||||
|
|
||||||
|
@ -84,25 +86,23 @@ class RandomPaste(A.DualTransform):
|
||||||
paste_mask = Image.open(
|
paste_mask = Image.open(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
self.path_paste_mask_dir,
|
self.path_paste_mask_dir,
|
||||||
filename,
|
mask_path,
|
||||||
)
|
)
|
||||||
).convert("LA")
|
).convert("LA")
|
||||||
|
|
||||||
# load the target image
|
# load the target image
|
||||||
target_img = params["image"]
|
target_img = params["image"]
|
||||||
|
|
||||||
|
# compute shapes, for easier computations
|
||||||
target_shape = np.array(target_img.shape[:2], dtype=np.uint)
|
target_shape = np.array(target_img.shape[:2], dtype=np.uint)
|
||||||
paste_shape = np.array(paste_img.size, dtype=np.uint)
|
paste_shape = np.array(paste_img.size, dtype=np.uint)
|
||||||
|
|
||||||
# change paste_img's brightness randomly
|
|
||||||
filter = ImageEnhance.Brightness(paste_img)
|
|
||||||
paste_img = filter.enhance(rd.uniform(0.5, 1.5))
|
|
||||||
|
|
||||||
# change paste_img's contrast randomly
|
# change paste_img's contrast randomly
|
||||||
filter = ImageEnhance.Contrast(paste_img)
|
filter = ImageEnhance.Contrast(paste_img)
|
||||||
paste_img = filter.enhance(rd.uniform(0.5, 1.5))
|
paste_img = filter.enhance(rd.uniform(0.5, 1.5))
|
||||||
|
|
||||||
# change paste_img's sharpness randomly
|
# change paste_img's brightness randomly
|
||||||
filter = ImageEnhance.Sharpness(paste_img)
|
filter = ImageEnhance.Brightness(paste_img)
|
||||||
paste_img = filter.enhance(rd.uniform(0.5, 1.5))
|
paste_img = filter.enhance(rd.uniform(0.5, 1.5))
|
||||||
|
|
||||||
# compute the minimum scaling to fit inside target image
|
# compute the minimum scaling to fit inside target image
|
||||||
|
|
Loading…
Reference in a new issue