mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
fluxion: add gaussian_blur to utils
This commit is contained in:
parent
f4298f87d2
commit
0dfa23fa53
|
@ -6,7 +6,7 @@ from safetensors import safe_open as _safe_open # type: ignore
|
||||||
from safetensors.torch import save_file as _save_file # type: ignore
|
from safetensors.torch import save_file as _save_file # type: ignore
|
||||||
from torch import norm as _norm, manual_seed as _manual_seed # type: ignore
|
from torch import norm as _norm, manual_seed as _manual_seed # type: ignore
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.functional import pad as _pad, interpolate as _interpolate # type: ignore
|
from torch.nn.functional import pad as _pad, interpolate as _interpolate, conv2d # type: ignore
|
||||||
from torch import Tensor, device as Device, dtype as DType
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
from jaxtyping import Float
|
from jaxtyping import Float
|
||||||
|
|
||||||
|
@ -51,6 +51,56 @@ def normalize(
|
||||||
return (tensor - pixel_mean) / pixel_std
|
return (tensor - pixel_mean) / pixel_std
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py
|
||||||
|
def gaussian_blur(
|
||||||
|
tensor: Float[Tensor, "*batch channels height width"],
|
||||||
|
kernel_size: int | tuple[int, int],
|
||||||
|
sigma: float | tuple[float, float] | None = None,
|
||||||
|
) -> Float[Tensor, "*batch channels height width"]:
|
||||||
|
assert torch.is_floating_point(tensor)
|
||||||
|
|
||||||
|
def get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Float[Tensor, "kernel_size"]:
|
||||||
|
ksize_half = (kernel_size - 1) * 0.5
|
||||||
|
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
|
||||||
|
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
||||||
|
kernel1d = pdf / pdf.sum()
|
||||||
|
return kernel1d
|
||||||
|
|
||||||
|
def get_gaussian_kernel2d(
|
||||||
|
kernel_size_x: int, kernel_size_y: int, sigma_x: float, sigma_y: float, dtype: DType, device: Device
|
||||||
|
) -> Float[Tensor, "kernel_size_y kernel_size_x"]:
|
||||||
|
kernel1d_x = get_gaussian_kernel1d(kernel_size_x, sigma_x).to(device, dtype=dtype)
|
||||||
|
kernel1d_y = get_gaussian_kernel1d(kernel_size_y, sigma_y).to(device, dtype=dtype)
|
||||||
|
kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
|
||||||
|
return kernel2d
|
||||||
|
|
||||||
|
def default_sigma(kernel_size: int) -> float:
|
||||||
|
return kernel_size * 0.15 + 0.35
|
||||||
|
|
||||||
|
if isinstance(kernel_size, int):
|
||||||
|
kx, ky = kernel_size, kernel_size
|
||||||
|
else:
|
||||||
|
kx, ky = kernel_size
|
||||||
|
|
||||||
|
if sigma is None:
|
||||||
|
sx, sy = default_sigma(kx), default_sigma(ky)
|
||||||
|
elif isinstance(sigma, float):
|
||||||
|
sx, sy = sigma, sigma
|
||||||
|
else:
|
||||||
|
assert isinstance(sigma, tuple) # TODO: remove with pyright +1.1.330
|
||||||
|
sx, sy = sigma
|
||||||
|
|
||||||
|
channels = tensor.shape[-3]
|
||||||
|
kernel = get_gaussian_kernel2d(kx, ky, sx, sy, dtype=torch.float32, device=tensor.device)
|
||||||
|
kernel = kernel.expand(channels, 1, kernel.shape[0], kernel.shape[1])
|
||||||
|
|
||||||
|
# pad = (left, right, top, bottom)
|
||||||
|
tensor = pad(tensor, pad=(kx // 2, kx // 2, ky // 2, ky // 2), mode="reflect")
|
||||||
|
tensor = conv2d(tensor, weight=kernel, groups=channels)
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
|
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
|
||||||
return torch.tensor(array(image).astype(float32).transpose(2, 0, 1) / 255.0, device=device, dtype=dtype).unsqueeze(
|
return torch.tensor(array(image).astype(float32).transpose(2, 0, 1) / 255.0, device=device, dtype=dtype).unsqueeze(
|
||||||
0
|
0
|
||||||
|
|
Loading…
Reference in a new issue