feat: better pasting function

Former-commit-id: 43fedd3f6bafb51fe604e347f59b70cd5b0cc218 [formerly 51bb06c3b98df613710b329d3ade1febaf2b0b23]
Former-commit-id: 46b89acd2b860d272ce8a13cf2c8c955d7545c46
This commit is contained in:
Your Name 2022-06-28 16:36:50 +02:00
parent 1e388c6b90
commit 92ac3a2ab8
4 changed files with 70 additions and 64 deletions

View file

@ -0,0 +1 @@
a27ff5fec9fd71b7846a70ebc473984e859912b8

View file

@ -1,8 +1,8 @@
import torch
import torch.nn.functional as F
from tqdm import tqdm
from src.utils.dice import dice_coeff, multiclass_dice_coeff
import wandb
from src.utils.dice import dice_coeff
def evaluate(net, dataloader, device):
@ -27,6 +27,12 @@ def evaluate(net, dataloader, device):
pbar.update(images.shape[0])
# save some images to wandb
table = wandb.Table(columns=["image", "mask", "prediction"])
for img, mask, pred in zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu")):
table.add_data(wandb.Image(img), wandb.Image(mask), wandb.Image(pred))
wandb.log({"predictions_table": table}, commit=False)
net.train()
# Fixes a potential division by zero error

View file

@ -14,7 +14,6 @@ from tqdm import tqdm
import wandb
from evaluate import evaluate
from src.utils.dataset import SphereDataset
from src.utils.dice import dice_loss
from unet import UNet
from utils.paste import RandomPaste
@ -43,7 +42,7 @@ def get_args():
dest="batch_size",
metavar="B",
type=int,
default=10,
default=16,
help="Batch size",
)
parser.add_argument(
@ -195,21 +194,21 @@ def main():
wandb.log( # log training metrics
{
"train/epoch": epoch + step / epoch,
"train/epoch": epoch + step / len(train_loader),
"train/train_loss": train_loss,
}
)
# Evaluation round
val_loss = evaluate(net, val_loader, device)
scheduler.step(val_loss)
val_score = evaluate(net, val_loader, device)
scheduler.step(val_score)
wandb.log( # log validation metrics
{
"val/val_loss": val_loss,
"val/val_score": val_score,
}
)
print(f"Train Loss: {train_loss:.3f}, Valid Loss: {val_loss:3f}")
print(f"Train Loss: {train_loss:.3f}, Valid Score: {val_score:3f}")
# save weights when epoch end
Path(CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True)

View file

@ -2,7 +2,6 @@ import os
import random as rd
import albumentations as A
import cv2
import numpy as np
from PIL import Image
@ -23,16 +22,16 @@ class RandomPaste(A.DualTransform):
def __init__(
self,
nb,
scale_limit,
path_paste_img_dir,
path_paste_mask_dir,
scale_range=(0.1, 0.2),
always_apply=True,
p=1.0,
):
super().__init__(always_apply, p)
self.path_paste_img_dir = path_paste_img_dir
self.path_paste_mask_dir = path_paste_mask_dir
self.scale_limit = scale_limit
self.scale_range = scale_range
self.nb = nb
@property
@ -40,76 +39,77 @@ class RandomPaste(A.DualTransform):
return ["image"]
def apply(self, img, positions, paste_img, paste_mask, **params):
img = img.copy()
# convert img to Image, needed for `paste` function
img = Image.fromarray(img)
w, h = paste_mask.shape
mask_b = paste_mask > 0
mask_rgb_b = np.stack([mask_b, mask_b, mask_b], axis=2)
for pos in positions:
img.paste(paste_img, pos, paste_mask)
for (x, y) in positions:
img[x : x + w, y : y + h] = img[x : x + w, y : y + h] * ~mask_rgb_b + paste_img * mask_rgb_b
return img
return np.asarray(img.convert("RGB"))
def apply_to_mask(self, mask, positions, paste_mask, **params):
mask = mask.copy()
# convert mask to Image, needed for `paste` function
mask = Image.fromarray(mask)
w, h = paste_mask.shape
mask_b = paste_mask > 0
for pos in positions:
mask.paste(paste_mask, pos, paste_mask)
for (x, y) in positions:
mask[x : x + w, y : y + h] = mask[x : x + w, y : y + h] * ~mask_b + mask_b
return mask
return np.asarray(mask.convert("L"))
def get_params_dependent_on_targets(self, params):
# choose a random image inside the image folder
filename = rd.choice(os.listdir(self.path_paste_img_dir))
paste_img = np.array(
Image.open(
os.path.join(
self.path_paste_img_dir,
filename,
)
).convert("RGB"),
dtype=np.uint8,
)
paste_mask = (
np.array(
Image.open(
os.path.join(
self.path_paste_mask_dir,
filename,
)
).convert("L"),
dtype=np.float32,
# load the "paste" image
paste_img = Image.open(
os.path.join(
self.path_paste_img_dir,
filename,
)
/ 255
)
).convert("RGBA")
# load its respective mask
paste_mask = Image.open(
os.path.join(
self.path_paste_mask_dir,
filename,
)
).convert("LA")
# load the target image
target_img = params["image"]
target_shape = np.array(target_img.shape[:2], dtype=np.uint)
paste_shape = np.array(paste_img.size, dtype=np.uint)
min_scale = min(
target_img.shape[0] / paste_img.shape[0],
target_img.shape[1] / paste_img.shape[1],
# compute the minimum scaling to fit inside target image
min_scale = np.min(target_shape / paste_shape)
# randomize the relative scaling
scale = rd.uniform(*self.scale_range)
# rotate the image and its mask
angle = rd.uniform(0, 360)
paste_img = paste_img.rotate(angle, expand=True)
paste_mask = paste_mask.rotate(angle, expand=True)
# scale the "paste" image and its mask
paste_img = paste_img.resize(
tuple((paste_shape * min_scale * scale).astype(np.uint)),
resample=Image.Resampling.LANCZOS,
)
paste_mask = paste_mask.resize(
tuple((paste_shape * min_scale * scale).astype(np.uint)),
resample=Image.Resampling.LANCZOS,
)
rescale_rotate = A.Compose(
[
A.Rotate(limit=360, always_apply=True, border_mode=cv2.BORDER_CONSTANT),
A.RandomScale(scale_limit=(min_scale * self.scale_limit - 1, -0.99), always_apply=True),
],
)
augmentations = rescale_rotate(image=paste_img, mask=paste_mask)
paste_img = augmentations["image"]
paste_mask = augmentations["mask"]
# update paste_shape after scaling
paste_shape = np.array(paste_img.size, dtype=np.uint)
# generate some positions
positions = []
for _ in range(rd.randint(1, self.nb)):
x = rd.randint(0, target_img.shape[0] - paste_img.shape[0])
y = rd.randint(0, target_img.shape[1] - paste_img.shape[1])
x = rd.randint(0, target_shape[0] - paste_shape[0])
y = rd.randint(0, target_shape[1] - paste_shape[1])
positions.append((x, y))
params.update(
@ -123,4 +123,4 @@ class RandomPaste(A.DualTransform):
return params
def get_transform_init_args_names(self):
return "scale_limit", "path_paste_img_dir", "path_paste_mask_dir"
return "scale_range", "path_paste_img_dir", "path_paste_mask_dir"