From 2395e666d54ef790ffcb63a4454792dedce8a4d5 Mon Sep 17 00:00:00 2001 From: Laurent Date: Mon, 9 Sep 2024 14:30:13 +0000 Subject: [PATCH] update `gaussian_blur` fluxion util, see https://github.com/pytorch/vision/commit/45e053b2ae470cc3ad69c3ac9c378eeff413c858 --- src/refiners/fluxion/utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index 824cae8..d3dc0d4 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -70,9 +70,11 @@ def gaussian_blur( ) -> Float[Tensor, "*batch channels height width"]: assert torch.is_floating_point(tensor) - def get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Float[Tensor, "kernel_size"]: + def get_gaussian_kernel1d( + kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device + ) -> Float[Tensor, "kernel_size"]: ksize_half = (kernel_size - 1) * 0.5 - x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size, device=device, dtype=dtype) pdf = torch.exp(-0.5 * (x / sigma).pow(2)) kernel1d = pdf / pdf.sum() return kernel1d @@ -80,8 +82,8 @@ def gaussian_blur( def get_gaussian_kernel2d( kernel_size_x: int, kernel_size_y: int, sigma_x: float, sigma_y: float, dtype: DType, device: Device ) -> Float[Tensor, "kernel_size_y kernel_size_x"]: - kernel1d_x = get_gaussian_kernel1d(kernel_size_x, sigma_x).to(device, dtype=dtype) - kernel1d_y = get_gaussian_kernel1d(kernel_size_y, sigma_y).to(device, dtype=dtype) + kernel1d_x = get_gaussian_kernel1d(kernel_size_x, sigma_x, dtype, device) + kernel1d_y = get_gaussian_kernel1d(kernel_size_y, sigma_y, dtype, device) kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :]) return kernel2d