From 364e196874f64653232c47dc662bad17b3d45196 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Fri, 22 Mar 2024 16:47:38 +0100 Subject: [PATCH] support no CFG in compute_clip_text_embedding --- .../latent_diffusion/stable_diffusion_1/model.py | 4 ++++ .../latent_diffusion/stable_diffusion_xl/model.py | 4 ++++ tests/e2e/test_lcm.py | 7 ++----- tests/e2e/test_lightning.py | 6 +++--- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py index b48eb1f..873fc4b 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py @@ -77,6 +77,10 @@ class StableDiffusion_1(LatentDiffusionModel): If not provided, the negative prompt is assumed to be empty (i.e., `""`). """ text = [text] if isinstance(text, str) else text + + if not self.classifier_free_guidance: + return self.clip_text_encoder(text) + negative_text = [negative_text] if isinstance(negative_text, str) else negative_text assert len(text) == len(negative_text), "The length of the text list and negative_text should be the same" diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py index 31a747d..4c4f8d7 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py @@ -77,6 +77,10 @@ class StableDiffusion_XL(LatentDiffusionModel): """ text = [text] if isinstance(text, str) else text + + if not self.classifier_free_guidance: + return self.clip_text_encoder(text) + negative_text = [negative_text] if isinstance(negative_text, str) else negative_text assert len(text) == len(negative_text), "The length of the text list and negative_text should be the same" diff --git a/tests/e2e/test_lcm.py b/tests/e2e/test_lcm.py index 205b608..8d52250 100644 --- a/tests/e2e/test_lcm.py +++ b/tests/e2e/test_lcm.py @@ -114,8 +114,7 @@ def test_lcm_base( prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" expected_image = expected_lcm_base - # *NOT* compute_clip_text_embedding! We disable classifier-free guidance. - clip_text_embedding, pooled_text_embedding = sdxl.clip_text_encoder(prompt) + clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(prompt) time_ids = sdxl.default_time_ids manual_seed(2) @@ -163,7 +162,6 @@ def test_lcm_lora_with_guidance( prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" expected_image = expected_lcm_lora_1_0 if condition_scale == 1.0 else expected_lcm_lora_1_2 - # *NOT* clip_text_encoder! We use classifier-free guidance here. clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(prompt) time_ids = sdxl.default_time_ids assert time_ids.shape == (2, 6) # CFG @@ -213,8 +211,7 @@ def test_lcm_lora_without_guidance( prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" expected_image = expected_lcm_lora_1_0 - # *NOT* compute_clip_text_embedding! We disable classifier-free guidance. - clip_text_embedding, pooled_text_embedding = sdxl.clip_text_encoder(prompt) + clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(prompt) time_ids = sdxl.default_time_ids assert time_ids.shape == (1, 6) # no CFG diff --git a/tests/e2e/test_lightning.py b/tests/e2e/test_lightning.py index 40793aa..4f5f6f6 100644 --- a/tests/e2e/test_lightning.py +++ b/tests/e2e/test_lightning.py @@ -127,7 +127,7 @@ def test_lightning_base_4step( prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" - clip_text_embedding, pooled_text_embedding = sdxl.clip_text_encoder(prompt) + clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(prompt) time_ids = sdxl.default_time_ids manual_seed(0) @@ -178,7 +178,7 @@ def test_lightning_base_1step( prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" - clip_text_embedding, pooled_text_embedding = sdxl.clip_text_encoder(prompt) + clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(prompt) time_ids = sdxl.default_time_ids manual_seed(0) @@ -232,7 +232,7 @@ def test_lightning_lora_4step( prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" - clip_text_embedding, pooled_text_embedding = sdxl.clip_text_encoder(prompt) + clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(prompt) time_ids = sdxl.default_time_ids manual_seed(0)