diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 76d3851..5ff987d 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -135,12 +135,17 @@ def expected_image_ip_adapter_controlnet(ref_path: Path) -> Image.Image: @pytest.fixture def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image: - return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert(mode="RGB") + return Image.open(ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert("RGB") @pytest.fixture def expected_sdxl_ddim_random_init_sag(ref_path: Path) -> Image.Image: - return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert(mode="RGB") + return Image.open(ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert("RGB") + + +@pytest.fixture +def expected_sdxl_euler_random_init(ref_path: Path) -> Image.Image: + return Image.open(ref_path / "expected_cutecat_sdxl_euler_random_init.png").convert("RGB") @pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"]) @@ -627,6 +632,18 @@ def sdxl_ddim_lda_fp16_fix( return sdxl +@pytest.fixture +def sdxl_euler_deterministic(sdxl_ddim: StableDiffusion_XL) -> StableDiffusion_XL: + return StableDiffusion_XL( + unet=sdxl_ddim.unet, + lda=sdxl_ddim.lda, + clip_text_encoder=sdxl_ddim.clip_text_encoder, + solver=Euler(num_inference_steps=30), + device=sdxl_ddim.device, + dtype=sdxl_ddim.dtype, + ) + + @no_grad() def test_diffusion_std_random_init( sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device @@ -1684,6 +1701,44 @@ def test_diffusion_sdxl_sliced_attention( ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) +@no_grad() +def test_diffusion_sdxl_euler_deterministic( + sdxl_euler_deterministic: StableDiffusion_XL, expected_sdxl_euler_random_init: Image.Image +) -> None: + sdxl = sdxl_euler_deterministic + assert isinstance(sdxl.solver, Euler) + + expected_image = expected_sdxl_euler_random_init + + prompt = "a cute cat, detailed high-quality professional image" + negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" + + clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( + text=prompt, negative_text=negative_prompt + ) + time_ids = sdxl.default_time_ids + sdxl.set_inference_steps(30) + manual_seed(2) + x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype) + + # init latents must be scaled for Euler + # TODO make init_latents work + x = x * sdxl.solver.init_noise_sigma + + for step in sdxl.steps: + x = sdxl( + x, + step=step, + clip_text_embedding=clip_text_embedding, + pooled_text_embedding=pooled_text_embedding, + time_ids=time_ids, + condition_scale=5, + ) + + predicted_image = sdxl.lda.decode_latents(x) + ensure_similar_images(predicted_image, expected_image) + + @no_grad() def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None: manual_seed(seed=2) diff --git a/tests/e2e/test_diffusion_ref/README.md b/tests/e2e/test_diffusion_ref/README.md index bdecfc6..7096e13 100644 --- a/tests/e2e/test_diffusion_ref/README.md +++ b/tests/e2e/test_diffusion_ref/README.md @@ -45,6 +45,7 @@ Special cases: - `expected_t2i_adapter_xl_canny.png` - `expected_image_sdxl_ip_adapter_plus_woman.png` - `expected_cutecat_sdxl_ddim_random_init_sag.png` + - `expected_cutecat_sdxl_euler_random_init.png` - `expected_restart.png` - `expected_freeu.png` - `expected_dropy_slime_9752.png` diff --git a/tests/e2e/test_diffusion_ref/expected_cutecat_sdxl_euler_random_init.png b/tests/e2e/test_diffusion_ref/expected_cutecat_sdxl_euler_random_init.png new file mode 100644 index 0000000..df73e7b Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_cutecat_sdxl_euler_random_init.png differ