mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-27 16:38:46 +00:00
127 lines
4.2 KiB
Python
127 lines
4.2 KiB
Python
import pickle
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from warnings import warn
|
|
|
|
import pytest
|
|
import torch
|
|
from PIL import Image
|
|
from torch import device as Device, dtype as DType
|
|
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore
|
|
|
|
from refiners.fluxion import layers as fl
|
|
from refiners.fluxion.utils import (
|
|
gaussian_blur,
|
|
image_to_tensor,
|
|
load_tensors,
|
|
manual_seed,
|
|
no_grad,
|
|
summarize_tensor,
|
|
tensor_to_image,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class BlurInput:
|
|
kernel_size: int | tuple[int, int]
|
|
sigma: float | tuple[float, float] | None = None
|
|
image_height: int = 512
|
|
image_width: int = 512
|
|
batch_size: int | None = 1
|
|
dtype: DType = torch.float32
|
|
|
|
|
|
BLUR_INPUTS: list[BlurInput] = [
|
|
BlurInput(kernel_size=9),
|
|
BlurInput(kernel_size=9, batch_size=None),
|
|
BlurInput(kernel_size=9, sigma=1.0),
|
|
BlurInput(kernel_size=9, sigma=1.0, image_height=768),
|
|
BlurInput(kernel_size=(9, 5), sigma=(1.0, 0.8)),
|
|
BlurInput(kernel_size=9, dtype=torch.float16),
|
|
]
|
|
|
|
|
|
@pytest.fixture(params=BLUR_INPUTS)
|
|
def blur_input(request: pytest.FixtureRequest) -> BlurInput:
|
|
return request.param
|
|
|
|
|
|
def test_gaussian_blur(test_device: Device, blur_input: BlurInput) -> None:
|
|
if test_device.type == "cpu" and blur_input.dtype == torch.float16:
|
|
warn("half float is not supported on the CPU because of `torch.mm`, skipping")
|
|
pytest.skip()
|
|
manual_seed(2)
|
|
tensor = torch.randn(3, blur_input.image_height, blur_input.image_width, device=test_device, dtype=blur_input.dtype)
|
|
if blur_input.batch_size is not None:
|
|
tensor = tensor.expand(blur_input.batch_size, -1, -1, -1)
|
|
|
|
ref_blur = torch_gaussian_blur(tensor, blur_input.kernel_size, blur_input.sigma) # type: ignore
|
|
our_blur = gaussian_blur(tensor, blur_input.kernel_size, blur_input.sigma)
|
|
|
|
assert torch.equal(our_blur, ref_blur)
|
|
|
|
|
|
def test_image_to_tensor() -> None:
|
|
image = Image.new("RGB", (512, 512))
|
|
|
|
assert image_to_tensor(image).shape == (1, 3, 512, 512)
|
|
assert image_to_tensor(image.convert("L")).shape == (1, 1, 512, 512)
|
|
assert image_to_tensor(image.convert("RGBA")).shape == (1, 4, 512, 512)
|
|
|
|
|
|
def test_tensor_to_image() -> None:
|
|
assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB"
|
|
assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L"
|
|
assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA"
|
|
assert tensor_to_image(torch.zeros(1, 3, 512, 512, dtype=torch.bfloat16)).mode == "RGB"
|
|
|
|
|
|
def test_summarize_tensor() -> None:
|
|
assert summarize_tensor(torch.zeros(1, 3, 512, 512).int())
|
|
assert summarize_tensor(torch.zeros(1, 3, 512, 512).float())
|
|
assert summarize_tensor(torch.zeros(1, 3, 512, 512).double())
|
|
assert summarize_tensor(torch.complex(torch.zeros(1, 3, 512, 512), torch.zeros(1, 3, 512, 512)))
|
|
assert summarize_tensor(torch.zeros(1, 3, 512, 512).bfloat16())
|
|
assert summarize_tensor(torch.zeros(1, 3, 512, 512).bool())
|
|
assert summarize_tensor(torch.zeros(1, 0, 512, 512).int())
|
|
|
|
|
|
def test_no_grad() -> None:
|
|
x = torch.randn(1, 1, requires_grad=True)
|
|
|
|
with torch.no_grad():
|
|
y = x + 1
|
|
assert not y.requires_grad
|
|
|
|
with no_grad():
|
|
z = x + 1
|
|
assert not z.requires_grad
|
|
|
|
w = x + 1
|
|
assert w.requires_grad
|
|
|
|
|
|
def test_load_tensors_valid_pickle(tmp_path: Path) -> None:
|
|
pickle_path = tmp_path / "valid.pickle"
|
|
|
|
tensors = {"easy-as.weight": torch.tensor([1.0, 2.0, 3.0])}
|
|
torch.save(tensors, pickle_path) # type: ignore
|
|
loaded_tensor = load_tensors(pickle_path)
|
|
assert torch.equal(loaded_tensor["easy-as.weight"], tensors["easy-as.weight"])
|
|
|
|
tensors = {"easy-as.weight": torch.tensor([1, 2, 3]), "hello": "world"}
|
|
torch.save(tensors, pickle_path) # type: ignore
|
|
|
|
with pytest.raises(AssertionError):
|
|
loaded_tensor = load_tensors(pickle_path)
|
|
|
|
|
|
def test_load_tensors_invalid_pickle(tmp_path: Path) -> None:
|
|
invalid_pickle_path = tmp_path / "invalid.pickle"
|
|
model = fl.Chain(fl.Linear(1, 1))
|
|
torch.save(model, invalid_pickle_path) # type: ignore
|
|
with pytest.raises(
|
|
pickle.UnpicklingError,
|
|
):
|
|
load_tensors(invalid_pickle_path)
|