From 446967859d8816ad214998e2ab33b268d26f8ed4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Wed, 21 Feb 2024 15:21:23 +0000 Subject: [PATCH] test_style_aligned: switch to CLIP text batch API Added in #263 --- tests/e2e/test_diffusion.py | 26 +++----------------------- 1 file changed, 3 insertions(+), 23 deletions(-) diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 95aac98..0573526 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -2283,29 +2283,9 @@ def test_style_aligned( ] # create (context) embeddings from prompts - # TODO: replace this logic with https://github.com/finegrain-ai/refiners/pull/263 when it gets merged - unconds: list[torch.Tensor] = [] - conds: list[torch.Tensor] = [] - pooled_unconds: list[torch.Tensor] = [] - pooled_conds: list[torch.Tensor] = [] - for prompt in set_of_prompts: - clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(text=prompt) - - uncond, cond = clip_text_embedding.chunk(2) - pooled_uncond, pooled_cond = pooled_text_embedding.chunk(2) - - unconds.append(uncond) - conds.append(cond) - pooled_unconds.append(pooled_uncond) - pooled_conds.append(pooled_cond) - - uncond = torch.cat(unconds, dim=0) - cond = torch.cat(conds, dim=0) - pooled_uncond = torch.cat(pooled_unconds, dim=0) - pooled_cond = torch.cat(pooled_conds, dim=0) - - clip_text_embedding = torch.cat((uncond, cond), dim=0) - pooled_text_embedding = torch.cat((pooled_uncond, pooled_cond), dim=0) + clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( + text=set_of_prompts, negative_text=[""] * len(set_of_prompts) + ) time_ids = sdxl.default_time_ids.repeat(len(set_of_prompts), 1)