feat: better pasting function
Former-commit-id: 43fedd3f6bafb51fe604e347f59b70cd5b0cc218 [formerly 51bb06c3b98df613710b329d3ade1febaf2b0b23] Former-commit-id: 46b89acd2b860d272ce8a13cf2c8c955d7545c46
This commit is contained in:
parent
1e388c6b90
commit
92ac3a2ab8
1
comp.ipynb.REMOVED.git-id
Normal file
1
comp.ipynb.REMOVED.git-id
Normal file
|
@ -0,0 +1 @@
|
|||
a27ff5fec9fd71b7846a70ebc473984e859912b8
|
|
@ -1,8 +1,8 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from tqdm import tqdm
|
||||
|
||||
from src.utils.dice import dice_coeff, multiclass_dice_coeff
|
||||
import wandb
|
||||
from src.utils.dice import dice_coeff
|
||||
|
||||
|
||||
def evaluate(net, dataloader, device):
|
||||
|
@ -27,6 +27,12 @@ def evaluate(net, dataloader, device):
|
|||
|
||||
pbar.update(images.shape[0])
|
||||
|
||||
# save some images to wandb
|
||||
table = wandb.Table(columns=["image", "mask", "prediction"])
|
||||
for img, mask, pred in zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu")):
|
||||
table.add_data(wandb.Image(img), wandb.Image(mask), wandb.Image(pred))
|
||||
wandb.log({"predictions_table": table}, commit=False)
|
||||
|
||||
net.train()
|
||||
|
||||
# Fixes a potential division by zero error
|
||||
|
|
13
src/train.py
13
src/train.py
|
@ -14,7 +14,6 @@ from tqdm import tqdm
|
|||
import wandb
|
||||
from evaluate import evaluate
|
||||
from src.utils.dataset import SphereDataset
|
||||
from src.utils.dice import dice_loss
|
||||
from unet import UNet
|
||||
from utils.paste import RandomPaste
|
||||
|
||||
|
@ -43,7 +42,7 @@ def get_args():
|
|||
dest="batch_size",
|
||||
metavar="B",
|
||||
type=int,
|
||||
default=10,
|
||||
default=16,
|
||||
help="Batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
@ -195,21 +194,21 @@ def main():
|
|||
|
||||
wandb.log( # log training metrics
|
||||
{
|
||||
"train/epoch": epoch + step / epoch,
|
||||
"train/epoch": epoch + step / len(train_loader),
|
||||
"train/train_loss": train_loss,
|
||||
}
|
||||
)
|
||||
|
||||
# Evaluation round
|
||||
val_loss = evaluate(net, val_loader, device)
|
||||
scheduler.step(val_loss)
|
||||
val_score = evaluate(net, val_loader, device)
|
||||
scheduler.step(val_score)
|
||||
wandb.log( # log validation metrics
|
||||
{
|
||||
"val/val_loss": val_loss,
|
||||
"val/val_score": val_score,
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Train Loss: {train_loss:.3f}, Valid Loss: {val_loss:3f}")
|
||||
print(f"Train Loss: {train_loss:.3f}, Valid Score: {val_score:3f}")
|
||||
|
||||
# save weights when epoch end
|
||||
Path(CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
|
|
@ -2,7 +2,6 @@ import os
|
|||
import random as rd
|
||||
|
||||
import albumentations as A
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
@ -23,16 +22,16 @@ class RandomPaste(A.DualTransform):
|
|||
def __init__(
|
||||
self,
|
||||
nb,
|
||||
scale_limit,
|
||||
path_paste_img_dir,
|
||||
path_paste_mask_dir,
|
||||
scale_range=(0.1, 0.2),
|
||||
always_apply=True,
|
||||
p=1.0,
|
||||
):
|
||||
super().__init__(always_apply, p)
|
||||
self.path_paste_img_dir = path_paste_img_dir
|
||||
self.path_paste_mask_dir = path_paste_mask_dir
|
||||
self.scale_limit = scale_limit
|
||||
self.scale_range = scale_range
|
||||
self.nb = nb
|
||||
|
||||
@property
|
||||
|
@ -40,76 +39,77 @@ class RandomPaste(A.DualTransform):
|
|||
return ["image"]
|
||||
|
||||
def apply(self, img, positions, paste_img, paste_mask, **params):
|
||||
img = img.copy()
|
||||
# convert img to Image, needed for `paste` function
|
||||
img = Image.fromarray(img)
|
||||
|
||||
w, h = paste_mask.shape
|
||||
mask_b = paste_mask > 0
|
||||
mask_rgb_b = np.stack([mask_b, mask_b, mask_b], axis=2)
|
||||
for pos in positions:
|
||||
img.paste(paste_img, pos, paste_mask)
|
||||
|
||||
for (x, y) in positions:
|
||||
img[x : x + w, y : y + h] = img[x : x + w, y : y + h] * ~mask_rgb_b + paste_img * mask_rgb_b
|
||||
|
||||
return img
|
||||
return np.asarray(img.convert("RGB"))
|
||||
|
||||
def apply_to_mask(self, mask, positions, paste_mask, **params):
|
||||
mask = mask.copy()
|
||||
# convert mask to Image, needed for `paste` function
|
||||
mask = Image.fromarray(mask)
|
||||
|
||||
w, h = paste_mask.shape
|
||||
mask_b = paste_mask > 0
|
||||
for pos in positions:
|
||||
mask.paste(paste_mask, pos, paste_mask)
|
||||
|
||||
for (x, y) in positions:
|
||||
mask[x : x + w, y : y + h] = mask[x : x + w, y : y + h] * ~mask_b + mask_b
|
||||
|
||||
return mask
|
||||
return np.asarray(mask.convert("L"))
|
||||
|
||||
def get_params_dependent_on_targets(self, params):
|
||||
# choose a random image inside the image folder
|
||||
filename = rd.choice(os.listdir(self.path_paste_img_dir))
|
||||
|
||||
paste_img = np.array(
|
||||
Image.open(
|
||||
# load the "paste" image
|
||||
paste_img = Image.open(
|
||||
os.path.join(
|
||||
self.path_paste_img_dir,
|
||||
filename,
|
||||
)
|
||||
).convert("RGB"),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
).convert("RGBA")
|
||||
|
||||
paste_mask = (
|
||||
np.array(
|
||||
Image.open(
|
||||
# load its respective mask
|
||||
paste_mask = Image.open(
|
||||
os.path.join(
|
||||
self.path_paste_mask_dir,
|
||||
filename,
|
||||
)
|
||||
).convert("L"),
|
||||
dtype=np.float32,
|
||||
)
|
||||
/ 255
|
||||
)
|
||||
).convert("LA")
|
||||
|
||||
# load the target image
|
||||
target_img = params["image"]
|
||||
target_shape = np.array(target_img.shape[:2], dtype=np.uint)
|
||||
paste_shape = np.array(paste_img.size, dtype=np.uint)
|
||||
|
||||
min_scale = min(
|
||||
target_img.shape[0] / paste_img.shape[0],
|
||||
target_img.shape[1] / paste_img.shape[1],
|
||||
# compute the minimum scaling to fit inside target image
|
||||
min_scale = np.min(target_shape / paste_shape)
|
||||
|
||||
# randomize the relative scaling
|
||||
scale = rd.uniform(*self.scale_range)
|
||||
|
||||
# rotate the image and its mask
|
||||
angle = rd.uniform(0, 360)
|
||||
paste_img = paste_img.rotate(angle, expand=True)
|
||||
paste_mask = paste_mask.rotate(angle, expand=True)
|
||||
|
||||
# scale the "paste" image and its mask
|
||||
paste_img = paste_img.resize(
|
||||
tuple((paste_shape * min_scale * scale).astype(np.uint)),
|
||||
resample=Image.Resampling.LANCZOS,
|
||||
)
|
||||
paste_mask = paste_mask.resize(
|
||||
tuple((paste_shape * min_scale * scale).astype(np.uint)),
|
||||
resample=Image.Resampling.LANCZOS,
|
||||
)
|
||||
|
||||
rescale_rotate = A.Compose(
|
||||
[
|
||||
A.Rotate(limit=360, always_apply=True, border_mode=cv2.BORDER_CONSTANT),
|
||||
A.RandomScale(scale_limit=(min_scale * self.scale_limit - 1, -0.99), always_apply=True),
|
||||
],
|
||||
)
|
||||
|
||||
augmentations = rescale_rotate(image=paste_img, mask=paste_mask)
|
||||
paste_img = augmentations["image"]
|
||||
paste_mask = augmentations["mask"]
|
||||
# update paste_shape after scaling
|
||||
paste_shape = np.array(paste_img.size, dtype=np.uint)
|
||||
|
||||
# generate some positions
|
||||
positions = []
|
||||
for _ in range(rd.randint(1, self.nb)):
|
||||
x = rd.randint(0, target_img.shape[0] - paste_img.shape[0])
|
||||
y = rd.randint(0, target_img.shape[1] - paste_img.shape[1])
|
||||
x = rd.randint(0, target_shape[0] - paste_shape[0])
|
||||
y = rd.randint(0, target_shape[1] - paste_shape[1])
|
||||
positions.append((x, y))
|
||||
|
||||
params.update(
|
||||
|
@ -123,4 +123,4 @@ class RandomPaste(A.DualTransform):
|
|||
return params
|
||||
|
||||
def get_transform_init_args_names(self):
|
||||
return "scale_limit", "path_paste_img_dir", "path_paste_mask_dir"
|
||||
return "scale_range", "path_paste_img_dir", "path_paste_mask_dir"
|
||||
|
|
Loading…
Reference in a new issue