mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 06:38:45 +00:00
support no CFG in compute_clip_text_embedding
This commit is contained in:
parent
94e8b9c23f
commit
364e196874
|
@ -77,6 +77,10 @@ class StableDiffusion_1(LatentDiffusionModel):
|
||||||
If not provided, the negative prompt is assumed to be empty (i.e., `""`).
|
If not provided, the negative prompt is assumed to be empty (i.e., `""`).
|
||||||
"""
|
"""
|
||||||
text = [text] if isinstance(text, str) else text
|
text = [text] if isinstance(text, str) else text
|
||||||
|
|
||||||
|
if not self.classifier_free_guidance:
|
||||||
|
return self.clip_text_encoder(text)
|
||||||
|
|
||||||
negative_text = [negative_text] if isinstance(negative_text, str) else negative_text
|
negative_text = [negative_text] if isinstance(negative_text, str) else negative_text
|
||||||
assert len(text) == len(negative_text), "The length of the text list and negative_text should be the same"
|
assert len(text) == len(negative_text), "The length of the text list and negative_text should be the same"
|
||||||
|
|
||||||
|
|
|
@ -77,6 +77,10 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
text = [text] if isinstance(text, str) else text
|
text = [text] if isinstance(text, str) else text
|
||||||
|
|
||||||
|
if not self.classifier_free_guidance:
|
||||||
|
return self.clip_text_encoder(text)
|
||||||
|
|
||||||
negative_text = [negative_text] if isinstance(negative_text, str) else negative_text
|
negative_text = [negative_text] if isinstance(negative_text, str) else negative_text
|
||||||
assert len(text) == len(negative_text), "The length of the text list and negative_text should be the same"
|
assert len(text) == len(negative_text), "The length of the text list and negative_text should be the same"
|
||||||
|
|
||||||
|
|
|
@ -114,8 +114,7 @@ def test_lcm_base(
|
||||||
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
||||||
expected_image = expected_lcm_base
|
expected_image = expected_lcm_base
|
||||||
|
|
||||||
# *NOT* compute_clip_text_embedding! We disable classifier-free guidance.
|
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(prompt)
|
||||||
clip_text_embedding, pooled_text_embedding = sdxl.clip_text_encoder(prompt)
|
|
||||||
time_ids = sdxl.default_time_ids
|
time_ids = sdxl.default_time_ids
|
||||||
|
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
|
@ -163,7 +162,6 @@ def test_lcm_lora_with_guidance(
|
||||||
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
||||||
expected_image = expected_lcm_lora_1_0 if condition_scale == 1.0 else expected_lcm_lora_1_2
|
expected_image = expected_lcm_lora_1_0 if condition_scale == 1.0 else expected_lcm_lora_1_2
|
||||||
|
|
||||||
# *NOT* clip_text_encoder! We use classifier-free guidance here.
|
|
||||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(prompt)
|
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(prompt)
|
||||||
time_ids = sdxl.default_time_ids
|
time_ids = sdxl.default_time_ids
|
||||||
assert time_ids.shape == (2, 6) # CFG
|
assert time_ids.shape == (2, 6) # CFG
|
||||||
|
@ -213,8 +211,7 @@ def test_lcm_lora_without_guidance(
|
||||||
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
||||||
expected_image = expected_lcm_lora_1_0
|
expected_image = expected_lcm_lora_1_0
|
||||||
|
|
||||||
# *NOT* compute_clip_text_embedding! We disable classifier-free guidance.
|
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(prompt)
|
||||||
clip_text_embedding, pooled_text_embedding = sdxl.clip_text_encoder(prompt)
|
|
||||||
time_ids = sdxl.default_time_ids
|
time_ids = sdxl.default_time_ids
|
||||||
assert time_ids.shape == (1, 6) # no CFG
|
assert time_ids.shape == (1, 6) # no CFG
|
||||||
|
|
||||||
|
|
|
@ -127,7 +127,7 @@ def test_lightning_base_4step(
|
||||||
|
|
||||||
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
||||||
|
|
||||||
clip_text_embedding, pooled_text_embedding = sdxl.clip_text_encoder(prompt)
|
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(prompt)
|
||||||
time_ids = sdxl.default_time_ids
|
time_ids = sdxl.default_time_ids
|
||||||
|
|
||||||
manual_seed(0)
|
manual_seed(0)
|
||||||
|
@ -178,7 +178,7 @@ def test_lightning_base_1step(
|
||||||
|
|
||||||
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
||||||
|
|
||||||
clip_text_embedding, pooled_text_embedding = sdxl.clip_text_encoder(prompt)
|
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(prompt)
|
||||||
time_ids = sdxl.default_time_ids
|
time_ids = sdxl.default_time_ids
|
||||||
|
|
||||||
manual_seed(0)
|
manual_seed(0)
|
||||||
|
@ -232,7 +232,7 @@ def test_lightning_lora_4step(
|
||||||
|
|
||||||
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
||||||
|
|
||||||
clip_text_embedding, pooled_text_embedding = sdxl.clip_text_encoder(prompt)
|
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(prompt)
|
||||||
time_ids = sdxl.default_time_ids
|
time_ids = sdxl.default_time_ids
|
||||||
|
|
||||||
manual_seed(0)
|
manual_seed(0)
|
||||||
|
|
Loading…
Reference in a new issue