feat: support multi image pasting
This commit is contained in:
parent
c374673786
commit
57853be03e
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -1,5 +1,9 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import random as rd
|
import random as rd
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -23,162 +27,201 @@ class RandomPaste(A.DualTransform):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
nb,
|
nb,
|
||||||
image_dir,
|
sphere_image_dir,
|
||||||
scale_range=(0.02, 0.3),
|
chrome_sphere_image_dir,
|
||||||
|
scale_range=(0.05, 0.3),
|
||||||
always_apply=True,
|
always_apply=True,
|
||||||
p=1.0,
|
p=1.0,
|
||||||
):
|
):
|
||||||
super().__init__(always_apply, p)
|
super().__init__(always_apply, p)
|
||||||
self.images = []
|
|
||||||
self.images.extend(list(Path(image_dir).glob("**/*.jpg")))
|
self.sphere_images = []
|
||||||
self.images.extend(list(Path(image_dir).glob("**/*.png")))
|
self.sphere_images.extend(list(Path(sphere_image_dir).glob("**/*.jpg")))
|
||||||
|
self.sphere_images.extend(list(Path(sphere_image_dir).glob("**/*.png")))
|
||||||
|
|
||||||
|
self.chrome_sphere_images = []
|
||||||
|
self.chrome_sphere_images.extend(list(Path(chrome_sphere_image_dir).glob("**/*.jpg")))
|
||||||
|
self.chrome_sphere_images.extend(list(Path(chrome_sphere_image_dir).glob("**/*.png")))
|
||||||
|
|
||||||
self.scale_range = scale_range
|
self.scale_range = scale_range
|
||||||
self.nb = nb
|
self.nb = nb
|
||||||
|
|
||||||
|
self.augmentation_datas: List[AugmentationData] = []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def targets_as_params(self):
|
def targets_as_params(self):
|
||||||
return ["image"]
|
return ["image"]
|
||||||
|
|
||||||
def apply(self, img, augmentations, paste_img, paste_mask, **params):
|
def apply(self, img, **params):
|
||||||
# convert img to Image, needed for `paste` function
|
# convert img to Image, needed for `paste` function
|
||||||
img = Image.fromarray(img)
|
img = Image.fromarray(img)
|
||||||
|
|
||||||
# copy paste_img and paste_mask
|
|
||||||
paste_mask = paste_mask.copy()
|
|
||||||
paste_img = paste_img.copy()
|
|
||||||
|
|
||||||
# paste spheres
|
# paste spheres
|
||||||
for (x, y, shearx, sheary, shape, angle, brightness, contrast) in augmentations:
|
for augmentation in self.augmentation_datas:
|
||||||
paste_img_aug = T.functional.adjust_contrast(
|
paste_img_aug = T.functional.adjust_contrast(
|
||||||
paste_img,
|
augmentation.paste_img,
|
||||||
contrast_factor=contrast,
|
contrast_factor=augmentation.contrast,
|
||||||
)
|
)
|
||||||
paste_img_aug = T.functional.adjust_brightness(
|
paste_img_aug = T.functional.adjust_brightness(
|
||||||
paste_img_aug,
|
paste_img_aug,
|
||||||
brightness_factor=brightness,
|
brightness_factor=augmentation.brightness,
|
||||||
)
|
)
|
||||||
paste_img_aug = T.functional.affine(
|
paste_img_aug = T.functional.affine(
|
||||||
paste_img_aug,
|
paste_img_aug,
|
||||||
scale=0.95,
|
scale=0.95,
|
||||||
angle=angle,
|
|
||||||
translate=(0, 0),
|
translate=(0, 0),
|
||||||
shear=(shearx, sheary),
|
angle=augmentation.angle,
|
||||||
|
shear=augmentation.shear,
|
||||||
interpolation=T.InterpolationMode.BICUBIC,
|
interpolation=T.InterpolationMode.BICUBIC,
|
||||||
)
|
)
|
||||||
paste_img_aug = T.functional.resize(
|
paste_img_aug = T.functional.resize(
|
||||||
paste_img_aug,
|
paste_img_aug,
|
||||||
size=shape,
|
size=augmentation.shape,
|
||||||
interpolation=T.InterpolationMode.LANCZOS,
|
interpolation=T.InterpolationMode.LANCZOS,
|
||||||
)
|
)
|
||||||
|
|
||||||
paste_mask_aug = T.functional.affine(
|
paste_mask_aug = T.functional.affine(
|
||||||
paste_mask,
|
augmentation.paste_mask,
|
||||||
scale=0.95,
|
scale=0.95,
|
||||||
angle=angle,
|
|
||||||
translate=(0, 0),
|
translate=(0, 0),
|
||||||
shear=(shearx, sheary),
|
angle=augmentation.angle,
|
||||||
|
shear=augmentation.shear,
|
||||||
interpolation=T.InterpolationMode.BICUBIC,
|
interpolation=T.InterpolationMode.BICUBIC,
|
||||||
)
|
)
|
||||||
paste_mask_aug = T.functional.resize(
|
paste_mask_aug = T.functional.resize(
|
||||||
paste_mask_aug,
|
paste_mask_aug,
|
||||||
size=shape,
|
size=augmentation.shape,
|
||||||
interpolation=T.InterpolationMode.LANCZOS,
|
interpolation=T.InterpolationMode.LANCZOS,
|
||||||
)
|
)
|
||||||
|
|
||||||
img.paste(paste_img_aug, (x, y), paste_mask_aug)
|
img.paste(paste_img_aug, augmentation.position, paste_mask_aug)
|
||||||
|
|
||||||
return np.array(img.convert("RGB"))
|
return np.array(img.convert("RGB"))
|
||||||
|
|
||||||
def apply_to_mask(self, mask, augmentations, paste_mask, **params):
|
def apply_to_mask(self, mask, **params):
|
||||||
# convert mask to Image, needed for `paste` function
|
# convert mask to Image, needed for `paste` function
|
||||||
mask = Image.fromarray(mask)
|
mask = Image.fromarray(mask)
|
||||||
|
|
||||||
# copy paste_img and paste_mask
|
for augmentation in self.augmentation_datas:
|
||||||
paste_mask = paste_mask.copy()
|
|
||||||
|
|
||||||
for i, (x, y, shearx, sheary, shape, angle, _, _) in enumerate(augmentations):
|
|
||||||
paste_mask_aug = T.functional.affine(
|
paste_mask_aug = T.functional.affine(
|
||||||
paste_mask,
|
augmentation.paste_mask,
|
||||||
scale=0.95,
|
scale=0.95,
|
||||||
angle=angle,
|
|
||||||
translate=(0, 0),
|
translate=(0, 0),
|
||||||
shear=(shearx, sheary),
|
angle=augmentation.angle,
|
||||||
|
shear=augmentation.shear,
|
||||||
interpolation=T.InterpolationMode.BICUBIC,
|
interpolation=T.InterpolationMode.BICUBIC,
|
||||||
)
|
)
|
||||||
paste_mask_aug = T.functional.resize(
|
paste_mask_aug = T.functional.resize(
|
||||||
paste_mask_aug,
|
paste_mask_aug,
|
||||||
size=shape,
|
size=augmentation.shape,
|
||||||
interpolation=T.InterpolationMode.LANCZOS,
|
interpolation=T.InterpolationMode.LANCZOS,
|
||||||
)
|
)
|
||||||
|
|
||||||
# binarize the mask -> {0, 1}
|
# binarize the mask -> {0, 1}
|
||||||
paste_mask_aug_bin = paste_mask_aug.point(lambda p: i + 1 if p > 10 else 0)
|
paste_mask_aug_bin = paste_mask_aug.point(lambda p: augmentation.value if p > 10 else 0)
|
||||||
|
|
||||||
mask.paste(paste_mask_aug, (x, y), paste_mask_aug_bin)
|
mask.paste(paste_mask_aug, augmentation.position, paste_mask_aug_bin)
|
||||||
|
|
||||||
return np.array(mask.convert("L"))
|
return np.array(mask.convert("L"))
|
||||||
|
|
||||||
def get_params_dependent_on_targets(self, params):
|
def get_params_dependent_on_targets(self, params):
|
||||||
# choose a random image and its corresponding mask
|
|
||||||
img_path = rd.choice(self.images)
|
|
||||||
mask_path = img_path.parent.joinpath("MASK.PNG")
|
|
||||||
|
|
||||||
# load images (w/ transparency)
|
# load target image (w/ transparency)
|
||||||
paste_img = Image.open(img_path).convert("RGBA")
|
|
||||||
paste_mask = Image.open(mask_path).convert("LA")
|
|
||||||
target_img = params["image"]
|
target_img = params["image"]
|
||||||
|
|
||||||
# compute shapes
|
|
||||||
target_shape = np.array(target_img.shape[:2], dtype=np.uint)
|
target_shape = np.array(target_img.shape[:2], dtype=np.uint)
|
||||||
paste_shape = np.array(paste_img.size, dtype=np.uint)
|
|
||||||
|
|
||||||
# compute minimum scaling to fit inside target
|
|
||||||
min_scale = np.min(target_shape / paste_shape)
|
|
||||||
|
|
||||||
# generate augmentations
|
# generate augmentations
|
||||||
augmentations = []
|
|
||||||
NB = rd.randint(1, self.nb)
|
|
||||||
ite = 0
|
ite = 0
|
||||||
while len(augmentations) < NB:
|
NB = rd.randint(1, self.nb)
|
||||||
|
while len(self.augmentation_datas) < NB:
|
||||||
if ite > 100:
|
if ite > 100:
|
||||||
break
|
break
|
||||||
|
else:
|
||||||
|
ite += 1
|
||||||
|
|
||||||
|
# choose a random sphere image and its corresponding mask
|
||||||
|
if rd.random() > 0.5:
|
||||||
|
img_path = rd.choice(self.sphere_images)
|
||||||
|
value = len(self.augmentation_datas) + 1
|
||||||
|
else:
|
||||||
|
img_path = rd.choice(self.chrome_sphere_images)
|
||||||
|
value = 255 - len(self.augmentation_datas)
|
||||||
|
mask_path = img_path.parent.joinpath("MASK.PNG")
|
||||||
|
|
||||||
|
# load paste assets
|
||||||
|
paste_img = Image.open(img_path).convert("RGBA")
|
||||||
|
paste_shape = np.array(paste_img.size, dtype=np.uint)
|
||||||
|
paste_mask = Image.open(mask_path).convert("LA")
|
||||||
|
|
||||||
|
# compute minimum scaling to fit inside target
|
||||||
|
min_scale = np.min(target_shape / paste_shape)
|
||||||
|
|
||||||
|
# randomly scale image inside target
|
||||||
scale = rd.uniform(*self.scale_range) * min_scale
|
scale = rd.uniform(*self.scale_range) * min_scale
|
||||||
shape = np.array(paste_shape * scale, dtype=np.uint)
|
shape = np.array(paste_shape * scale, dtype=np.uint)
|
||||||
|
|
||||||
x = rd.randint(0, target_shape[1] - shape[1])
|
try:
|
||||||
y = rd.randint(0, target_shape[0] - shape[0])
|
self.augmentation_datas.append(
|
||||||
|
AugmentationData(
|
||||||
# check for overlapping
|
position=(
|
||||||
if RandomPaste.overlap(augmentations, x, y, shape[1], shape[0]):
|
rd.randint(0, target_shape[1] - shape[1]),
|
||||||
|
rd.randint(0, target_shape[0] - shape[0]),
|
||||||
|
),
|
||||||
|
shear=(
|
||||||
|
rd.uniform(-2, 2),
|
||||||
|
rd.uniform(-2, 2),
|
||||||
|
),
|
||||||
|
shape=tuple(shape),
|
||||||
|
angle=rd.uniform(0, 360),
|
||||||
|
brightness=rd.uniform(0.8, 1.2),
|
||||||
|
contrast=rd.uniform(0.8, 1.2),
|
||||||
|
paste_img=paste_img,
|
||||||
|
paste_mask=paste_mask,
|
||||||
|
value=value,
|
||||||
|
other_augmentations=self.augmentation_datas,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
shearx = rd.uniform(-2, 2)
|
|
||||||
sheary = rd.uniform(-2, 2)
|
|
||||||
|
|
||||||
angle = rd.uniform(0, 360)
|
|
||||||
|
|
||||||
brightness = rd.uniform(0.8, 1.2)
|
|
||||||
contrast = rd.uniform(0.8, 1.2)
|
|
||||||
|
|
||||||
augmentations.append((x, y, shearx, sheary, tuple(shape), angle, brightness, contrast))
|
|
||||||
|
|
||||||
ite += 1
|
|
||||||
|
|
||||||
params.update(
|
|
||||||
{
|
|
||||||
"augmentations": augmentations,
|
|
||||||
"paste_img": paste_img,
|
|
||||||
"paste_mask": paste_mask,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def overlap(positions, x1, y1, w1, h1):
|
@dataclass
|
||||||
for x2, y2, _, _, (w2, h2), _, _, _ in positions:
|
class AugmentationData:
|
||||||
if x1 + w1 >= x2 and x1 <= x2 + w2 and y1 + h1 >= y2 and y1 <= y2 + h2:
|
"""Store data for pasting augmentation."""
|
||||||
return True
|
|
||||||
return False
|
position: Tuple[int, int]
|
||||||
|
|
||||||
|
shape: Tuple[int, int]
|
||||||
|
angle: float
|
||||||
|
|
||||||
|
brightness: float
|
||||||
|
contrast: float
|
||||||
|
|
||||||
|
shear: Tuple[float, float]
|
||||||
|
|
||||||
|
paste_img: Image.Image
|
||||||
|
paste_mask: Image.Image
|
||||||
|
value: int
|
||||||
|
|
||||||
|
other_augmentations: List[AugmentationData]
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
# check for overlapping
|
||||||
|
if overlap(self.other_augmentations, self):
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
|
def overlap(augmentations: List[AugmentationData], augmentation: AugmentationData) -> bool:
|
||||||
|
x1, y1 = augmentation.position
|
||||||
|
w1, h1 = augmentation.shape
|
||||||
|
|
||||||
|
for other_augmentation in augmentations:
|
||||||
|
x2, y2 = other_augmentation.position
|
||||||
|
w2, h2 = other_augmentation.shape
|
||||||
|
|
||||||
|
if x1 + w1 >= x2 and x1 <= x2 + w2 and y1 + h1 >= y2 and y1 <= y2 + h2:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
Loading…
Reference in a new issue