diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 0e1a6c5..b282718 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -108,6 +108,11 @@ def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image: return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB") +@pytest.fixture +def expected_image_ip_adapter_multi(ref_path: Path) -> Image.Image: + return Image.open(ref_path / "expected_image_ip_adapter_multi.png").convert("RGB") + + @pytest.fixture def expected_image_ip_adapter_plus_statue(ref_path: Path) -> Image.Image: return Image.open(ref_path / "expected_image_ip_adapter_plus_statue.png").convert("RGB") @@ -1334,6 +1339,46 @@ def test_diffusion_ip_adapter( ensure_similar_images(predicted_image, expected_image_ip_adapter_woman) +@no_grad() +def test_diffusion_ip_adapter_multi( + sd15_ddim_lda_ft_mse: StableDiffusion_1, + ip_adapter_weights: Path, + image_encoder_weights: Path, + woman_image: Image.Image, + statue_image: Image.Image, + expected_image_ip_adapter_multi: Image.Image, + test_device: torch.device, +): + sd15 = sd15_ddim_lda_ft_mse.to(dtype=torch.float16) + + prompt = "best quality, high quality" + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_weights)) + ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) + ip_adapter.inject() + + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + clip_image_embedding = ip_adapter.compute_clip_image_embedding([woman_image, statue_image], weights=[1.0, 1.4]) + ip_adapter.set_clip_image_embedding(clip_image_embedding) + + sd15.set_inference_steps(50) + + manual_seed(2) + x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) + + for step in sd15.steps: + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.decode_latents(x) + + ensure_similar_images(predicted_image, expected_image_ip_adapter_multi) + + @no_grad() def test_diffusion_sdxl_ip_adapter( sdxl_ddim: StableDiffusion_XL, diff --git a/tests/e2e/test_diffusion_ref/README.md b/tests/e2e/test_diffusion_ref/README.md index fdd22b1..bdecfc6 100644 --- a/tests/e2e/test_diffusion_ref/README.md +++ b/tests/e2e/test_diffusion_ref/README.md @@ -50,6 +50,7 @@ Special cases: - `expected_dropy_slime_9752.png` - `expected_sdxl_dpo_lora.png` - `expected_sdxl_multi_loras.png` + - `expected_image_ip_adapter_multi.png` ## Other images diff --git a/tests/e2e/test_diffusion_ref/expected_image_ip_adapter_multi.png b/tests/e2e/test_diffusion_ref/expected_image_ip_adapter_multi.png new file mode 100644 index 0000000..31654c6 Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_image_ip_adapter_multi.png differ