fluxion: add gaussian_blur to utils

This commit is contained in:
Cédric Deltheil 2023-10-05 11:13:15 +02:00 committed by Cédric Deltheil
parent f4298f87d2
commit 0dfa23fa53

View file

@ -6,7 +6,7 @@ from safetensors import safe_open as _safe_open # type: ignore
from safetensors.torch import save_file as _save_file # type: ignore
from torch import norm as _norm, manual_seed as _manual_seed # type: ignore
import torch
from torch.nn.functional import pad as _pad, interpolate as _interpolate # type: ignore
from torch.nn.functional import pad as _pad, interpolate as _interpolate, conv2d # type: ignore
from torch import Tensor, device as Device, dtype as DType
from jaxtyping import Float
@ -51,6 +51,56 @@ def normalize(
return (tensor - pixel_mean) / pixel_std
# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py
def gaussian_blur(
tensor: Float[Tensor, "*batch channels height width"],
kernel_size: int | tuple[int, int],
sigma: float | tuple[float, float] | None = None,
) -> Float[Tensor, "*batch channels height width"]:
assert torch.is_floating_point(tensor)
def get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Float[Tensor, "kernel_size"]:
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
kernel1d = pdf / pdf.sum()
return kernel1d
def get_gaussian_kernel2d(
kernel_size_x: int, kernel_size_y: int, sigma_x: float, sigma_y: float, dtype: DType, device: Device
) -> Float[Tensor, "kernel_size_y kernel_size_x"]:
kernel1d_x = get_gaussian_kernel1d(kernel_size_x, sigma_x).to(device, dtype=dtype)
kernel1d_y = get_gaussian_kernel1d(kernel_size_y, sigma_y).to(device, dtype=dtype)
kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
return kernel2d
def default_sigma(kernel_size: int) -> float:
return kernel_size * 0.15 + 0.35
if isinstance(kernel_size, int):
kx, ky = kernel_size, kernel_size
else:
kx, ky = kernel_size
if sigma is None:
sx, sy = default_sigma(kx), default_sigma(ky)
elif isinstance(sigma, float):
sx, sy = sigma, sigma
else:
assert isinstance(sigma, tuple) # TODO: remove with pyright +1.1.330
sx, sy = sigma
channels = tensor.shape[-3]
kernel = get_gaussian_kernel2d(kx, ky, sx, sy, dtype=torch.float32, device=tensor.device)
kernel = kernel.expand(channels, 1, kernel.shape[0], kernel.shape[1])
# pad = (left, right, top, bottom)
tensor = pad(tensor, pad=(kx // 2, kx // 2, ky // 2, ky // 2), mode="reflect")
tensor = conv2d(tensor, weight=kernel, groups=channels)
return tensor
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
return torch.tensor(array(image).astype(float32).transpose(2, 0, 1) / 255.0, device=device, dtype=dtype).unsqueeze(
0