2023-12-11 10:46:38 +00:00
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
2023-08-08 10:17:16 +00:00
|
|
|
from PIL import Image
|
2024-06-24 14:22:11 +00:00
|
|
|
from tests.utils import ensure_similar_images
|
2023-08-08 10:17:16 +00:00
|
|
|
|
2023-12-29 09:59:51 +00:00
|
|
|
from refiners.fluxion.utils import image_to_tensor, no_grad, tensor_to_image
|
2023-08-08 10:17:16 +00:00
|
|
|
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def diffusion_ref_path(test_e2e_path: Path) -> Path:
|
|
|
|
return test_e2e_path / "test_diffusion_ref"
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def cutecat_init(diffusion_ref_path: Path) -> Image.Image:
|
2024-10-15 13:51:19 +00:00
|
|
|
return Image.open(diffusion_ref_path / "cutecat_init.png").convert("RGB")
|
2023-08-08 10:17:16 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def expected_image_informative_drawings(diffusion_ref_path: Path) -> Image.Image:
|
2024-10-15 13:51:19 +00:00
|
|
|
return Image.open(diffusion_ref_path / "cutecat_guide_lineart.png").convert("RGB")
|
2023-08-08 10:17:16 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
2024-10-09 09:28:34 +00:00
|
|
|
def informative_drawings_model(
|
|
|
|
controlnet_preprocessor_info_drawings_weights_path: Path,
|
|
|
|
test_device: torch.device,
|
|
|
|
) -> InformativeDrawings:
|
2023-08-08 10:17:16 +00:00
|
|
|
model = InformativeDrawings(device=test_device)
|
2024-10-09 09:28:34 +00:00
|
|
|
model.load_from_safetensors(controlnet_preprocessor_info_drawings_weights_path)
|
2023-08-08 10:17:16 +00:00
|
|
|
return model
|
|
|
|
|
|
|
|
|
2023-12-29 09:59:51 +00:00
|
|
|
@no_grad()
|
2023-08-08 10:17:16 +00:00
|
|
|
def test_preprocessor_informative_drawing(
|
|
|
|
informative_drawings_model: InformativeDrawings,
|
|
|
|
cutecat_init: Image.Image,
|
|
|
|
expected_image_informative_drawings: Image.Image,
|
|
|
|
test_device: torch.device,
|
|
|
|
):
|
|
|
|
in_tensor = image_to_tensor(cutecat_init.convert("RGB"), device=test_device)
|
|
|
|
out_tensor = informative_drawings_model(in_tensor)
|
|
|
|
rgb_tensor = out_tensor.repeat(1, 3, 1, 1) # grayscale to RGB
|
|
|
|
image = tensor_to_image(rgb_tensor)
|
|
|
|
ensure_similar_images(image, expected_image_informative_drawings)
|