mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 07:08:45 +00:00
parent
d199cd4f24
commit
446967859d
|
@ -2283,29 +2283,9 @@ def test_style_aligned(
|
|||
]
|
||||
|
||||
# create (context) embeddings from prompts
|
||||
# TODO: replace this logic with https://github.com/finegrain-ai/refiners/pull/263 when it gets merged
|
||||
unconds: list[torch.Tensor] = []
|
||||
conds: list[torch.Tensor] = []
|
||||
pooled_unconds: list[torch.Tensor] = []
|
||||
pooled_conds: list[torch.Tensor] = []
|
||||
for prompt in set_of_prompts:
|
||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(text=prompt)
|
||||
|
||||
uncond, cond = clip_text_embedding.chunk(2)
|
||||
pooled_uncond, pooled_cond = pooled_text_embedding.chunk(2)
|
||||
|
||||
unconds.append(uncond)
|
||||
conds.append(cond)
|
||||
pooled_unconds.append(pooled_uncond)
|
||||
pooled_conds.append(pooled_cond)
|
||||
|
||||
uncond = torch.cat(unconds, dim=0)
|
||||
cond = torch.cat(conds, dim=0)
|
||||
pooled_uncond = torch.cat(pooled_unconds, dim=0)
|
||||
pooled_cond = torch.cat(pooled_conds, dim=0)
|
||||
|
||||
clip_text_embedding = torch.cat((uncond, cond), dim=0)
|
||||
pooled_text_embedding = torch.cat((pooled_uncond, pooled_cond), dim=0)
|
||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||
text=set_of_prompts, negative_text=[""] * len(set_of_prompts)
|
||||
)
|
||||
|
||||
time_ids = sdxl.default_time_ids.repeat(len(set_of_prompts), 1)
|
||||
|
||||
|
|
Loading…
Reference in a new issue