update gaussian_blur fluxion util, see 45e053b2ae

This commit is contained in:
Laurent 2024-09-09 14:30:13 +00:00
parent a51d695523
commit 51a4e9e8ba
No known key found for this signature in database

View file

@ -70,9 +70,11 @@ def gaussian_blur(
) -> Float[Tensor, "*batch channels height width"]:
assert torch.is_floating_point(tensor)
def get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Float[Tensor, "kernel_size"]:
def get_gaussian_kernel1d(
kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device
) -> Float[Tensor, "kernel_size"]:
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size, device=device, dtype=dtype)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
kernel1d = pdf / pdf.sum()
return kernel1d
@ -80,8 +82,8 @@ def gaussian_blur(
def get_gaussian_kernel2d(
kernel_size_x: int, kernel_size_y: int, sigma_x: float, sigma_y: float, dtype: DType, device: Device
) -> Float[Tensor, "kernel_size_y kernel_size_x"]:
kernel1d_x = get_gaussian_kernel1d(kernel_size_x, sigma_x).to(device, dtype=dtype)
kernel1d_y = get_gaussian_kernel1d(kernel_size_y, sigma_y).to(device, dtype=dtype)
kernel1d_x = get_gaussian_kernel1d(kernel_size_x, sigma_x, dtype, device)
kernel1d_y = get_gaussian_kernel1d(kernel_size_y, sigma_y, dtype, device)
kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
return kernel2d