mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 00:58:13 +00:00
make gaussian_blur work with float16
This commit is contained in:
parent
7d2abf6fbc
commit
05126c8f4d
|
@ -91,7 +91,7 @@ def gaussian_blur(
|
||||||
sx, sy = sigma
|
sx, sy = sigma
|
||||||
|
|
||||||
channels = tensor.shape[-3]
|
channels = tensor.shape[-3]
|
||||||
kernel = get_gaussian_kernel2d(kx, ky, sx, sy, dtype=torch.float32, device=tensor.device)
|
kernel = get_gaussian_kernel2d(kx, ky, sx, sy, dtype=tensor.dtype, device=tensor.device)
|
||||||
kernel = kernel.expand(channels, 1, kernel.shape[0], kernel.shape[1])
|
kernel = kernel.expand(channels, 1, kernel.shape[0], kernel.shape[1])
|
||||||
|
|
||||||
# pad = (left, right, top, bottom)
|
# pad = (left, right, top, bottom)
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from warnings import warn
|
||||||
|
|
||||||
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore
|
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore
|
||||||
|
from torch import device as Device, dtype as DType
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -14,6 +16,7 @@ class BlurInput:
|
||||||
image_height: int = 512
|
image_height: int = 512
|
||||||
image_width: int = 512
|
image_width: int = 512
|
||||||
batch_size: int | None = 1
|
batch_size: int | None = 1
|
||||||
|
dtype: DType = torch.float32
|
||||||
|
|
||||||
|
|
||||||
BLUR_INPUTS: list[BlurInput] = [
|
BLUR_INPUTS: list[BlurInput] = [
|
||||||
|
@ -22,6 +25,7 @@ BLUR_INPUTS: list[BlurInput] = [
|
||||||
BlurInput(kernel_size=9, sigma=1.0),
|
BlurInput(kernel_size=9, sigma=1.0),
|
||||||
BlurInput(kernel_size=9, sigma=1.0, image_height=768),
|
BlurInput(kernel_size=9, sigma=1.0, image_height=768),
|
||||||
BlurInput(kernel_size=(9, 5), sigma=(1.0, 0.8)),
|
BlurInput(kernel_size=(9, 5), sigma=(1.0, 0.8)),
|
||||||
|
BlurInput(kernel_size=9, dtype=torch.float16),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,9 +34,12 @@ def blur_input(request: pytest.FixtureRequest) -> BlurInput:
|
||||||
return request.param
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
def test_gaussian_blur(blur_input: BlurInput) -> None:
|
def test_gaussian_blur(test_device: Device, blur_input: BlurInput) -> None:
|
||||||
|
if test_device.type == "cpu" and blur_input.dtype == torch.float16:
|
||||||
|
warn("half float is not supported on the CPU because of `torch.mm`, skipping")
|
||||||
|
pytest.skip()
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
tensor = torch.randn(3, blur_input.image_height, blur_input.image_width)
|
tensor = torch.randn(3, blur_input.image_height, blur_input.image_width, device=test_device, dtype=blur_input.dtype)
|
||||||
if blur_input.batch_size is not None:
|
if blur_input.batch_size is not None:
|
||||||
tensor = tensor.expand(blur_input.batch_size, -1, -1, -1)
|
tensor = tensor.expand(blur_input.batch_size, -1, -1, -1)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue