refiners/tests/e2e/test_preprocessors.py
Pierre Chapuis 97b162d9a0 add InformativeDrawings
https://github.com/carolineec/informative-drawings

This is the preprocessor for the Lineart ControlNet.
2023-08-16 12:29:09 +02:00

57 lines
2 KiB
Python

import torch
import pytest
from warnings import warn
from PIL import Image
from pathlib import Path
from refiners.fluxion.utils import load_from_safetensors, image_to_tensor, tensor_to_image
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
from tests.utils import ensure_similar_images
@pytest.fixture(scope="module")
def diffusion_ref_path(test_e2e_path: Path) -> Path:
return test_e2e_path / "test_diffusion_ref"
@pytest.fixture(scope="module")
def cutecat_init(diffusion_ref_path: Path) -> Image.Image:
return Image.open(diffusion_ref_path / "cutecat_init.png").convert("RGB")
@pytest.fixture
def expected_image_informative_drawings(diffusion_ref_path: Path) -> Image.Image:
return Image.open(diffusion_ref_path / "cutecat_guide_lineart.png").convert("RGB")
@pytest.fixture(scope="module")
def informative_drawings_weights(test_weights_path: Path) -> Path:
weights = test_weights_path / "informative-drawings.safetensors"
if not weights.is_file():
warn(f"could not find weights at {test_weights_path}, skipping")
pytest.skip(allow_module_level=True)
return weights
@pytest.fixture
def informative_drawings_model(informative_drawings_weights: Path, test_device: torch.device) -> InformativeDrawings:
model = InformativeDrawings(device=test_device)
model.load_state_dict(load_from_safetensors(informative_drawings_weights))
return model
@torch.no_grad()
def test_preprocessor_informative_drawing(
informative_drawings_model: InformativeDrawings,
cutecat_init: Image.Image,
expected_image_informative_drawings: Image.Image,
test_device: torch.device,
):
in_tensor = image_to_tensor(cutecat_init.convert("RGB"), device=test_device)
out_tensor = informative_drawings_model(in_tensor)
rgb_tensor = out_tensor.repeat(1, 3, 1, 1) # grayscale to RGB
image = tensor_to_image(rgb_tensor)
ensure_similar_images(image, expected_image_informative_drawings)