feat: support multi image pasting

This commit is contained in:
Laurent Fainsin 2022-09-14 16:23:19 +02:00
parent c374673786
commit 57853be03e
3 changed files with 150 additions and 127 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1,5 +1,9 @@
from __future__ import annotations
import random as rd
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple
import albumentations as A
import numpy as np
@ -23,162 +27,201 @@ class RandomPaste(A.DualTransform):
def __init__(
self,
nb,
image_dir,
scale_range=(0.02, 0.3),
sphere_image_dir,
chrome_sphere_image_dir,
scale_range=(0.05, 0.3),
always_apply=True,
p=1.0,
):
super().__init__(always_apply, p)
self.images = []
self.images.extend(list(Path(image_dir).glob("**/*.jpg")))
self.images.extend(list(Path(image_dir).glob("**/*.png")))
self.sphere_images = []
self.sphere_images.extend(list(Path(sphere_image_dir).glob("**/*.jpg")))
self.sphere_images.extend(list(Path(sphere_image_dir).glob("**/*.png")))
self.chrome_sphere_images = []
self.chrome_sphere_images.extend(list(Path(chrome_sphere_image_dir).glob("**/*.jpg")))
self.chrome_sphere_images.extend(list(Path(chrome_sphere_image_dir).glob("**/*.png")))
self.scale_range = scale_range
self.nb = nb
self.augmentation_datas: List[AugmentationData] = []
@property
def targets_as_params(self):
return ["image"]
def apply(self, img, augmentations, paste_img, paste_mask, **params):
def apply(self, img, **params):
# convert img to Image, needed for `paste` function
img = Image.fromarray(img)
# copy paste_img and paste_mask
paste_mask = paste_mask.copy()
paste_img = paste_img.copy()
# paste spheres
for (x, y, shearx, sheary, shape, angle, brightness, contrast) in augmentations:
for augmentation in self.augmentation_datas:
paste_img_aug = T.functional.adjust_contrast(
paste_img,
contrast_factor=contrast,
augmentation.paste_img,
contrast_factor=augmentation.contrast,
)
paste_img_aug = T.functional.adjust_brightness(
paste_img_aug,
brightness_factor=brightness,
brightness_factor=augmentation.brightness,
)
paste_img_aug = T.functional.affine(
paste_img_aug,
scale=0.95,
angle=angle,
translate=(0, 0),
shear=(shearx, sheary),
angle=augmentation.angle,
shear=augmentation.shear,
interpolation=T.InterpolationMode.BICUBIC,
)
paste_img_aug = T.functional.resize(
paste_img_aug,
size=shape,
size=augmentation.shape,
interpolation=T.InterpolationMode.LANCZOS,
)
paste_mask_aug = T.functional.affine(
paste_mask,
augmentation.paste_mask,
scale=0.95,
angle=angle,
translate=(0, 0),
shear=(shearx, sheary),
angle=augmentation.angle,
shear=augmentation.shear,
interpolation=T.InterpolationMode.BICUBIC,
)
paste_mask_aug = T.functional.resize(
paste_mask_aug,
size=shape,
size=augmentation.shape,
interpolation=T.InterpolationMode.LANCZOS,
)
img.paste(paste_img_aug, (x, y), paste_mask_aug)
img.paste(paste_img_aug, augmentation.position, paste_mask_aug)
return np.array(img.convert("RGB"))
def apply_to_mask(self, mask, augmentations, paste_mask, **params):
def apply_to_mask(self, mask, **params):
# convert mask to Image, needed for `paste` function
mask = Image.fromarray(mask)
# copy paste_img and paste_mask
paste_mask = paste_mask.copy()
for i, (x, y, shearx, sheary, shape, angle, _, _) in enumerate(augmentations):
for augmentation in self.augmentation_datas:
paste_mask_aug = T.functional.affine(
paste_mask,
augmentation.paste_mask,
scale=0.95,
angle=angle,
translate=(0, 0),
shear=(shearx, sheary),
angle=augmentation.angle,
shear=augmentation.shear,
interpolation=T.InterpolationMode.BICUBIC,
)
paste_mask_aug = T.functional.resize(
paste_mask_aug,
size=shape,
size=augmentation.shape,
interpolation=T.InterpolationMode.LANCZOS,
)
# binarize the mask -> {0, 1}
paste_mask_aug_bin = paste_mask_aug.point(lambda p: i + 1 if p > 10 else 0)
paste_mask_aug_bin = paste_mask_aug.point(lambda p: augmentation.value if p > 10 else 0)
mask.paste(paste_mask_aug, (x, y), paste_mask_aug_bin)
mask.paste(paste_mask_aug, augmentation.position, paste_mask_aug_bin)
return np.array(mask.convert("L"))
def get_params_dependent_on_targets(self, params):
# choose a random image and its corresponding mask
img_path = rd.choice(self.images)
mask_path = img_path.parent.joinpath("MASK.PNG")
# load images (w/ transparency)
paste_img = Image.open(img_path).convert("RGBA")
paste_mask = Image.open(mask_path).convert("LA")
# load target image (w/ transparency)
target_img = params["image"]
# compute shapes
target_shape = np.array(target_img.shape[:2], dtype=np.uint)
paste_shape = np.array(paste_img.size, dtype=np.uint)
# compute minimum scaling to fit inside target
min_scale = np.min(target_shape / paste_shape)
# generate augmentations
augmentations = []
NB = rd.randint(1, self.nb)
ite = 0
while len(augmentations) < NB:
NB = rd.randint(1, self.nb)
while len(self.augmentation_datas) < NB:
if ite > 100:
break
else:
ite += 1
# choose a random sphere image and its corresponding mask
if rd.random() > 0.5:
img_path = rd.choice(self.sphere_images)
value = len(self.augmentation_datas) + 1
else:
img_path = rd.choice(self.chrome_sphere_images)
value = 255 - len(self.augmentation_datas)
mask_path = img_path.parent.joinpath("MASK.PNG")
# load paste assets
paste_img = Image.open(img_path).convert("RGBA")
paste_shape = np.array(paste_img.size, dtype=np.uint)
paste_mask = Image.open(mask_path).convert("LA")
# compute minimum scaling to fit inside target
min_scale = np.min(target_shape / paste_shape)
# randomly scale image inside target
scale = rd.uniform(*self.scale_range) * min_scale
shape = np.array(paste_shape * scale, dtype=np.uint)
x = rd.randint(0, target_shape[1] - shape[1])
y = rd.randint(0, target_shape[0] - shape[0])
# check for overlapping
if RandomPaste.overlap(augmentations, x, y, shape[1], shape[0]):
try:
self.augmentation_datas.append(
AugmentationData(
position=(
rd.randint(0, target_shape[1] - shape[1]),
rd.randint(0, target_shape[0] - shape[0]),
),
shear=(
rd.uniform(-2, 2),
rd.uniform(-2, 2),
),
shape=tuple(shape),
angle=rd.uniform(0, 360),
brightness=rd.uniform(0.8, 1.2),
contrast=rd.uniform(0.8, 1.2),
paste_img=paste_img,
paste_mask=paste_mask,
value=value,
other_augmentations=self.augmentation_datas,
)
)
except ValueError:
continue
shearx = rd.uniform(-2, 2)
sheary = rd.uniform(-2, 2)
angle = rd.uniform(0, 360)
brightness = rd.uniform(0.8, 1.2)
contrast = rd.uniform(0.8, 1.2)
augmentations.append((x, y, shearx, sheary, tuple(shape), angle, brightness, contrast))
ite += 1
params.update(
{
"augmentations": augmentations,
"paste_img": paste_img,
"paste_mask": paste_mask,
}
)
return params
@staticmethod
def overlap(positions, x1, y1, w1, h1):
for x2, y2, _, _, (w2, h2), _, _, _ in positions:
if x1 + w1 >= x2 and x1 <= x2 + w2 and y1 + h1 >= y2 and y1 <= y2 + h2:
return True
return False
@dataclass
class AugmentationData:
"""Store data for pasting augmentation."""
position: Tuple[int, int]
shape: Tuple[int, int]
angle: float
brightness: float
contrast: float
shear: Tuple[float, float]
paste_img: Image.Image
paste_mask: Image.Image
value: int
other_augmentations: List[AugmentationData]
def __post_init__(self) -> None:
# check for overlapping
if overlap(self.other_augmentations, self):
raise ValueError
def overlap(augmentations: List[AugmentationData], augmentation: AugmentationData) -> bool:
x1, y1 = augmentation.position
w1, h1 = augmentation.shape
for other_augmentation in augmentations:
x2, y2 = other_augmentation.position
w2, h2 = other_augmentation.shape
if x1 + w1 >= x2 and x1 <= x2 + w2 and y1 + h1 >= y2 and y1 <= y2 + h2:
return True
return False