test_style_aligned: switch to CLIP text batch API

Added in #263
This commit is contained in:
Cédric Deltheil 2024-02-21 15:21:23 +00:00 committed by Cédric Deltheil
parent d199cd4f24
commit 446967859d

View file

@ -2283,29 +2283,9 @@ def test_style_aligned(
]
# create (context) embeddings from prompts
# TODO: replace this logic with https://github.com/finegrain-ai/refiners/pull/263 when it gets merged
unconds: list[torch.Tensor] = []
conds: list[torch.Tensor] = []
pooled_unconds: list[torch.Tensor] = []
pooled_conds: list[torch.Tensor] = []
for prompt in set_of_prompts:
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(text=prompt)
uncond, cond = clip_text_embedding.chunk(2)
pooled_uncond, pooled_cond = pooled_text_embedding.chunk(2)
unconds.append(uncond)
conds.append(cond)
pooled_unconds.append(pooled_uncond)
pooled_conds.append(pooled_cond)
uncond = torch.cat(unconds, dim=0)
cond = torch.cat(conds, dim=0)
pooled_uncond = torch.cat(pooled_unconds, dim=0)
pooled_cond = torch.cat(pooled_conds, dim=0)
clip_text_embedding = torch.cat((uncond, cond), dim=0)
pooled_text_embedding = torch.cat((pooled_uncond, pooled_cond), dim=0)
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=set_of_prompts, negative_text=[""] * len(set_of_prompts)
)
time_ids = sdxl.default_time_ids.repeat(len(set_of_prompts), 1)