diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index d8d273f..921bc9b 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -16,6 +16,7 @@ from refiners.foundationals.latent_diffusion import ( SD1IPAdapter, SD1T2IAdapter, SDXLIPAdapter, + SDXLT2IAdapter, ) from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget @@ -129,6 +130,7 @@ def controlnet_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11f1p_sd15_depth.safetensors" return cn_name, condition_image, expected_image, weights_path + @pytest.fixture(scope="module") def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: name = "depth" @@ -138,6 +140,15 @@ def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str return name, condition_image, expected_image, weights_path +@pytest.fixture(scope="module") +def t2i_adapter_xl_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: + name = "canny" + condition_image = Image.open(ref_path / f"fairy_guide_{name}.png").convert("RGB") + expected_image = Image.open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB") + weights_path = test_weights_path / "T2I-Adapter" / "t2i-adapter-canny-sdxl-1.0.safetensors" + return name, condition_image, expected_image, weights_path + + @pytest.fixture(scope="module") def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, Path]: expected_image = Image.open(ref_path / "expected_lora_pokemon.png").convert("RGB") @@ -1283,3 +1294,52 @@ def test_t2i_adapter_depth( predicted_image = sd15.lda.decode_latents(x) ensure_similar_images(predicted_image, expected_image) + + +@torch.no_grad() +def test_t2i_adapter_xl_canny( + sdxl_ddim: StableDiffusion_XL, + t2i_adapter_xl_data_canny: tuple[str, Image.Image, Image.Image, Path], + test_device: torch.device, +): + sdxl = sdxl_ddim + n_steps = 30 + + name, condition_image, expected_image, weights_path = t2i_adapter_xl_data_canny + + if not weights_path.is_file(): + warn(f"could not find weights at {weights_path}, skipping") + pytest.skip(allow_module_level=True) + + prompt = "Mystical fairy in real, magic, 4k picture, high quality" + negative_prompt = ( + "extra digit, fewer digits, cropped, worst quality, low quality, glitch, deformed, mutated, ugly, disfigured" + ) + clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( + text=prompt, negative_text=negative_prompt + ) + time_ids = sdxl.default_time_ids + + sdxl.set_num_inference_steps(n_steps) + + t2i_adapter = SDXLT2IAdapter(target=sdxl.unet, name=name, weights=load_from_safetensors(weights_path)).inject() + t2i_adapter.set_scale(0.8) + + condition = image_to_tensor(condition_image.convert("RGB"), device=test_device) + t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition)) + + manual_seed(2) + x = torch.randn(1, 4, condition_image.height // 8, condition_image.width // 8, device=test_device) + + for step in sdxl.steps: + x = sdxl( + x, + step=step, + clip_text_embedding=clip_text_embedding, + pooled_text_embedding=pooled_text_embedding, + time_ids=time_ids, + condition_scale=7.5, + ) + predicted_image = sdxl.lda.decode_latents(x) + + ensure_similar_images(predicted_image, expected_image) diff --git a/tests/e2e/test_diffusion_ref/README.md b/tests/e2e/test_diffusion_ref/README.md index a0aed75..ece1f96 100644 --- a/tests/e2e/test_diffusion_ref/README.md +++ b/tests/e2e/test_diffusion_ref/README.md @@ -35,7 +35,12 @@ output.images[0].save("std_random_init_expected.png") Special cases: - `expected_refonly.png` has been generated [with Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui). -- `expected_inpainting_refonly.png`, `expected_image_ip_adapter_woman.png`, `expected_image_sdxl_ip_adapter_woman.png` and `expected_ip_adapter_controlnet.png` have been generated with refiners itself (and inspected so that they look reasonable). +- The following references have been generated with refiners itself (and inspected so that they look reasonable): + - `expected_inpainting_refonly.png`, + - `expected_image_ip_adapter_woman.png`, + - `expected_image_sdxl_ip_adapter_woman.png` + - `expected_ip_adapter_controlnet.png` + - `expected_t2i_adapter_xl_canny.png` ## Other images @@ -45,13 +50,15 @@ Special cases: - `kitchen_mask.png` is made manually. -- Controlnet guides have been manually generated using open source software and models, namely: +- Controlnet guides have been manually generated (x) using open source software and models, namely: - Canny: opencv-python - Depth: https://github.com/isl-org/ZoeDepth - Lineart: https://github.com/lllyasviel/ControlNet-v1-1-nightly/tree/main/annotator/lineart - Normals: https://github.com/baegwangbin/surface_normal_uncertainty/tree/fe2b9f1 - SAM: https://huggingface.co/spaces/mfidabel/controlnet-segment-anything +(x): excepted `fairy_guide_canny.png` which comes from [TencentARC/t2i-adapter-canny-sdxl-1.0](https://huggingface.co/TencentARC/t2i-adapter-canny-sdxl-1.0) + - `cyberpunk_guide.png` [comes from Lexica](https://lexica.art/prompt/5ba40855-0d0c-4322-8722-51115985f573). - `inpainting-mask.png`, `inpainting-scene.png` and `inpainting-target.png` have been generated as follows: diff --git a/tests/e2e/test_diffusion_ref/expected_t2i_adapter_xl_canny.png b/tests/e2e/test_diffusion_ref/expected_t2i_adapter_xl_canny.png new file mode 100644 index 0000000..3b41635 Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_t2i_adapter_xl_canny.png differ diff --git a/tests/e2e/test_diffusion_ref/fairy_guide_canny.png b/tests/e2e/test_diffusion_ref/fairy_guide_canny.png new file mode 100644 index 0000000..ce97e70 Binary files /dev/null and b/tests/e2e/test_diffusion_ref/fairy_guide_canny.png differ