from dataclasses import dataclass from warnings import warn from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore from torch import device as Device, dtype as DType import pytest import torch from refiners.fluxion.utils import gaussian_blur, manual_seed @dataclass class BlurInput: kernel_size: int | tuple[int, int] sigma: float | tuple[float, float] | None = None image_height: int = 512 image_width: int = 512 batch_size: int | None = 1 dtype: DType = torch.float32 BLUR_INPUTS: list[BlurInput] = [ BlurInput(kernel_size=9), BlurInput(kernel_size=9, batch_size=None), BlurInput(kernel_size=9, sigma=1.0), BlurInput(kernel_size=9, sigma=1.0, image_height=768), BlurInput(kernel_size=(9, 5), sigma=(1.0, 0.8)), BlurInput(kernel_size=9, dtype=torch.float16), ] @pytest.fixture(params=BLUR_INPUTS) def blur_input(request: pytest.FixtureRequest) -> BlurInput: return request.param def test_gaussian_blur(test_device: Device, blur_input: BlurInput) -> None: if test_device.type == "cpu" and blur_input.dtype == torch.float16: warn("half float is not supported on the CPU because of `torch.mm`, skipping") pytest.skip() manual_seed(2) tensor = torch.randn(3, blur_input.image_height, blur_input.image_width, device=test_device, dtype=blur_input.dtype) if blur_input.batch_size is not None: tensor = tensor.expand(blur_input.batch_size, -1, -1, -1) ref_blur = torch_gaussian_blur(tensor, blur_input.kernel_size, blur_input.sigma) # type: ignore our_blur = gaussian_blur(tensor, blur_input.kernel_size, blur_input.sigma) assert torch.equal(our_blur, ref_blur)