feat: support multi image pasting

This commit is contained in:
Laurent Fainsin 2022-09-14 16:23:19 +02:00
parent c374673786
commit 57853be03e
3 changed files with 150 additions and 127 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1,5 +1,9 @@
from __future__ import annotations
import random as rd import random as rd
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import List, Tuple
import albumentations as A import albumentations as A
import numpy as np import numpy as np
@ -23,162 +27,201 @@ class RandomPaste(A.DualTransform):
def __init__( def __init__(
self, self,
nb, nb,
image_dir, sphere_image_dir,
scale_range=(0.02, 0.3), chrome_sphere_image_dir,
scale_range=(0.05, 0.3),
always_apply=True, always_apply=True,
p=1.0, p=1.0,
): ):
super().__init__(always_apply, p) super().__init__(always_apply, p)
self.images = []
self.images.extend(list(Path(image_dir).glob("**/*.jpg"))) self.sphere_images = []
self.images.extend(list(Path(image_dir).glob("**/*.png"))) self.sphere_images.extend(list(Path(sphere_image_dir).glob("**/*.jpg")))
self.sphere_images.extend(list(Path(sphere_image_dir).glob("**/*.png")))
self.chrome_sphere_images = []
self.chrome_sphere_images.extend(list(Path(chrome_sphere_image_dir).glob("**/*.jpg")))
self.chrome_sphere_images.extend(list(Path(chrome_sphere_image_dir).glob("**/*.png")))
self.scale_range = scale_range self.scale_range = scale_range
self.nb = nb self.nb = nb
self.augmentation_datas: List[AugmentationData] = []
@property @property
def targets_as_params(self): def targets_as_params(self):
return ["image"] return ["image"]
def apply(self, img, augmentations, paste_img, paste_mask, **params): def apply(self, img, **params):
# convert img to Image, needed for `paste` function # convert img to Image, needed for `paste` function
img = Image.fromarray(img) img = Image.fromarray(img)
# copy paste_img and paste_mask
paste_mask = paste_mask.copy()
paste_img = paste_img.copy()
# paste spheres # paste spheres
for (x, y, shearx, sheary, shape, angle, brightness, contrast) in augmentations: for augmentation in self.augmentation_datas:
paste_img_aug = T.functional.adjust_contrast( paste_img_aug = T.functional.adjust_contrast(
paste_img, augmentation.paste_img,
contrast_factor=contrast, contrast_factor=augmentation.contrast,
) )
paste_img_aug = T.functional.adjust_brightness( paste_img_aug = T.functional.adjust_brightness(
paste_img_aug, paste_img_aug,
brightness_factor=brightness, brightness_factor=augmentation.brightness,
) )
paste_img_aug = T.functional.affine( paste_img_aug = T.functional.affine(
paste_img_aug, paste_img_aug,
scale=0.95, scale=0.95,
angle=angle,
translate=(0, 0), translate=(0, 0),
shear=(shearx, sheary), angle=augmentation.angle,
shear=augmentation.shear,
interpolation=T.InterpolationMode.BICUBIC, interpolation=T.InterpolationMode.BICUBIC,
) )
paste_img_aug = T.functional.resize( paste_img_aug = T.functional.resize(
paste_img_aug, paste_img_aug,
size=shape, size=augmentation.shape,
interpolation=T.InterpolationMode.LANCZOS, interpolation=T.InterpolationMode.LANCZOS,
) )
paste_mask_aug = T.functional.affine( paste_mask_aug = T.functional.affine(
paste_mask, augmentation.paste_mask,
scale=0.95, scale=0.95,
angle=angle,
translate=(0, 0), translate=(0, 0),
shear=(shearx, sheary), angle=augmentation.angle,
shear=augmentation.shear,
interpolation=T.InterpolationMode.BICUBIC, interpolation=T.InterpolationMode.BICUBIC,
) )
paste_mask_aug = T.functional.resize( paste_mask_aug = T.functional.resize(
paste_mask_aug, paste_mask_aug,
size=shape, size=augmentation.shape,
interpolation=T.InterpolationMode.LANCZOS, interpolation=T.InterpolationMode.LANCZOS,
) )
img.paste(paste_img_aug, (x, y), paste_mask_aug) img.paste(paste_img_aug, augmentation.position, paste_mask_aug)
return np.array(img.convert("RGB")) return np.array(img.convert("RGB"))
def apply_to_mask(self, mask, augmentations, paste_mask, **params): def apply_to_mask(self, mask, **params):
# convert mask to Image, needed for `paste` function # convert mask to Image, needed for `paste` function
mask = Image.fromarray(mask) mask = Image.fromarray(mask)
# copy paste_img and paste_mask for augmentation in self.augmentation_datas:
paste_mask = paste_mask.copy()
for i, (x, y, shearx, sheary, shape, angle, _, _) in enumerate(augmentations):
paste_mask_aug = T.functional.affine( paste_mask_aug = T.functional.affine(
paste_mask, augmentation.paste_mask,
scale=0.95, scale=0.95,
angle=angle,
translate=(0, 0), translate=(0, 0),
shear=(shearx, sheary), angle=augmentation.angle,
shear=augmentation.shear,
interpolation=T.InterpolationMode.BICUBIC, interpolation=T.InterpolationMode.BICUBIC,
) )
paste_mask_aug = T.functional.resize( paste_mask_aug = T.functional.resize(
paste_mask_aug, paste_mask_aug,
size=shape, size=augmentation.shape,
interpolation=T.InterpolationMode.LANCZOS, interpolation=T.InterpolationMode.LANCZOS,
) )
# binarize the mask -> {0, 1} # binarize the mask -> {0, 1}
paste_mask_aug_bin = paste_mask_aug.point(lambda p: i + 1 if p > 10 else 0) paste_mask_aug_bin = paste_mask_aug.point(lambda p: augmentation.value if p > 10 else 0)
mask.paste(paste_mask_aug, (x, y), paste_mask_aug_bin) mask.paste(paste_mask_aug, augmentation.position, paste_mask_aug_bin)
return np.array(mask.convert("L")) return np.array(mask.convert("L"))
def get_params_dependent_on_targets(self, params): def get_params_dependent_on_targets(self, params):
# choose a random image and its corresponding mask
img_path = rd.choice(self.images) # load target image (w/ transparency)
target_img = params["image"]
target_shape = np.array(target_img.shape[:2], dtype=np.uint)
# generate augmentations
ite = 0
NB = rd.randint(1, self.nb)
while len(self.augmentation_datas) < NB:
if ite > 100:
break
else:
ite += 1
# choose a random sphere image and its corresponding mask
if rd.random() > 0.5:
img_path = rd.choice(self.sphere_images)
value = len(self.augmentation_datas) + 1
else:
img_path = rd.choice(self.chrome_sphere_images)
value = 255 - len(self.augmentation_datas)
mask_path = img_path.parent.joinpath("MASK.PNG") mask_path = img_path.parent.joinpath("MASK.PNG")
# load images (w/ transparency) # load paste assets
paste_img = Image.open(img_path).convert("RGBA") paste_img = Image.open(img_path).convert("RGBA")
paste_mask = Image.open(mask_path).convert("LA")
target_img = params["image"]
# compute shapes
target_shape = np.array(target_img.shape[:2], dtype=np.uint)
paste_shape = np.array(paste_img.size, dtype=np.uint) paste_shape = np.array(paste_img.size, dtype=np.uint)
paste_mask = Image.open(mask_path).convert("LA")
# compute minimum scaling to fit inside target # compute minimum scaling to fit inside target
min_scale = np.min(target_shape / paste_shape) min_scale = np.min(target_shape / paste_shape)
# generate augmentations # randomly scale image inside target
augmentations = []
NB = rd.randint(1, self.nb)
ite = 0
while len(augmentations) < NB:
if ite > 100:
break
scale = rd.uniform(*self.scale_range) * min_scale scale = rd.uniform(*self.scale_range) * min_scale
shape = np.array(paste_shape * scale, dtype=np.uint) shape = np.array(paste_shape * scale, dtype=np.uint)
x = rd.randint(0, target_shape[1] - shape[1]) try:
y = rd.randint(0, target_shape[0] - shape[0]) self.augmentation_datas.append(
AugmentationData(
# check for overlapping position=(
if RandomPaste.overlap(augmentations, x, y, shape[1], shape[0]): rd.randint(0, target_shape[1] - shape[1]),
continue rd.randint(0, target_shape[0] - shape[0]),
),
shearx = rd.uniform(-2, 2) shear=(
sheary = rd.uniform(-2, 2) rd.uniform(-2, 2),
rd.uniform(-2, 2),
angle = rd.uniform(0, 360) ),
shape=tuple(shape),
brightness = rd.uniform(0.8, 1.2) angle=rd.uniform(0, 360),
contrast = rd.uniform(0.8, 1.2) brightness=rd.uniform(0.8, 1.2),
contrast=rd.uniform(0.8, 1.2),
augmentations.append((x, y, shearx, sheary, tuple(shape), angle, brightness, contrast)) paste_img=paste_img,
paste_mask=paste_mask,
ite += 1 value=value,
other_augmentations=self.augmentation_datas,
params.update(
{
"augmentations": augmentations,
"paste_img": paste_img,
"paste_mask": paste_mask,
}
) )
)
except ValueError:
continue
return params return params
@staticmethod
def overlap(positions, x1, y1, w1, h1): @dataclass
for x2, y2, _, _, (w2, h2), _, _, _ in positions: class AugmentationData:
"""Store data for pasting augmentation."""
position: Tuple[int, int]
shape: Tuple[int, int]
angle: float
brightness: float
contrast: float
shear: Tuple[float, float]
paste_img: Image.Image
paste_mask: Image.Image
value: int
other_augmentations: List[AugmentationData]
def __post_init__(self) -> None:
# check for overlapping
if overlap(self.other_augmentations, self):
raise ValueError
def overlap(augmentations: List[AugmentationData], augmentation: AugmentationData) -> bool:
x1, y1 = augmentation.position
w1, h1 = augmentation.shape
for other_augmentation in augmentations:
x2, y2 = other_augmentation.position
w2, h2 = other_augmentation.shape
if x1 + w1 >= x2 and x1 <= x2 + w2 and y1 + h1 >= y2 and y1 <= y2 + h2: if x1 + w1 >= x2 and x1 <= x2 + w2 and y1 + h1 >= y2 and y1 <= y2 + h2:
return True return True
return False return False