diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 495ac2b..91effd5 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -1,4 +1,5 @@ import gc +from dataclasses import dataclass from pathlib import Path from typing import Iterator from warnings import warn @@ -27,6 +28,7 @@ from refiners.foundationals.latent_diffusion.reference_only_control import Refer from refiners.foundationals.latent_diffusion.restart import Restart from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.control_lora import ControlLoraAdapter from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL from tests.utils import ensure_similar_images @@ -185,6 +187,84 @@ def controlnet_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, return cn_name, condition_image, expected_image, weights_path +@dataclass +class ControlLoraConfig: + scale: float + condition_path: str + weights_path: str + + +@dataclass +class ControlLoraResolvedConfig: + scale: float + condition_image: Image.Image + weights_path: Path + + +CONTROL_LORA_CONFIGS: dict[str, dict[str, ControlLoraConfig]] = { + "expected_controllora_PyraCanny.png": { + "PyraCanny": ControlLoraConfig( + scale=1.0, + condition_path="cutecat_guide_PyraCanny.png", + weights_path="refiners_control-lora-canny-rank128.safetensors", + ), + }, + "expected_controllora_CPDS.png": { + "CPDS": ControlLoraConfig( + scale=1.0, + condition_path="cutecat_guide_CPDS.png", + weights_path="refiners_fooocus_xl_cpds_128.safetensors", + ), + }, + "expected_controllora_PyraCanny+CPDS.png": { + "PyraCanny": ControlLoraConfig( + scale=0.55, + condition_path="cutecat_guide_PyraCanny.png", + weights_path="refiners_control-lora-canny-rank128.safetensors", + ), + "CPDS": ControlLoraConfig( + scale=0.55, + condition_path="cutecat_guide_CPDS.png", + weights_path="refiners_fooocus_xl_cpds_128.safetensors", + ), + }, + "expected_controllora_disabled.png": { + "PyraCanny": ControlLoraConfig( + scale=0.0, + condition_path="cutecat_guide_PyraCanny.png", + weights_path="refiners_control-lora-canny-rank128.safetensors", + ), + "CPDS": ControlLoraConfig( + scale=0.0, + condition_path="cutecat_guide_CPDS.png", + weights_path="refiners_fooocus_xl_cpds_128.safetensors", + ), + }, +} + + +@pytest.fixture(params=CONTROL_LORA_CONFIGS.items()) +def controllora_sdxl_config( + request: pytest.FixtureRequest, + ref_path: Path, + test_weights_path: Path, +) -> tuple[Image.Image, dict[str, ControlLoraResolvedConfig]]: + name: str = request.param[0] + configs: dict[str, ControlLoraConfig] = request.param[1] + expected_image = Image.open(ref_path / name).convert("RGB") + + loaded_configs = { + config_name: ControlLoraResolvedConfig( + scale=config.scale, + condition_image=Image.open(ref_path / config.condition_path).convert("RGB"), + weights_path=test_weights_path / "control_lora" / config.weights_path, + ) + for config_name, config in configs.items() + } + + return expected_image, loaded_configs + + @pytest.fixture(scope="module") def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: name = "depth" @@ -1074,6 +1154,79 @@ def test_diffusion_controlnet_stack( ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98) +@no_grad() +def test_diffusion_sdxl_controllora( + controllora_sdxl_config: tuple[Image.Image, dict[str, ControlLoraResolvedConfig]], + sdxl_ddim_lda_fp16_fix: StableDiffusion_XL, +) -> None: + sdxl = sdxl_ddim_lda_fp16_fix.to(dtype=torch.float16) + sdxl.dtype = torch.float16 # FIXME: should not be necessary + + expected_image = controllora_sdxl_config[0] + configs = controllora_sdxl_config[1] + + adapters: dict[str, ControlLoraAdapter] = {} + for config_name, config in configs.items(): + adapter = ControlLoraAdapter( + name=config_name, + scale=config.scale, + target=sdxl.unet, + weights=load_from_safetensors( + path=config.weights_path, + device=sdxl.device, + ), + ) + adapter.set_condition( + image_to_tensor( + image=config.condition_image, + device=sdxl.device, + dtype=sdxl.dtype, + ) + ) + adapters[config_name] = adapter + + # inject all the control lora adapters + for adapter in adapters.values(): + adapter.inject() + + # compute the text embeddings + prompt = "a cute cat, flying in the air, detailed high-quality professional image, blank background" + negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality, watermarks" + clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( + text=prompt, + negative_text=negative_prompt, + ) + + # initialize the latents + manual_seed(2) + x = torch.randn( + (1, 4, 128, 128), + device=sdxl.device, + dtype=sdxl.dtype, + ) + + # denoise + for step in sdxl.steps: + x = sdxl( + x, + step=step, + clip_text_embedding=clip_text_embedding, + pooled_text_embedding=pooled_text_embedding, + time_ids=sdxl.default_time_ids, + ) + + # decode latent to image + predicted_image = sdxl.lda.decode_latents(x) + + # ensure the predicted image is similar to the expected image + ensure_similar_images( + img_1=predicted_image, + img_2=expected_image, + min_psnr=35, + min_ssim=0.99, + ) + + @no_grad() def test_diffusion_lora( sd15_std: StableDiffusion_1, diff --git a/tests/e2e/test_diffusion_ref/README.md b/tests/e2e/test_diffusion_ref/README.md index 7096e13..be87780 100644 --- a/tests/e2e/test_diffusion_ref/README.md +++ b/tests/e2e/test_diffusion_ref/README.md @@ -52,6 +52,10 @@ Special cases: - `expected_sdxl_dpo_lora.png` - `expected_sdxl_multi_loras.png` - `expected_image_ip_adapter_multi.png` + - `expected_controllora_CPDS.png` + - `expected_controllora_PyraCanny.png` + - `expected_controllora_PyraCanny+CPDS.png` + - `expected_controllora_disabled.png` ## Other images @@ -81,6 +85,8 @@ Special cases: - `statue.png` [comes from tencent-ailab/IP-Adapter](https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/assets/images/statue.png). +- `cutecat_guide_PyraCanny.png` and `cutecat_guide_CPDS.png` were [generated inside Fooocus](https://github.com/lllyasviel/Fooocus/blob/e8d88d3e250e541c6daf99d6ef734e8dc3cfdc7f/extras/preprocessors.py). + ## VAE without randomness ```diff diff --git a/tests/e2e/test_diffusion_ref/cutecat_guide_CPDS.png b/tests/e2e/test_diffusion_ref/cutecat_guide_CPDS.png new file mode 100644 index 0000000..ed9dc94 Binary files /dev/null and b/tests/e2e/test_diffusion_ref/cutecat_guide_CPDS.png differ diff --git a/tests/e2e/test_diffusion_ref/cutecat_guide_PyraCanny.png b/tests/e2e/test_diffusion_ref/cutecat_guide_PyraCanny.png new file mode 100644 index 0000000..c990b17 Binary files /dev/null and b/tests/e2e/test_diffusion_ref/cutecat_guide_PyraCanny.png differ diff --git a/tests/e2e/test_diffusion_ref/expected_controllora_CPDS.png b/tests/e2e/test_diffusion_ref/expected_controllora_CPDS.png new file mode 100644 index 0000000..04590c6 Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_controllora_CPDS.png differ diff --git a/tests/e2e/test_diffusion_ref/expected_controllora_PyraCanny+CPDS.png b/tests/e2e/test_diffusion_ref/expected_controllora_PyraCanny+CPDS.png new file mode 100644 index 0000000..0dad209 Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_controllora_PyraCanny+CPDS.png differ diff --git a/tests/e2e/test_diffusion_ref/expected_controllora_PyraCanny.png b/tests/e2e/test_diffusion_ref/expected_controllora_PyraCanny.png new file mode 100644 index 0000000..2d8b746 Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_controllora_PyraCanny.png differ diff --git a/tests/e2e/test_diffusion_ref/expected_controllora_disabled.png b/tests/e2e/test_diffusion_ref/expected_controllora_disabled.png new file mode 100644 index 0000000..e8e4381 Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_controllora_disabled.png differ