fix sdxl structural copy

This commit is contained in:
Pierre Chapuis 2024-02-06 18:58:47 +01:00
parent ca9e89b22a
commit 11da76f7df
2 changed files with 20 additions and 3 deletions

View file

@ -84,3 +84,16 @@ class DoubleTextEncoder(fl.Chain):
text_embedding_g, pooled_text_embedding = text_embedding_with_pooling
text_embedding = cat((text_embedding_l, text_embedding_g), dim=-1)
return text_embedding, pooled_text_embedding
def structural_copy(self: "DoubleTextEncoder") -> "DoubleTextEncoder":
old_tep = self.ensure_find(TextEncoderWithPooling)
old_tep.eject()
copy = super().structural_copy()
old_tep.inject()
new_text_encoder_g = copy.ensure_find(CLIPTextEncoderG)
projection = old_tep.layer(("Parallel", "Chain", "Linear"), fl.Linear)
new_tep = TextEncoderWithPooling(target=new_text_encoder_g, projection=projection)
new_tep.inject(copy.layer("Parallel", fl.Parallel))
return copy

View file

@ -1590,10 +1590,14 @@ def test_diffusion_sdxl_ip_adapter_plus(
@no_grad()
def test_sdxl_random_init(
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device
@pytest.mark.parametrize("structural_copy", [False, True])
def test_diffusion_sdxl_random_init(
sdxl_ddim: StableDiffusion_XL,
expected_sdxl_ddim_random_init: Image.Image,
test_device: torch.device,
structural_copy: bool,
) -> None:
sdxl = sdxl_ddim
sdxl = sdxl_ddim.structural_copy() if structural_copy else sdxl_ddim
expected_image = expected_sdxl_ddim_random_init
prompt = "a cute cat, detailed high-quality professional image"