diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index d6528c6..d7aa74b 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -17,9 +17,11 @@ from refiners.foundationals.latent_diffusion import ( SDXLIPAdapter, ) from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter +from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget from refiners.foundationals.latent_diffusion.schedulers import DDIM from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter from refiners.foundationals.clip.concepts import ConceptExtender +from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL from tests.utils import ensure_similar_images @@ -169,6 +171,11 @@ def expected_image_textual_inversion_random_init(ref_path: Path) -> Image.Image: return Image.open(ref_path / "expected_textual_inversion_random_init.png").convert("RGB") +@pytest.fixture +def expected_multi_diffusion(ref_path: Path) -> Image.Image: + return Image.open(fp=ref_path / "expected_multi_diffusion.png").convert(mode="RGB") + + @pytest.fixture def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor: return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")[""] # type: ignore @@ -1179,3 +1186,34 @@ def test_sdxl_random_init( predicted_image = sdxl.lda.decode_latents(x=x) ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98) + + +@torch.no_grad() +def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None: + manual_seed(seed=2) + sd = sd15_ddim + multi_diffusion = SD1MultiDiffusion(sd) + clip_text_embedding = sd.compute_clip_text_embedding(text="a panorama of a mountain") + target_1 = DiffusionTarget( + size=(64, 64), + offset=(0, 0), + clip_text_embedding=clip_text_embedding, + start_step=0, + ) + target_2 = DiffusionTarget( + size=(64, 64), + offset=(0, 16), + clip_text_embedding=clip_text_embedding, + start_step=0, + ) + noise = torch.randn(1, 4, 64, 80, device=sd.device, dtype=sd.dtype) + x = noise + for step in sd.steps: + x = multi_diffusion( + x, + noise=noise, + step=step, + targets=[target_1, target_2], + ) + result = sd.lda.decode_latents(x=x) + ensure_similar_images(img_1=result, img_2=expected_multi_diffusion, min_psnr=35, min_ssim=0.98) diff --git a/tests/e2e/test_diffusion_ref/expected_multi_diffusion.png b/tests/e2e/test_diffusion_ref/expected_multi_diffusion.png new file mode 100644 index 0000000..da09414 Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_multi_diffusion.png differ