support no CFG in compute_clip_text_embedding

This commit is contained in:
Pierre Chapuis 2024-03-22 16:47:38 +01:00
parent 94e8b9c23f
commit 364e196874
4 changed files with 13 additions and 8 deletions

View file

@ -77,6 +77,10 @@ class StableDiffusion_1(LatentDiffusionModel):
If not provided, the negative prompt is assumed to be empty (i.e., `""`). If not provided, the negative prompt is assumed to be empty (i.e., `""`).
""" """
text = [text] if isinstance(text, str) else text text = [text] if isinstance(text, str) else text
if not self.classifier_free_guidance:
return self.clip_text_encoder(text)
negative_text = [negative_text] if isinstance(negative_text, str) else negative_text negative_text = [negative_text] if isinstance(negative_text, str) else negative_text
assert len(text) == len(negative_text), "The length of the text list and negative_text should be the same" assert len(text) == len(negative_text), "The length of the text list and negative_text should be the same"

View file

@ -77,6 +77,10 @@ class StableDiffusion_XL(LatentDiffusionModel):
""" """
text = [text] if isinstance(text, str) else text text = [text] if isinstance(text, str) else text
if not self.classifier_free_guidance:
return self.clip_text_encoder(text)
negative_text = [negative_text] if isinstance(negative_text, str) else negative_text negative_text = [negative_text] if isinstance(negative_text, str) else negative_text
assert len(text) == len(negative_text), "The length of the text list and negative_text should be the same" assert len(text) == len(negative_text), "The length of the text list and negative_text should be the same"

View file

@ -114,8 +114,7 @@ def test_lcm_base(
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
expected_image = expected_lcm_base expected_image = expected_lcm_base
# *NOT* compute_clip_text_embedding! We disable classifier-free guidance. clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(prompt)
clip_text_embedding, pooled_text_embedding = sdxl.clip_text_encoder(prompt)
time_ids = sdxl.default_time_ids time_ids = sdxl.default_time_ids
manual_seed(2) manual_seed(2)
@ -163,7 +162,6 @@ def test_lcm_lora_with_guidance(
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
expected_image = expected_lcm_lora_1_0 if condition_scale == 1.0 else expected_lcm_lora_1_2 expected_image = expected_lcm_lora_1_0 if condition_scale == 1.0 else expected_lcm_lora_1_2
# *NOT* clip_text_encoder! We use classifier-free guidance here.
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(prompt) clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(prompt)
time_ids = sdxl.default_time_ids time_ids = sdxl.default_time_ids
assert time_ids.shape == (2, 6) # CFG assert time_ids.shape == (2, 6) # CFG
@ -213,8 +211,7 @@ def test_lcm_lora_without_guidance(
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
expected_image = expected_lcm_lora_1_0 expected_image = expected_lcm_lora_1_0
# *NOT* compute_clip_text_embedding! We disable classifier-free guidance. clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(prompt)
clip_text_embedding, pooled_text_embedding = sdxl.clip_text_encoder(prompt)
time_ids = sdxl.default_time_ids time_ids = sdxl.default_time_ids
assert time_ids.shape == (1, 6) # no CFG assert time_ids.shape == (1, 6) # no CFG

View file

@ -127,7 +127,7 @@ def test_lightning_base_4step(
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
clip_text_embedding, pooled_text_embedding = sdxl.clip_text_encoder(prompt) clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(prompt)
time_ids = sdxl.default_time_ids time_ids = sdxl.default_time_ids
manual_seed(0) manual_seed(0)
@ -178,7 +178,7 @@ def test_lightning_base_1step(
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
clip_text_embedding, pooled_text_embedding = sdxl.clip_text_encoder(prompt) clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(prompt)
time_ids = sdxl.default_time_ids time_ids = sdxl.default_time_ids
manual_seed(0) manual_seed(0)
@ -232,7 +232,7 @@ def test_lightning_lora_4step(
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
clip_text_embedding, pooled_text_embedding = sdxl.clip_text_encoder(prompt) clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(prompt)
time_ids = sdxl.default_time_ids time_ids = sdxl.default_time_ids
manual_seed(0) manual_seed(0)