From 11da76f7dfaddf6931bbdb9448416e494126f4e6 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Tue, 6 Feb 2024 18:58:47 +0100 Subject: [PATCH] fix sdxl structural copy --- .../stable_diffusion_xl/text_encoder.py | 13 +++++++++++++ tests/e2e/test_diffusion.py | 10 +++++++--- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py index 79ede44..ec17507 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py @@ -84,3 +84,16 @@ class DoubleTextEncoder(fl.Chain): text_embedding_g, pooled_text_embedding = text_embedding_with_pooling text_embedding = cat((text_embedding_l, text_embedding_g), dim=-1) return text_embedding, pooled_text_embedding + + def structural_copy(self: "DoubleTextEncoder") -> "DoubleTextEncoder": + old_tep = self.ensure_find(TextEncoderWithPooling) + old_tep.eject() + copy = super().structural_copy() + old_tep.inject() + + new_text_encoder_g = copy.ensure_find(CLIPTextEncoderG) + projection = old_tep.layer(("Parallel", "Chain", "Linear"), fl.Linear) + + new_tep = TextEncoderWithPooling(target=new_text_encoder_g, projection=projection) + new_tep.inject(copy.layer("Parallel", fl.Parallel)) + return copy diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index e8ccad9..495ac2b 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -1590,10 +1590,14 @@ def test_diffusion_sdxl_ip_adapter_plus( @no_grad() -def test_sdxl_random_init( - sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device +@pytest.mark.parametrize("structural_copy", [False, True]) +def test_diffusion_sdxl_random_init( + sdxl_ddim: StableDiffusion_XL, + expected_sdxl_ddim_random_init: Image.Image, + test_device: torch.device, + structural_copy: bool, ) -> None: - sdxl = sdxl_ddim + sdxl = sdxl_ddim.structural_copy() if structural_copy else sdxl_ddim expected_image = expected_sdxl_ddim_random_init prompt = "a cute cat, detailed high-quality professional image"