diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index 824cae8..d3dc0d4 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -70,9 +70,11 @@ def gaussian_blur( ) -> Float[Tensor, "*batch channels height width"]: assert torch.is_floating_point(tensor) - def get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Float[Tensor, "kernel_size"]: + def get_gaussian_kernel1d( + kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device + ) -> Float[Tensor, "kernel_size"]: ksize_half = (kernel_size - 1) * 0.5 - x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size, device=device, dtype=dtype) pdf = torch.exp(-0.5 * (x / sigma).pow(2)) kernel1d = pdf / pdf.sum() return kernel1d @@ -80,8 +82,8 @@ def gaussian_blur( def get_gaussian_kernel2d( kernel_size_x: int, kernel_size_y: int, sigma_x: float, sigma_y: float, dtype: DType, device: Device ) -> Float[Tensor, "kernel_size_y kernel_size_x"]: - kernel1d_x = get_gaussian_kernel1d(kernel_size_x, sigma_x).to(device, dtype=dtype) - kernel1d_y = get_gaussian_kernel1d(kernel_size_y, sigma_y).to(device, dtype=dtype) + kernel1d_x = get_gaussian_kernel1d(kernel_size_x, sigma_x, dtype, device) + kernel1d_y = get_gaussian_kernel1d(kernel_size_y, sigma_y, dtype, device) kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :]) return kernel2d