From 92ac3a2ab889a298fd8a898d662f5d2a8a4f8a43 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 28 Jun 2022 16:36:50 +0200 Subject: [PATCH] feat: better pasting function Former-commit-id: 43fedd3f6bafb51fe604e347f59b70cd5b0cc218 [formerly 51bb06c3b98df613710b329d3ade1febaf2b0b23] Former-commit-id: 46b89acd2b860d272ce8a13cf2c8c955d7545c46 --- comp.ipynb.REMOVED.git-id | 1 + src/evaluate.py | 10 +++- src/train.py | 13 +++-- src/utils/paste.py | 110 +++++++++++++++++++------------------- 4 files changed, 70 insertions(+), 64 deletions(-) create mode 100644 comp.ipynb.REMOVED.git-id diff --git a/comp.ipynb.REMOVED.git-id b/comp.ipynb.REMOVED.git-id new file mode 100644 index 0000000..5ba568f --- /dev/null +++ b/comp.ipynb.REMOVED.git-id @@ -0,0 +1 @@ +a27ff5fec9fd71b7846a70ebc473984e859912b8 \ No newline at end of file diff --git a/src/evaluate.py b/src/evaluate.py index 3d56d51..af3d1dc 100644 --- a/src/evaluate.py +++ b/src/evaluate.py @@ -1,8 +1,8 @@ import torch -import torch.nn.functional as F from tqdm import tqdm -from src.utils.dice import dice_coeff, multiclass_dice_coeff +import wandb +from src.utils.dice import dice_coeff def evaluate(net, dataloader, device): @@ -27,6 +27,12 @@ def evaluate(net, dataloader, device): pbar.update(images.shape[0]) + # save some images to wandb + table = wandb.Table(columns=["image", "mask", "prediction"]) + for img, mask, pred in zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu")): + table.add_data(wandb.Image(img), wandb.Image(mask), wandb.Image(pred)) + wandb.log({"predictions_table": table}, commit=False) + net.train() # Fixes a potential division by zero error diff --git a/src/train.py b/src/train.py index 768de17..fcd901e 100644 --- a/src/train.py +++ b/src/train.py @@ -14,7 +14,6 @@ from tqdm import tqdm import wandb from evaluate import evaluate from src.utils.dataset import SphereDataset -from src.utils.dice import dice_loss from unet import UNet from utils.paste import RandomPaste @@ -43,7 +42,7 @@ def get_args(): dest="batch_size", metavar="B", type=int, - default=10, + default=16, help="Batch size", ) parser.add_argument( @@ -195,21 +194,21 @@ def main(): wandb.log( # log training metrics { - "train/epoch": epoch + step / epoch, + "train/epoch": epoch + step / len(train_loader), "train/train_loss": train_loss, } ) # Evaluation round - val_loss = evaluate(net, val_loader, device) - scheduler.step(val_loss) + val_score = evaluate(net, val_loader, device) + scheduler.step(val_score) wandb.log( # log validation metrics { - "val/val_loss": val_loss, + "val/val_score": val_score, } ) - print(f"Train Loss: {train_loss:.3f}, Valid Loss: {val_loss:3f}") + print(f"Train Loss: {train_loss:.3f}, Valid Score: {val_score:3f}") # save weights when epoch end Path(CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True) diff --git a/src/utils/paste.py b/src/utils/paste.py index 0f87c7e..5ad8961 100644 --- a/src/utils/paste.py +++ b/src/utils/paste.py @@ -2,7 +2,6 @@ import os import random as rd import albumentations as A -import cv2 import numpy as np from PIL import Image @@ -23,16 +22,16 @@ class RandomPaste(A.DualTransform): def __init__( self, nb, - scale_limit, path_paste_img_dir, path_paste_mask_dir, + scale_range=(0.1, 0.2), always_apply=True, p=1.0, ): super().__init__(always_apply, p) self.path_paste_img_dir = path_paste_img_dir self.path_paste_mask_dir = path_paste_mask_dir - self.scale_limit = scale_limit + self.scale_range = scale_range self.nb = nb @property @@ -40,76 +39,77 @@ class RandomPaste(A.DualTransform): return ["image"] def apply(self, img, positions, paste_img, paste_mask, **params): - img = img.copy() + # convert img to Image, needed for `paste` function + img = Image.fromarray(img) - w, h = paste_mask.shape - mask_b = paste_mask > 0 - mask_rgb_b = np.stack([mask_b, mask_b, mask_b], axis=2) + for pos in positions: + img.paste(paste_img, pos, paste_mask) - for (x, y) in positions: - img[x : x + w, y : y + h] = img[x : x + w, y : y + h] * ~mask_rgb_b + paste_img * mask_rgb_b - - return img + return np.asarray(img.convert("RGB")) def apply_to_mask(self, mask, positions, paste_mask, **params): - mask = mask.copy() + # convert mask to Image, needed for `paste` function + mask = Image.fromarray(mask) - w, h = paste_mask.shape - mask_b = paste_mask > 0 + for pos in positions: + mask.paste(paste_mask, pos, paste_mask) - for (x, y) in positions: - mask[x : x + w, y : y + h] = mask[x : x + w, y : y + h] * ~mask_b + mask_b - - return mask + return np.asarray(mask.convert("L")) def get_params_dependent_on_targets(self, params): + # choose a random image inside the image folder filename = rd.choice(os.listdir(self.path_paste_img_dir)) - paste_img = np.array( - Image.open( - os.path.join( - self.path_paste_img_dir, - filename, - ) - ).convert("RGB"), - dtype=np.uint8, - ) - - paste_mask = ( - np.array( - Image.open( - os.path.join( - self.path_paste_mask_dir, - filename, - ) - ).convert("L"), - dtype=np.float32, + # load the "paste" image + paste_img = Image.open( + os.path.join( + self.path_paste_img_dir, + filename, ) - / 255 - ) + ).convert("RGBA") + # load its respective mask + paste_mask = Image.open( + os.path.join( + self.path_paste_mask_dir, + filename, + ) + ).convert("LA") + + # load the target image target_img = params["image"] + target_shape = np.array(target_img.shape[:2], dtype=np.uint) + paste_shape = np.array(paste_img.size, dtype=np.uint) - min_scale = min( - target_img.shape[0] / paste_img.shape[0], - target_img.shape[1] / paste_img.shape[1], + # compute the minimum scaling to fit inside target image + min_scale = np.min(target_shape / paste_shape) + + # randomize the relative scaling + scale = rd.uniform(*self.scale_range) + + # rotate the image and its mask + angle = rd.uniform(0, 360) + paste_img = paste_img.rotate(angle, expand=True) + paste_mask = paste_mask.rotate(angle, expand=True) + + # scale the "paste" image and its mask + paste_img = paste_img.resize( + tuple((paste_shape * min_scale * scale).astype(np.uint)), + resample=Image.Resampling.LANCZOS, + ) + paste_mask = paste_mask.resize( + tuple((paste_shape * min_scale * scale).astype(np.uint)), + resample=Image.Resampling.LANCZOS, ) - rescale_rotate = A.Compose( - [ - A.Rotate(limit=360, always_apply=True, border_mode=cv2.BORDER_CONSTANT), - A.RandomScale(scale_limit=(min_scale * self.scale_limit - 1, -0.99), always_apply=True), - ], - ) - - augmentations = rescale_rotate(image=paste_img, mask=paste_mask) - paste_img = augmentations["image"] - paste_mask = augmentations["mask"] + # update paste_shape after scaling + paste_shape = np.array(paste_img.size, dtype=np.uint) + # generate some positions positions = [] for _ in range(rd.randint(1, self.nb)): - x = rd.randint(0, target_img.shape[0] - paste_img.shape[0]) - y = rd.randint(0, target_img.shape[1] - paste_img.shape[1]) + x = rd.randint(0, target_shape[0] - paste_shape[0]) + y = rd.randint(0, target_shape[1] - paste_shape[1]) positions.append((x, y)) params.update( @@ -123,4 +123,4 @@ class RandomPaste(A.DualTransform): return params def get_transform_init_args_names(self): - return "scale_limit", "path_paste_img_dir", "path_paste_mask_dir" + return "scale_range", "path_paste_img_dir", "path_paste_mask_dir"