From 8611d8cd7a372a40df31a940e3567c50013902a1 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Thu, 7 Jul 2022 16:31:53 +0200 Subject: [PATCH] feat: better paste augmentation Former-commit-id: 2adef7920e5f317ac3fbe0205862e29d49c2af8f [formerly 41cb0c231b00a1e992847723eb754af1a9e28eee] Former-commit-id: f826c62f4aa3b0c9d2ea7b49f49b5839072ff259 --- comp.ipynb.REMOVED.git-id | 2 +- src/train.py | 6 -- src/utils/paste.py | 157 +++++++++++++++++++++++--------------- 3 files changed, 97 insertions(+), 68 deletions(-) diff --git a/comp.ipynb.REMOVED.git-id b/comp.ipynb.REMOVED.git-id index 6ca72da..21ecac1 100644 --- a/comp.ipynb.REMOVED.git-id +++ b/comp.ipynb.REMOVED.git-id @@ -1 +1 @@ -5ef2ef54312186cd3e3162869c4f237b69de3b1e \ No newline at end of file +0f3136c724eea42fdf1ee15e721ef33604e9a46d \ No newline at end of file diff --git a/src/train.py b/src/train.py index 1f72d51..bb0c2e4 100644 --- a/src/train.py +++ b/src/train.py @@ -14,13 +14,7 @@ CONFIG = { "DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/val/", "DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/", "DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres_prod/", - # "FEATURES": [1, 2, 4, 8], - # "FEATURES": [4, 8, 16, 32], "FEATURES": [8, 16, 32, 64], - # "FEATURES": [4, 8, 16, 32, 64], - # "FEATURES": [8, 16, 32, 64, 128], - # "FEATURES": [16, 32, 64, 128], - # "FEATURES": [64, 128, 256, 512], "N_CHANNELS": 3, "N_CLASSES": 1, "AMP": True, diff --git a/src/utils/paste.py b/src/utils/paste.py index 0c4db00..bf662eb 100644 --- a/src/utils/paste.py +++ b/src/utils/paste.py @@ -3,7 +3,8 @@ from pathlib import Path import albumentations as A import numpy as np -from PIL import Image, ImageEnhance +import torchvision.transforms as T +from PIL import Image class RandomPaste(A.DualTransform): @@ -38,105 +39,139 @@ class RandomPaste(A.DualTransform): def targets_as_params(self): return ["image"] - def apply(self, img, positions, paste_img, paste_mask, **params): + def apply(self, img, augmentations, paste_img, paste_mask, **params): # convert img to Image, needed for `paste` function img = Image.fromarray(img) + # copy paste_img and paste_mask + paste_mask = paste_mask.copy() + paste_img = paste_img.copy() + # paste spheres - for pos in positions: - img.paste(paste_img, pos, paste_mask) + for (x, y, shearx, sheary, shape, angle, brightness, contrast) in augmentations: + paste_img = T.functional.adjust_contrast( + paste_img, + contrast_factor=contrast, + ) + paste_img = T.functional.adjust_brightness( + paste_img, + brightness_factor=brightness, + ) + paste_img = T.functional.affine( + paste_img, + scale=0.95, + angle=angle, + translate=(0, 0), + shear=(shearx, sheary), + interpolation=T.InterpolationMode.BICUBIC, + ) + paste_img = T.functional.resize( + paste_img, + size=shape, + interpolation=T.InterpolationMode.BICUBIC, + ) + + paste_mask = T.functional.affine( + paste_mask, + scale=0.95, + angle=angle, + translate=(0, 0), + shear=(shearx, sheary), + interpolation=T.InterpolationMode.BICUBIC, + ) + paste_mask = T.functional.resize( + paste_mask, + size=shape, + interpolation=T.InterpolationMode.BICUBIC, + ) + + img.paste(paste_img, (x, y), paste_mask) return np.asarray(img.convert("RGB")) - def apply_to_mask(self, mask, positions, paste_mask, **params): + def apply_to_mask(self, mask, augmentations, paste_mask, **params): # convert mask to Image, needed for `paste` function mask = Image.fromarray(mask) - # binarize the mask -> {0, 1} - paste_mask_bin = paste_mask.point(lambda p: 1 if p > 10 else 0) + # copy paste_img and paste_mask + paste_mask = paste_mask.copy() - # paste spheres - for pos in positions: - mask.paste(paste_mask, pos, paste_mask_bin) + for (x, y, shearx, sheary, shape, angle, _, _) in augmentations: + paste_mask = T.functional.affine( + paste_mask, + scale=0.95, + angle=angle, + translate=(0, 0), + shear=(shearx, sheary), + interpolation=T.InterpolationMode.BICUBIC, + ) + paste_mask = T.functional.resize( + paste_mask, + size=shape, + interpolation=T.InterpolationMode.BICUBIC, + ) + + # binarize the mask -> {0, 1} + paste_mask_bin = paste_mask.point(lambda p: 1 if p > 10 else 0) + + mask.paste(paste_mask, (x, y), paste_mask_bin) return np.asarray(mask.convert("L")) - @staticmethod - def overlap(positions, x1, y1, w, h): - for x2, y2 in positions: - if x1 + w >= x2 and x1 <= x2 + w and y1 + h >= y2 and y1 <= y2 + h: - return True - return False - def get_params_dependent_on_targets(self, params): # choose a random image and its corresponding mask img_path = rd.choice(self.images) mask_path = img_path.parent.joinpath("MASK.PNG") - # load the "paste" image + # load images (w/ transparency) paste_img = Image.open(img_path).convert("RGBA") - - # load its respective mask paste_mask = Image.open(mask_path).convert("LA") - - # load the target image target_img = params["image"] - # compute shapes, for easier computations + # compute shapes target_shape = np.array(target_img.shape[:2], dtype=np.uint) paste_shape = np.array(paste_img.size, dtype=np.uint) - # change paste_img's contrast randomly - filter = ImageEnhance.Contrast(paste_img) - paste_img = filter.enhance(rd.uniform(0.5, 1.5)) - - # change paste_img's brightness randomly - filter = ImageEnhance.Brightness(paste_img) - paste_img = filter.enhance(rd.uniform(0.5, 1.5)) - - # compute the minimum scaling to fit inside target image + # compute minimum scaling to fit inside target min_scale = np.min(target_shape / paste_shape) - # randomize the relative scaling - scale = rd.uniform(*self.scale_range) - - # rotate the image and its mask - angle = rd.uniform(0, 360) - paste_img = paste_img.rotate(angle, expand=True) - paste_mask = paste_mask.rotate(angle, expand=True) - - # scale the "paste" image and its mask - paste_img = paste_img.resize( - tuple((paste_shape * min_scale * scale).astype(np.uint)), - resample=Image.Resampling.LANCZOS, - ) - paste_mask = paste_mask.resize( - tuple((paste_shape * min_scale * scale).astype(np.uint)), - resample=Image.Resampling.LANCZOS, - ) - - # update paste_shape after scaling - paste_shape = np.array(paste_img.size, dtype=np.uint) - - # generate some positions - positions = [] + # generate augmentations + augmentations = [] NB = rd.randint(1, self.nb) - while len(positions) < NB: - x = rd.randint(0, target_shape[0] - paste_shape[0]) - y = rd.randint(0, target_shape[1] - paste_shape[1]) + while len(augmentations) < NB: # TODO: mettre une condition d'arret ite max + scale = rd.uniform(*self.scale_range) * min_scale + shape = np.array(paste_shape * scale, dtype=np.uint) + + x = rd.randint(0, target_shape[0] - shape[0]) + y = rd.randint(0, target_shape[1] - shape[1]) # check for overlapping - if RandomPaste.overlap(positions, x, y, paste_shape[0], paste_shape[1]): + if RandomPaste.overlap(augmentations, x, y, shape[0], shape[1]): continue - positions.append((x, y)) + shearx = rd.uniform(-2, 2) + sheary = rd.uniform(-2, 2) + + angle = rd.uniform(0, 360) + + brightness = rd.uniform(0.8, 1.2) + contrast = rd.uniform(0.8, 1.2) + + augmentations.append((x, y, shearx, sheary, tuple(shape), angle, brightness, contrast)) params.update( { - "positions": positions, + "augmentations": augmentations, "paste_img": paste_img, "paste_mask": paste_mask, } ) return params + + @staticmethod + def overlap(positions, x1, y1, w, h): + for x2, y2, _, _, _, _, _, _ in positions: + if x1 + w >= x2 and x1 <= x2 + w and y1 + h >= y2 and y1 <= y2 + h: + return True + return False