mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
fix sdxl structural copy
This commit is contained in:
parent
ca9e89b22a
commit
11da76f7df
|
@ -84,3 +84,16 @@ class DoubleTextEncoder(fl.Chain):
|
|||
text_embedding_g, pooled_text_embedding = text_embedding_with_pooling
|
||||
text_embedding = cat((text_embedding_l, text_embedding_g), dim=-1)
|
||||
return text_embedding, pooled_text_embedding
|
||||
|
||||
def structural_copy(self: "DoubleTextEncoder") -> "DoubleTextEncoder":
|
||||
old_tep = self.ensure_find(TextEncoderWithPooling)
|
||||
old_tep.eject()
|
||||
copy = super().structural_copy()
|
||||
old_tep.inject()
|
||||
|
||||
new_text_encoder_g = copy.ensure_find(CLIPTextEncoderG)
|
||||
projection = old_tep.layer(("Parallel", "Chain", "Linear"), fl.Linear)
|
||||
|
||||
new_tep = TextEncoderWithPooling(target=new_text_encoder_g, projection=projection)
|
||||
new_tep.inject(copy.layer("Parallel", fl.Parallel))
|
||||
return copy
|
||||
|
|
|
@ -1590,10 +1590,14 @@ def test_diffusion_sdxl_ip_adapter_plus(
|
|||
|
||||
|
||||
@no_grad()
|
||||
def test_sdxl_random_init(
|
||||
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device
|
||||
@pytest.mark.parametrize("structural_copy", [False, True])
|
||||
def test_diffusion_sdxl_random_init(
|
||||
sdxl_ddim: StableDiffusion_XL,
|
||||
expected_sdxl_ddim_random_init: Image.Image,
|
||||
test_device: torch.device,
|
||||
structural_copy: bool,
|
||||
) -> None:
|
||||
sdxl = sdxl_ddim
|
||||
sdxl = sdxl_ddim.structural_copy() if structural_copy else sdxl_ddim
|
||||
expected_image = expected_sdxl_ddim_random_init
|
||||
|
||||
prompt = "a cute cat, detailed high-quality professional image"
|
||||
|
|
Loading…
Reference in a new issue