mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
fix sdxl structural copy
This commit is contained in:
parent
ca9e89b22a
commit
11da76f7df
|
@ -84,3 +84,16 @@ class DoubleTextEncoder(fl.Chain):
|
||||||
text_embedding_g, pooled_text_embedding = text_embedding_with_pooling
|
text_embedding_g, pooled_text_embedding = text_embedding_with_pooling
|
||||||
text_embedding = cat((text_embedding_l, text_embedding_g), dim=-1)
|
text_embedding = cat((text_embedding_l, text_embedding_g), dim=-1)
|
||||||
return text_embedding, pooled_text_embedding
|
return text_embedding, pooled_text_embedding
|
||||||
|
|
||||||
|
def structural_copy(self: "DoubleTextEncoder") -> "DoubleTextEncoder":
|
||||||
|
old_tep = self.ensure_find(TextEncoderWithPooling)
|
||||||
|
old_tep.eject()
|
||||||
|
copy = super().structural_copy()
|
||||||
|
old_tep.inject()
|
||||||
|
|
||||||
|
new_text_encoder_g = copy.ensure_find(CLIPTextEncoderG)
|
||||||
|
projection = old_tep.layer(("Parallel", "Chain", "Linear"), fl.Linear)
|
||||||
|
|
||||||
|
new_tep = TextEncoderWithPooling(target=new_text_encoder_g, projection=projection)
|
||||||
|
new_tep.inject(copy.layer("Parallel", fl.Parallel))
|
||||||
|
return copy
|
||||||
|
|
|
@ -1590,10 +1590,14 @@ def test_diffusion_sdxl_ip_adapter_plus(
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_sdxl_random_init(
|
@pytest.mark.parametrize("structural_copy", [False, True])
|
||||||
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device
|
def test_diffusion_sdxl_random_init(
|
||||||
|
sdxl_ddim: StableDiffusion_XL,
|
||||||
|
expected_sdxl_ddim_random_init: Image.Image,
|
||||||
|
test_device: torch.device,
|
||||||
|
structural_copy: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
sdxl = sdxl_ddim
|
sdxl = sdxl_ddim.structural_copy() if structural_copy else sdxl_ddim
|
||||||
expected_image = expected_sdxl_ddim_random_init
|
expected_image = expected_sdxl_ddim_random_init
|
||||||
|
|
||||||
prompt = "a cute cat, detailed high-quality professional image"
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
|
|
Loading…
Reference in a new issue