make gaussian_blur work with float16

This commit is contained in:
Cédric Deltheil 2023-10-07 15:39:58 +02:00 committed by Cédric Deltheil
parent 7d2abf6fbc
commit 05126c8f4d
2 changed files with 10 additions and 3 deletions

View file

@ -91,7 +91,7 @@ def gaussian_blur(
sx, sy = sigma sx, sy = sigma
channels = tensor.shape[-3] channels = tensor.shape[-3]
kernel = get_gaussian_kernel2d(kx, ky, sx, sy, dtype=torch.float32, device=tensor.device) kernel = get_gaussian_kernel2d(kx, ky, sx, sy, dtype=tensor.dtype, device=tensor.device)
kernel = kernel.expand(channels, 1, kernel.shape[0], kernel.shape[1]) kernel = kernel.expand(channels, 1, kernel.shape[0], kernel.shape[1])
# pad = (left, right, top, bottom) # pad = (left, right, top, bottom)

View file

@ -1,6 +1,8 @@
from dataclasses import dataclass from dataclasses import dataclass
from warnings import warn
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore
from torch import device as Device, dtype as DType
import pytest import pytest
import torch import torch
@ -14,6 +16,7 @@ class BlurInput:
image_height: int = 512 image_height: int = 512
image_width: int = 512 image_width: int = 512
batch_size: int | None = 1 batch_size: int | None = 1
dtype: DType = torch.float32
BLUR_INPUTS: list[BlurInput] = [ BLUR_INPUTS: list[BlurInput] = [
@ -22,6 +25,7 @@ BLUR_INPUTS: list[BlurInput] = [
BlurInput(kernel_size=9, sigma=1.0), BlurInput(kernel_size=9, sigma=1.0),
BlurInput(kernel_size=9, sigma=1.0, image_height=768), BlurInput(kernel_size=9, sigma=1.0, image_height=768),
BlurInput(kernel_size=(9, 5), sigma=(1.0, 0.8)), BlurInput(kernel_size=(9, 5), sigma=(1.0, 0.8)),
BlurInput(kernel_size=9, dtype=torch.float16),
] ]
@ -30,9 +34,12 @@ def blur_input(request: pytest.FixtureRequest) -> BlurInput:
return request.param return request.param
def test_gaussian_blur(blur_input: BlurInput) -> None: def test_gaussian_blur(test_device: Device, blur_input: BlurInput) -> None:
if test_device.type == "cpu" and blur_input.dtype == torch.float16:
warn("half float is not supported on the CPU because of `torch.mm`, skipping")
pytest.skip()
manual_seed(2) manual_seed(2)
tensor = torch.randn(3, blur_input.image_height, blur_input.image_width) tensor = torch.randn(3, blur_input.image_height, blur_input.image_width, device=test_device, dtype=blur_input.dtype)
if blur_input.batch_size is not None: if blur_input.batch_size is not None:
tensor = tensor.expand(blur_input.batch_size, -1, -1, -1) tensor = tensor.expand(blur_input.batch_size, -1, -1, -1)