feat: better pasting function

Former-commit-id: 43fedd3f6bafb51fe604e347f59b70cd5b0cc218 [formerly 51bb06c3b98df613710b329d3ade1febaf2b0b23]
Former-commit-id: 46b89acd2b860d272ce8a13cf2c8c955d7545c46
This commit is contained in:
Your Name 2022-06-28 16:36:50 +02:00
parent 1e388c6b90
commit 92ac3a2ab8
4 changed files with 70 additions and 64 deletions

View file

@ -0,0 +1 @@
a27ff5fec9fd71b7846a70ebc473984e859912b8

View file

@ -1,8 +1,8 @@
import torch import torch
import torch.nn.functional as F
from tqdm import tqdm from tqdm import tqdm
from src.utils.dice import dice_coeff, multiclass_dice_coeff import wandb
from src.utils.dice import dice_coeff
def evaluate(net, dataloader, device): def evaluate(net, dataloader, device):
@ -27,6 +27,12 @@ def evaluate(net, dataloader, device):
pbar.update(images.shape[0]) pbar.update(images.shape[0])
# save some images to wandb
table = wandb.Table(columns=["image", "mask", "prediction"])
for img, mask, pred in zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu")):
table.add_data(wandb.Image(img), wandb.Image(mask), wandb.Image(pred))
wandb.log({"predictions_table": table}, commit=False)
net.train() net.train()
# Fixes a potential division by zero error # Fixes a potential division by zero error

View file

@ -14,7 +14,6 @@ from tqdm import tqdm
import wandb import wandb
from evaluate import evaluate from evaluate import evaluate
from src.utils.dataset import SphereDataset from src.utils.dataset import SphereDataset
from src.utils.dice import dice_loss
from unet import UNet from unet import UNet
from utils.paste import RandomPaste from utils.paste import RandomPaste
@ -43,7 +42,7 @@ def get_args():
dest="batch_size", dest="batch_size",
metavar="B", metavar="B",
type=int, type=int,
default=10, default=16,
help="Batch size", help="Batch size",
) )
parser.add_argument( parser.add_argument(
@ -195,21 +194,21 @@ def main():
wandb.log( # log training metrics wandb.log( # log training metrics
{ {
"train/epoch": epoch + step / epoch, "train/epoch": epoch + step / len(train_loader),
"train/train_loss": train_loss, "train/train_loss": train_loss,
} }
) )
# Evaluation round # Evaluation round
val_loss = evaluate(net, val_loader, device) val_score = evaluate(net, val_loader, device)
scheduler.step(val_loss) scheduler.step(val_score)
wandb.log( # log validation metrics wandb.log( # log validation metrics
{ {
"val/val_loss": val_loss, "val/val_score": val_score,
} }
) )
print(f"Train Loss: {train_loss:.3f}, Valid Loss: {val_loss:3f}") print(f"Train Loss: {train_loss:.3f}, Valid Score: {val_score:3f}")
# save weights when epoch end # save weights when epoch end
Path(CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True) Path(CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True)

View file

@ -2,7 +2,6 @@ import os
import random as rd import random as rd
import albumentations as A import albumentations as A
import cv2
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@ -23,16 +22,16 @@ class RandomPaste(A.DualTransform):
def __init__( def __init__(
self, self,
nb, nb,
scale_limit,
path_paste_img_dir, path_paste_img_dir,
path_paste_mask_dir, path_paste_mask_dir,
scale_range=(0.1, 0.2),
always_apply=True, always_apply=True,
p=1.0, p=1.0,
): ):
super().__init__(always_apply, p) super().__init__(always_apply, p)
self.path_paste_img_dir = path_paste_img_dir self.path_paste_img_dir = path_paste_img_dir
self.path_paste_mask_dir = path_paste_mask_dir self.path_paste_mask_dir = path_paste_mask_dir
self.scale_limit = scale_limit self.scale_range = scale_range
self.nb = nb self.nb = nb
@property @property
@ -40,76 +39,77 @@ class RandomPaste(A.DualTransform):
return ["image"] return ["image"]
def apply(self, img, positions, paste_img, paste_mask, **params): def apply(self, img, positions, paste_img, paste_mask, **params):
img = img.copy() # convert img to Image, needed for `paste` function
img = Image.fromarray(img)
w, h = paste_mask.shape for pos in positions:
mask_b = paste_mask > 0 img.paste(paste_img, pos, paste_mask)
mask_rgb_b = np.stack([mask_b, mask_b, mask_b], axis=2)
for (x, y) in positions: return np.asarray(img.convert("RGB"))
img[x : x + w, y : y + h] = img[x : x + w, y : y + h] * ~mask_rgb_b + paste_img * mask_rgb_b
return img
def apply_to_mask(self, mask, positions, paste_mask, **params): def apply_to_mask(self, mask, positions, paste_mask, **params):
mask = mask.copy() # convert mask to Image, needed for `paste` function
mask = Image.fromarray(mask)
w, h = paste_mask.shape for pos in positions:
mask_b = paste_mask > 0 mask.paste(paste_mask, pos, paste_mask)
for (x, y) in positions: return np.asarray(mask.convert("L"))
mask[x : x + w, y : y + h] = mask[x : x + w, y : y + h] * ~mask_b + mask_b
return mask
def get_params_dependent_on_targets(self, params): def get_params_dependent_on_targets(self, params):
# choose a random image inside the image folder
filename = rd.choice(os.listdir(self.path_paste_img_dir)) filename = rd.choice(os.listdir(self.path_paste_img_dir))
paste_img = np.array( # load the "paste" image
Image.open( paste_img = Image.open(
os.path.join( os.path.join(
self.path_paste_img_dir, self.path_paste_img_dir,
filename, filename,
) )
).convert("RGB"), ).convert("RGBA")
dtype=np.uint8,
)
paste_mask = ( # load its respective mask
np.array( paste_mask = Image.open(
Image.open(
os.path.join( os.path.join(
self.path_paste_mask_dir, self.path_paste_mask_dir,
filename, filename,
) )
).convert("L"), ).convert("LA")
dtype=np.float32,
)
/ 255
)
# load the target image
target_img = params["image"] target_img = params["image"]
target_shape = np.array(target_img.shape[:2], dtype=np.uint)
paste_shape = np.array(paste_img.size, dtype=np.uint)
min_scale = min( # compute the minimum scaling to fit inside target image
target_img.shape[0] / paste_img.shape[0], min_scale = np.min(target_shape / paste_shape)
target_img.shape[1] / paste_img.shape[1],
# randomize the relative scaling
scale = rd.uniform(*self.scale_range)
# rotate the image and its mask
angle = rd.uniform(0, 360)
paste_img = paste_img.rotate(angle, expand=True)
paste_mask = paste_mask.rotate(angle, expand=True)
# scale the "paste" image and its mask
paste_img = paste_img.resize(
tuple((paste_shape * min_scale * scale).astype(np.uint)),
resample=Image.Resampling.LANCZOS,
)
paste_mask = paste_mask.resize(
tuple((paste_shape * min_scale * scale).astype(np.uint)),
resample=Image.Resampling.LANCZOS,
) )
rescale_rotate = A.Compose( # update paste_shape after scaling
[ paste_shape = np.array(paste_img.size, dtype=np.uint)
A.Rotate(limit=360, always_apply=True, border_mode=cv2.BORDER_CONSTANT),
A.RandomScale(scale_limit=(min_scale * self.scale_limit - 1, -0.99), always_apply=True),
],
)
augmentations = rescale_rotate(image=paste_img, mask=paste_mask)
paste_img = augmentations["image"]
paste_mask = augmentations["mask"]
# generate some positions
positions = [] positions = []
for _ in range(rd.randint(1, self.nb)): for _ in range(rd.randint(1, self.nb)):
x = rd.randint(0, target_img.shape[0] - paste_img.shape[0]) x = rd.randint(0, target_shape[0] - paste_shape[0])
y = rd.randint(0, target_img.shape[1] - paste_img.shape[1]) y = rd.randint(0, target_shape[1] - paste_shape[1])
positions.append((x, y)) positions.append((x, y))
params.update( params.update(
@ -123,4 +123,4 @@ class RandomPaste(A.DualTransform):
return params return params
def get_transform_init_args_names(self): def get_transform_init_args_names(self):
return "scale_limit", "path_paste_img_dir", "path_paste_mask_dir" return "scale_range", "path_paste_img_dir", "path_paste_mask_dir"