mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
parent
d199cd4f24
commit
446967859d
|
@ -2283,29 +2283,9 @@ def test_style_aligned(
|
||||||
]
|
]
|
||||||
|
|
||||||
# create (context) embeddings from prompts
|
# create (context) embeddings from prompts
|
||||||
# TODO: replace this logic with https://github.com/finegrain-ai/refiners/pull/263 when it gets merged
|
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||||
unconds: list[torch.Tensor] = []
|
text=set_of_prompts, negative_text=[""] * len(set_of_prompts)
|
||||||
conds: list[torch.Tensor] = []
|
)
|
||||||
pooled_unconds: list[torch.Tensor] = []
|
|
||||||
pooled_conds: list[torch.Tensor] = []
|
|
||||||
for prompt in set_of_prompts:
|
|
||||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(text=prompt)
|
|
||||||
|
|
||||||
uncond, cond = clip_text_embedding.chunk(2)
|
|
||||||
pooled_uncond, pooled_cond = pooled_text_embedding.chunk(2)
|
|
||||||
|
|
||||||
unconds.append(uncond)
|
|
||||||
conds.append(cond)
|
|
||||||
pooled_unconds.append(pooled_uncond)
|
|
||||||
pooled_conds.append(pooled_cond)
|
|
||||||
|
|
||||||
uncond = torch.cat(unconds, dim=0)
|
|
||||||
cond = torch.cat(conds, dim=0)
|
|
||||||
pooled_uncond = torch.cat(pooled_unconds, dim=0)
|
|
||||||
pooled_cond = torch.cat(pooled_conds, dim=0)
|
|
||||||
|
|
||||||
clip_text_embedding = torch.cat((uncond, cond), dim=0)
|
|
||||||
pooled_text_embedding = torch.cat((pooled_uncond, pooled_cond), dim=0)
|
|
||||||
|
|
||||||
time_ids = sdxl.default_time_ids.repeat(len(set_of_prompts), 1)
|
time_ids = sdxl.default_time_ids.repeat(len(set_of_prompts), 1)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue