mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-12 16:18:22 +00:00
fluxion: add gaussian_blur to utils
This commit is contained in:
parent
f4298f87d2
commit
0dfa23fa53
|
@ -6,7 +6,7 @@ from safetensors import safe_open as _safe_open # type: ignore
|
|||
from safetensors.torch import save_file as _save_file # type: ignore
|
||||
from torch import norm as _norm, manual_seed as _manual_seed # type: ignore
|
||||
import torch
|
||||
from torch.nn.functional import pad as _pad, interpolate as _interpolate # type: ignore
|
||||
from torch.nn.functional import pad as _pad, interpolate as _interpolate, conv2d # type: ignore
|
||||
from torch import Tensor, device as Device, dtype as DType
|
||||
from jaxtyping import Float
|
||||
|
||||
|
@ -51,6 +51,56 @@ def normalize(
|
|||
return (tensor - pixel_mean) / pixel_std
|
||||
|
||||
|
||||
# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py
|
||||
def gaussian_blur(
|
||||
tensor: Float[Tensor, "*batch channels height width"],
|
||||
kernel_size: int | tuple[int, int],
|
||||
sigma: float | tuple[float, float] | None = None,
|
||||
) -> Float[Tensor, "*batch channels height width"]:
|
||||
assert torch.is_floating_point(tensor)
|
||||
|
||||
def get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Float[Tensor, "kernel_size"]:
|
||||
ksize_half = (kernel_size - 1) * 0.5
|
||||
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
|
||||
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
||||
kernel1d = pdf / pdf.sum()
|
||||
return kernel1d
|
||||
|
||||
def get_gaussian_kernel2d(
|
||||
kernel_size_x: int, kernel_size_y: int, sigma_x: float, sigma_y: float, dtype: DType, device: Device
|
||||
) -> Float[Tensor, "kernel_size_y kernel_size_x"]:
|
||||
kernel1d_x = get_gaussian_kernel1d(kernel_size_x, sigma_x).to(device, dtype=dtype)
|
||||
kernel1d_y = get_gaussian_kernel1d(kernel_size_y, sigma_y).to(device, dtype=dtype)
|
||||
kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
|
||||
return kernel2d
|
||||
|
||||
def default_sigma(kernel_size: int) -> float:
|
||||
return kernel_size * 0.15 + 0.35
|
||||
|
||||
if isinstance(kernel_size, int):
|
||||
kx, ky = kernel_size, kernel_size
|
||||
else:
|
||||
kx, ky = kernel_size
|
||||
|
||||
if sigma is None:
|
||||
sx, sy = default_sigma(kx), default_sigma(ky)
|
||||
elif isinstance(sigma, float):
|
||||
sx, sy = sigma, sigma
|
||||
else:
|
||||
assert isinstance(sigma, tuple) # TODO: remove with pyright +1.1.330
|
||||
sx, sy = sigma
|
||||
|
||||
channels = tensor.shape[-3]
|
||||
kernel = get_gaussian_kernel2d(kx, ky, sx, sy, dtype=torch.float32, device=tensor.device)
|
||||
kernel = kernel.expand(channels, 1, kernel.shape[0], kernel.shape[1])
|
||||
|
||||
# pad = (left, right, top, bottom)
|
||||
tensor = pad(tensor, pad=(kx // 2, kx // 2, ky // 2, ky // 2), mode="reflect")
|
||||
tensor = conv2d(tensor, weight=kernel, groups=channels)
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
|
||||
return torch.tensor(array(image).astype(float32).transpose(2, 0, 1) / 255.0, device=device, dtype=dtype).unsqueeze(
|
||||
0
|
||||
|
|
Loading…
Reference in a new issue