diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 4b1bc61..418cfb9 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -30,6 +30,7 @@ from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSc from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion from refiners.foundationals.latent_diffusion.stable_diffusion_xl.control_lora import ControlLoraAdapter from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL +from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAdapter from tests.utils import ensure_similar_images @@ -150,6 +151,11 @@ def expected_sdxl_euler_random_init(ref_path: Path) -> Image.Image: return Image.open(ref_path / "expected_cutecat_sdxl_euler_random_init.png").convert("RGB") +@pytest.fixture +def expected_style_aligned(ref_path: Path) -> Image.Image: + return Image.open(fp=ref_path / "expected_style_aligned.png").convert(mode="RGB") + + @pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"]) def controlnet_data( ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest @@ -2140,3 +2146,79 @@ def test_hello_world( predicted_image = sdxl.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image) + + +@no_grad() +def test_style_aligned( + sdxl_ddim_lda_fp16_fix: StableDiffusion_XL, + expected_style_aligned: Image.Image, +): + sdxl = sdxl_ddim_lda_fp16_fix.to(dtype=torch.float16) + sdxl.dtype = torch.float16 # FIXME: should not be necessary + + style_aligned_adapter = StyleAlignedAdapter(sdxl.unet) + style_aligned_adapter.inject() + + set_of_prompts = [ + "a toy train. macro photo. 3d game asset", + "a toy airplane. macro photo. 3d game asset", + "a toy bicycle. macro photo. 3d game asset", + "a toy car. macro photo. 3d game asset", + "a toy boat. macro photo. 3d game asset", + ] + + # create (context) embeddings from prompts + # TODO: replace this logic with https://github.com/finegrain-ai/refiners/pull/263 when it gets merged + unconds: list[torch.Tensor] = [] + conds: list[torch.Tensor] = [] + pooled_unconds: list[torch.Tensor] = [] + pooled_conds: list[torch.Tensor] = [] + for prompt in set_of_prompts: + clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(text=prompt) + + uncond, cond = clip_text_embedding.chunk(2) + pooled_uncond, pooled_cond = pooled_text_embedding.chunk(2) + + unconds.append(uncond) + conds.append(cond) + pooled_unconds.append(pooled_uncond) + pooled_conds.append(pooled_cond) + + uncond = torch.cat(unconds, dim=0) + cond = torch.cat(conds, dim=0) + pooled_uncond = torch.cat(pooled_unconds, dim=0) + pooled_cond = torch.cat(pooled_conds, dim=0) + + clip_text_embedding = torch.cat((uncond, cond), dim=0) + pooled_text_embedding = torch.cat((pooled_uncond, pooled_cond), dim=0) + + time_ids = sdxl.default_time_ids.repeat(len(set_of_prompts), 1) + + # initialize latents + manual_seed(seed=2) + x = torch.randn( + (len(set_of_prompts), 4, 128, 128), + device=sdxl.device, + dtype=sdxl.dtype, + ) + + # denoise + for step in sdxl.steps: + x = sdxl( + x, + step=step, + clip_text_embedding=clip_text_embedding, + pooled_text_embedding=pooled_text_embedding, + time_ids=time_ids, + ) + + # decode latents + predicted_images = [sdxl.lda.decode_latents(latent.unsqueeze(0)) for latent in x] + + # tile all images horizontally + merged_image = Image.new("RGB", (1024 * len(predicted_images), 1024)) + for i in range(len(predicted_images)): + merged_image.paste(predicted_images[i], (i * 1024, 0)) + + # compare against reference image + ensure_similar_images(merged_image, expected_style_aligned, min_psnr=35, min_ssim=0.99) diff --git a/tests/e2e/test_diffusion_ref/README.md b/tests/e2e/test_diffusion_ref/README.md index be87780..42e3594 100644 --- a/tests/e2e/test_diffusion_ref/README.md +++ b/tests/e2e/test_diffusion_ref/README.md @@ -56,6 +56,7 @@ Special cases: - `expected_controllora_PyraCanny.png` - `expected_controllora_PyraCanny+CPDS.png` - `expected_controllora_disabled.png` + - `expected_style_aligned.png` ## Other images diff --git a/tests/e2e/test_diffusion_ref/expected_style_aligned.png b/tests/e2e/test_diffusion_ref/expected_style_aligned.png new file mode 100644 index 0000000..dc4d278 Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_style_aligned.png differ