From 0dfa23fa53967490370057ac31b5969b250ed7b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Thu, 5 Oct 2023 11:13:15 +0200 Subject: [PATCH] fluxion: add gaussian_blur to utils --- src/refiners/fluxion/utils.py | 52 ++++++++++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index fb1ce6f..8c69de2 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -6,7 +6,7 @@ from safetensors import safe_open as _safe_open # type: ignore from safetensors.torch import save_file as _save_file # type: ignore from torch import norm as _norm, manual_seed as _manual_seed # type: ignore import torch -from torch.nn.functional import pad as _pad, interpolate as _interpolate # type: ignore +from torch.nn.functional import pad as _pad, interpolate as _interpolate, conv2d # type: ignore from torch import Tensor, device as Device, dtype as DType from jaxtyping import Float @@ -51,6 +51,56 @@ def normalize( return (tensor - pixel_mean) / pixel_std +# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py +def gaussian_blur( + tensor: Float[Tensor, "*batch channels height width"], + kernel_size: int | tuple[int, int], + sigma: float | tuple[float, float] | None = None, +) -> Float[Tensor, "*batch channels height width"]: + assert torch.is_floating_point(tensor) + + def get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Float[Tensor, "kernel_size"]: + ksize_half = (kernel_size - 1) * 0.5 + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + kernel1d = pdf / pdf.sum() + return kernel1d + + def get_gaussian_kernel2d( + kernel_size_x: int, kernel_size_y: int, sigma_x: float, sigma_y: float, dtype: DType, device: Device + ) -> Float[Tensor, "kernel_size_y kernel_size_x"]: + kernel1d_x = get_gaussian_kernel1d(kernel_size_x, sigma_x).to(device, dtype=dtype) + kernel1d_y = get_gaussian_kernel1d(kernel_size_y, sigma_y).to(device, dtype=dtype) + kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :]) + return kernel2d + + def default_sigma(kernel_size: int) -> float: + return kernel_size * 0.15 + 0.35 + + if isinstance(kernel_size, int): + kx, ky = kernel_size, kernel_size + else: + kx, ky = kernel_size + + if sigma is None: + sx, sy = default_sigma(kx), default_sigma(ky) + elif isinstance(sigma, float): + sx, sy = sigma, sigma + else: + assert isinstance(sigma, tuple) # TODO: remove with pyright +1.1.330 + sx, sy = sigma + + channels = tensor.shape[-3] + kernel = get_gaussian_kernel2d(kx, ky, sx, sy, dtype=torch.float32, device=tensor.device) + kernel = kernel.expand(channels, 1, kernel.shape[0], kernel.shape[1]) + + # pad = (left, right, top, bottom) + tensor = pad(tensor, pad=(kx // 2, kx // 2, ky // 2, ky // 2), mode="reflect") + tensor = conv2d(tensor, weight=kernel, groups=channels) + + return tensor + + def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor: return torch.tensor(array(image).astype(float32).transpose(2, 0, 1) / 255.0, device=device, dtype=dtype).unsqueeze( 0