feat: better pasting function
Former-commit-id: 43fedd3f6bafb51fe604e347f59b70cd5b0cc218 [formerly 51bb06c3b98df613710b329d3ade1febaf2b0b23] Former-commit-id: 46b89acd2b860d272ce8a13cf2c8c955d7545c46
This commit is contained in:
parent
1e388c6b90
commit
92ac3a2ab8
1
comp.ipynb.REMOVED.git-id
Normal file
1
comp.ipynb.REMOVED.git-id
Normal file
|
@ -0,0 +1 @@
|
||||||
|
a27ff5fec9fd71b7846a70ebc473984e859912b8
|
|
@ -1,8 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from src.utils.dice import dice_coeff, multiclass_dice_coeff
|
import wandb
|
||||||
|
from src.utils.dice import dice_coeff
|
||||||
|
|
||||||
|
|
||||||
def evaluate(net, dataloader, device):
|
def evaluate(net, dataloader, device):
|
||||||
|
@ -27,6 +27,12 @@ def evaluate(net, dataloader, device):
|
||||||
|
|
||||||
pbar.update(images.shape[0])
|
pbar.update(images.shape[0])
|
||||||
|
|
||||||
|
# save some images to wandb
|
||||||
|
table = wandb.Table(columns=["image", "mask", "prediction"])
|
||||||
|
for img, mask, pred in zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu")):
|
||||||
|
table.add_data(wandb.Image(img), wandb.Image(mask), wandb.Image(pred))
|
||||||
|
wandb.log({"predictions_table": table}, commit=False)
|
||||||
|
|
||||||
net.train()
|
net.train()
|
||||||
|
|
||||||
# Fixes a potential division by zero error
|
# Fixes a potential division by zero error
|
||||||
|
|
13
src/train.py
13
src/train.py
|
@ -14,7 +14,6 @@ from tqdm import tqdm
|
||||||
import wandb
|
import wandb
|
||||||
from evaluate import evaluate
|
from evaluate import evaluate
|
||||||
from src.utils.dataset import SphereDataset
|
from src.utils.dataset import SphereDataset
|
||||||
from src.utils.dice import dice_loss
|
|
||||||
from unet import UNet
|
from unet import UNet
|
||||||
from utils.paste import RandomPaste
|
from utils.paste import RandomPaste
|
||||||
|
|
||||||
|
@ -43,7 +42,7 @@ def get_args():
|
||||||
dest="batch_size",
|
dest="batch_size",
|
||||||
metavar="B",
|
metavar="B",
|
||||||
type=int,
|
type=int,
|
||||||
default=10,
|
default=16,
|
||||||
help="Batch size",
|
help="Batch size",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -195,21 +194,21 @@ def main():
|
||||||
|
|
||||||
wandb.log( # log training metrics
|
wandb.log( # log training metrics
|
||||||
{
|
{
|
||||||
"train/epoch": epoch + step / epoch,
|
"train/epoch": epoch + step / len(train_loader),
|
||||||
"train/train_loss": train_loss,
|
"train/train_loss": train_loss,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evaluation round
|
# Evaluation round
|
||||||
val_loss = evaluate(net, val_loader, device)
|
val_score = evaluate(net, val_loader, device)
|
||||||
scheduler.step(val_loss)
|
scheduler.step(val_score)
|
||||||
wandb.log( # log validation metrics
|
wandb.log( # log validation metrics
|
||||||
{
|
{
|
||||||
"val/val_loss": val_loss,
|
"val/val_score": val_score,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Train Loss: {train_loss:.3f}, Valid Loss: {val_loss:3f}")
|
print(f"Train Loss: {train_loss:.3f}, Valid Score: {val_score:3f}")
|
||||||
|
|
||||||
# save weights when epoch end
|
# save weights when epoch end
|
||||||
Path(CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True)
|
Path(CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
|
@ -2,7 +2,6 @@ import os
|
||||||
import random as rd
|
import random as rd
|
||||||
|
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
import cv2
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
@ -23,16 +22,16 @@ class RandomPaste(A.DualTransform):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
nb,
|
nb,
|
||||||
scale_limit,
|
|
||||||
path_paste_img_dir,
|
path_paste_img_dir,
|
||||||
path_paste_mask_dir,
|
path_paste_mask_dir,
|
||||||
|
scale_range=(0.1, 0.2),
|
||||||
always_apply=True,
|
always_apply=True,
|
||||||
p=1.0,
|
p=1.0,
|
||||||
):
|
):
|
||||||
super().__init__(always_apply, p)
|
super().__init__(always_apply, p)
|
||||||
self.path_paste_img_dir = path_paste_img_dir
|
self.path_paste_img_dir = path_paste_img_dir
|
||||||
self.path_paste_mask_dir = path_paste_mask_dir
|
self.path_paste_mask_dir = path_paste_mask_dir
|
||||||
self.scale_limit = scale_limit
|
self.scale_range = scale_range
|
||||||
self.nb = nb
|
self.nb = nb
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -40,76 +39,77 @@ class RandomPaste(A.DualTransform):
|
||||||
return ["image"]
|
return ["image"]
|
||||||
|
|
||||||
def apply(self, img, positions, paste_img, paste_mask, **params):
|
def apply(self, img, positions, paste_img, paste_mask, **params):
|
||||||
img = img.copy()
|
# convert img to Image, needed for `paste` function
|
||||||
|
img = Image.fromarray(img)
|
||||||
|
|
||||||
w, h = paste_mask.shape
|
for pos in positions:
|
||||||
mask_b = paste_mask > 0
|
img.paste(paste_img, pos, paste_mask)
|
||||||
mask_rgb_b = np.stack([mask_b, mask_b, mask_b], axis=2)
|
|
||||||
|
|
||||||
for (x, y) in positions:
|
return np.asarray(img.convert("RGB"))
|
||||||
img[x : x + w, y : y + h] = img[x : x + w, y : y + h] * ~mask_rgb_b + paste_img * mask_rgb_b
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
def apply_to_mask(self, mask, positions, paste_mask, **params):
|
def apply_to_mask(self, mask, positions, paste_mask, **params):
|
||||||
mask = mask.copy()
|
# convert mask to Image, needed for `paste` function
|
||||||
|
mask = Image.fromarray(mask)
|
||||||
|
|
||||||
w, h = paste_mask.shape
|
for pos in positions:
|
||||||
mask_b = paste_mask > 0
|
mask.paste(paste_mask, pos, paste_mask)
|
||||||
|
|
||||||
for (x, y) in positions:
|
return np.asarray(mask.convert("L"))
|
||||||
mask[x : x + w, y : y + h] = mask[x : x + w, y : y + h] * ~mask_b + mask_b
|
|
||||||
|
|
||||||
return mask
|
|
||||||
|
|
||||||
def get_params_dependent_on_targets(self, params):
|
def get_params_dependent_on_targets(self, params):
|
||||||
|
# choose a random image inside the image folder
|
||||||
filename = rd.choice(os.listdir(self.path_paste_img_dir))
|
filename = rd.choice(os.listdir(self.path_paste_img_dir))
|
||||||
|
|
||||||
paste_img = np.array(
|
# load the "paste" image
|
||||||
Image.open(
|
paste_img = Image.open(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
self.path_paste_img_dir,
|
self.path_paste_img_dir,
|
||||||
filename,
|
filename,
|
||||||
)
|
)
|
||||||
).convert("RGB"),
|
).convert("RGBA")
|
||||||
dtype=np.uint8,
|
|
||||||
)
|
|
||||||
|
|
||||||
paste_mask = (
|
# load its respective mask
|
||||||
np.array(
|
paste_mask = Image.open(
|
||||||
Image.open(
|
|
||||||
os.path.join(
|
os.path.join(
|
||||||
self.path_paste_mask_dir,
|
self.path_paste_mask_dir,
|
||||||
filename,
|
filename,
|
||||||
)
|
)
|
||||||
).convert("L"),
|
).convert("LA")
|
||||||
dtype=np.float32,
|
|
||||||
)
|
|
||||||
/ 255
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# load the target image
|
||||||
target_img = params["image"]
|
target_img = params["image"]
|
||||||
|
target_shape = np.array(target_img.shape[:2], dtype=np.uint)
|
||||||
|
paste_shape = np.array(paste_img.size, dtype=np.uint)
|
||||||
|
|
||||||
min_scale = min(
|
# compute the minimum scaling to fit inside target image
|
||||||
target_img.shape[0] / paste_img.shape[0],
|
min_scale = np.min(target_shape / paste_shape)
|
||||||
target_img.shape[1] / paste_img.shape[1],
|
|
||||||
|
# randomize the relative scaling
|
||||||
|
scale = rd.uniform(*self.scale_range)
|
||||||
|
|
||||||
|
# rotate the image and its mask
|
||||||
|
angle = rd.uniform(0, 360)
|
||||||
|
paste_img = paste_img.rotate(angle, expand=True)
|
||||||
|
paste_mask = paste_mask.rotate(angle, expand=True)
|
||||||
|
|
||||||
|
# scale the "paste" image and its mask
|
||||||
|
paste_img = paste_img.resize(
|
||||||
|
tuple((paste_shape * min_scale * scale).astype(np.uint)),
|
||||||
|
resample=Image.Resampling.LANCZOS,
|
||||||
|
)
|
||||||
|
paste_mask = paste_mask.resize(
|
||||||
|
tuple((paste_shape * min_scale * scale).astype(np.uint)),
|
||||||
|
resample=Image.Resampling.LANCZOS,
|
||||||
)
|
)
|
||||||
|
|
||||||
rescale_rotate = A.Compose(
|
# update paste_shape after scaling
|
||||||
[
|
paste_shape = np.array(paste_img.size, dtype=np.uint)
|
||||||
A.Rotate(limit=360, always_apply=True, border_mode=cv2.BORDER_CONSTANT),
|
|
||||||
A.RandomScale(scale_limit=(min_scale * self.scale_limit - 1, -0.99), always_apply=True),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
augmentations = rescale_rotate(image=paste_img, mask=paste_mask)
|
|
||||||
paste_img = augmentations["image"]
|
|
||||||
paste_mask = augmentations["mask"]
|
|
||||||
|
|
||||||
|
# generate some positions
|
||||||
positions = []
|
positions = []
|
||||||
for _ in range(rd.randint(1, self.nb)):
|
for _ in range(rd.randint(1, self.nb)):
|
||||||
x = rd.randint(0, target_img.shape[0] - paste_img.shape[0])
|
x = rd.randint(0, target_shape[0] - paste_shape[0])
|
||||||
y = rd.randint(0, target_img.shape[1] - paste_img.shape[1])
|
y = rd.randint(0, target_shape[1] - paste_shape[1])
|
||||||
positions.append((x, y))
|
positions.append((x, y))
|
||||||
|
|
||||||
params.update(
|
params.update(
|
||||||
|
@ -123,4 +123,4 @@ class RandomPaste(A.DualTransform):
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def get_transform_init_args_names(self):
|
def get_transform_init_args_names(self):
|
||||||
return "scale_limit", "path_paste_img_dir", "path_paste_mask_dir"
|
return "scale_range", "path_paste_img_dir", "path_paste_mask_dir"
|
||||||
|
|
Loading…
Reference in a new issue