diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index 8c69de2..a3f8da8 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -91,7 +91,7 @@ def gaussian_blur( sx, sy = sigma channels = tensor.shape[-3] - kernel = get_gaussian_kernel2d(kx, ky, sx, sy, dtype=torch.float32, device=tensor.device) + kernel = get_gaussian_kernel2d(kx, ky, sx, sy, dtype=tensor.dtype, device=tensor.device) kernel = kernel.expand(channels, 1, kernel.shape[0], kernel.shape[1]) # pad = (left, right, top, bottom) diff --git a/tests/fluxion/test_utils.py b/tests/fluxion/test_utils.py index ff92f2e..c34c169 100644 --- a/tests/fluxion/test_utils.py +++ b/tests/fluxion/test_utils.py @@ -1,6 +1,8 @@ from dataclasses import dataclass +from warnings import warn from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore +from torch import device as Device, dtype as DType import pytest import torch @@ -14,6 +16,7 @@ class BlurInput: image_height: int = 512 image_width: int = 512 batch_size: int | None = 1 + dtype: DType = torch.float32 BLUR_INPUTS: list[BlurInput] = [ @@ -22,6 +25,7 @@ BLUR_INPUTS: list[BlurInput] = [ BlurInput(kernel_size=9, sigma=1.0), BlurInput(kernel_size=9, sigma=1.0, image_height=768), BlurInput(kernel_size=(9, 5), sigma=(1.0, 0.8)), + BlurInput(kernel_size=9, dtype=torch.float16), ] @@ -30,9 +34,12 @@ def blur_input(request: pytest.FixtureRequest) -> BlurInput: return request.param -def test_gaussian_blur(blur_input: BlurInput) -> None: +def test_gaussian_blur(test_device: Device, blur_input: BlurInput) -> None: + if test_device.type == "cpu" and blur_input.dtype == torch.float16: + warn("half float is not supported on the CPU because of `torch.mm`, skipping") + pytest.skip() manual_seed(2) - tensor = torch.randn(3, blur_input.image_height, blur_input.image_width) + tensor = torch.randn(3, blur_input.image_height, blur_input.image_width, device=test_device, dtype=blur_input.dtype) if blur_input.batch_size is not None: tensor = tensor.expand(blur_input.batch_size, -1, -1, -1)