batch sdxl + sd1 + compute_clip_text_embedding

Co-authored-by: Cédric Deltheil <355031+deltheil@users.noreply.github.com>
This commit is contained in:
Pierre Colle 2024-02-05 22:19:45 +01:00 committed by Cédric Deltheil
parent 8139b2dd91
commit d199cd4f24
5 changed files with 163 additions and 23 deletions

View file

@ -68,7 +68,7 @@ class StableDiffusion_1(LatentDiffusionModel):
dtype=dtype,
)
def compute_clip_text_embedding(self, text: str, negative_text: str = "") -> Tensor:
def compute_clip_text_embedding(self, text: str | list[str], negative_text: str | list[str] = "") -> Tensor:
"""Compute the CLIP text embedding associated with the given prompt and negative prompt.
Args:
@ -76,12 +76,14 @@ class StableDiffusion_1(LatentDiffusionModel):
negative_text: The negative prompt to compute the CLIP text embedding of.
If not provided, the negative prompt is assumed to be empty (i.e., `""`).
"""
conditional_embedding = self.clip_text_encoder(text)
if text == negative_text:
return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0)
text = [text] if isinstance(text, str) else text
negative_text = [negative_text] if isinstance(negative_text, str) else negative_text
assert len(text) == len(negative_text), "The length of the text list and negative_text should be the same"
negative_embedding = self.clip_text_encoder(negative_text or "")
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0)
conditional_embedding = self.clip_text_encoder(text)
negative_embedding = self.clip_text_encoder(negative_text)
return torch.cat((negative_embedding, conditional_embedding))
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
"""Set the various context parameters required by the U-Net model.

View file

@ -65,7 +65,9 @@ class StableDiffusion_XL(LatentDiffusionModel):
dtype=dtype,
)
def compute_clip_text_embedding(self, text: str, negative_text: str | None = None) -> tuple[Tensor, Tensor]:
def compute_clip_text_embedding(
self, text: str | list[str], negative_text: str | list[str] = ""
) -> tuple[Tensor, Tensor]:
"""Compute the CLIP text embedding associated with the given prompt and negative prompt.
Args:
@ -73,14 +75,13 @@ class StableDiffusion_XL(LatentDiffusionModel):
negative_text: The negative prompt to compute the CLIP text embedding of.
If not provided, the negative prompt is assumed to be empty (i.e., `""`).
"""
conditional_embedding, conditional_pooled_embedding = self.clip_text_encoder(text)
if text == negative_text:
return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0), torch.cat(
tensors=(conditional_pooled_embedding, conditional_pooled_embedding), dim=0
)
# TODO: when negative_text is None, use zero tensor?
negative_embedding, negative_pooled_embedding = self.clip_text_encoder(negative_text or "")
text = [text] if isinstance(text, str) else text
negative_text = [negative_text] if isinstance(negative_text, str) else negative_text
assert len(text) == len(negative_text), "The length of the text list and negative_text should be the same"
conditional_embedding, conditional_pooled_embedding = self.clip_text_encoder(text)
negative_embedding, negative_pooled_embedding = self.clip_text_encoder(negative_text)
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0), torch.cat(
tensors=(negative_pooled_embedding, conditional_pooled_embedding), dim=0

View file

@ -1,7 +1,7 @@
from typing import cast
from jaxtyping import Float
from torch import Tensor, cat, device as Device, dtype as DType
from torch import Tensor, cat, device as Device, dtype as DType, split
import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter
@ -40,7 +40,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
def init_context(self) -> Contexts:
return {"text_encoder_pooling": {"end_of_text_index": []}}
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 1280"], Float[Tensor, "1 1280"]]:
def __call__(self, text: str | list[str]) -> tuple[Float[Tensor, "batch 77 1280"], Float[Tensor, "batch 1280"]]:
return super().__call__(text)
@property
@ -48,13 +48,14 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
return self.ensure_find(CLIPTokenizer)
def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None:
position = (tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item()
end_of_text_index.append(cast(int, position))
for str_tokens in split(tokens, 1):
position = (str_tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item() # type: ignore
end_of_text_index.append(cast(int, position))
def pool(self, x: Float[Tensor, "1 77 1280"]) -> Float[Tensor, "1 1280"]:
def pool(self, x: Float[Tensor, "batch 77 1280"]) -> Float[Tensor, "batch 1280"]:
end_of_text_index = self.use_context(context_name="text_encoder_pooling").get("end_of_text_index", [])
assert len(end_of_text_index) == 1, "End of text index not found."
return x[:, end_of_text_index[0], :]
assert len(end_of_text_index) == x.shape[0], "End of text index not found."
return cat([x[i : i + 1, end_of_text_index[i], :] for i in range(x.shape[0])], dim=0)
class DoubleTextEncoder(fl.Chain):
@ -75,7 +76,7 @@ class DoubleTextEncoder(fl.Chain):
tep = TextEncoderWithPooling(target=text_encoder_g, projection=projection)
tep.inject(self.layer("Parallel", fl.Parallel))
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]:
def __call__(self, text: str | list[str]) -> tuple[Float[Tensor, "batch 77 2048"], Float[Tensor, "batch 1280"]]:
return super().__call__(text)
def concatenate_embeddings(

View file

@ -757,6 +757,60 @@ def test_diffusion_std_random_init(
ensure_similar_images(predicted_image, expected_image_std_random_init)
@no_grad()
def test_diffusion_batch2(sd15_std: StableDiffusion_1):
sd15 = sd15_std
prompt1 = "a cute cat, detailed high-quality professional image"
negative_prompt1 = "lowres, bad anatomy, bad hands, cropped, worst quality"
prompt2 = "a cute dog"
negative_prompt2 = "lowres, bad anatomy, bad hands"
clip_text_embedding_b2 = sd15.compute_clip_text_embedding(
text=[prompt1, prompt2], negative_text=[negative_prompt1, negative_prompt2]
)
step = sd15.steps[0]
manual_seed(2)
rand_b2 = torch.randn(2, 4, 64, 64, device=sd15.device)
x_b2 = sd15(
rand_b2,
step=step,
clip_text_embedding=clip_text_embedding_b2,
condition_scale=7.5,
)
assert x_b2.shape == (2, 4, 64, 64)
rand_1 = rand_b2[0:1]
clip_text_embedding_1 = sd15.compute_clip_text_embedding(text=[prompt1], negative_text=[negative_prompt1])
x_1 = sd15(
rand_1,
step=step,
clip_text_embedding=clip_text_embedding_1,
condition_scale=7.5,
)
rand_2 = rand_b2[1:2]
clip_text_embedding_2 = sd15.compute_clip_text_embedding(text=[prompt2], negative_text=[negative_prompt2])
x_2 = sd15(
rand_2,
step=step,
clip_text_embedding=clip_text_embedding_2,
condition_scale=7.5,
)
# The 5e-3 tolerance is detailed in https://github.com/finegrain-ai/refiners/pull/263#issuecomment-1956404911
assert torch.allclose(
x_b2[0], x_1[0], atol=5e-3, rtol=0
), f"Batch 2 and batch1 output should be the same and are distant of {torch.max((x_b2[0] - x_1[0]).abs()).item()}"
assert torch.allclose(
x_b2[1], x_2[0], atol=5e-3, rtol=0
), f"Batch 2 and batch1 output should be the same and are distant of {torch.max((x_b2[1] - x_2[0]).abs()).item()}"
@no_grad()
def test_diffusion_std_random_init_euler(
sd15_euler: StableDiffusion_1, expected_image_std_random_init_euler: Image.Image, test_device: torch.device
@ -836,7 +890,6 @@ def test_diffusion_std_random_init_float16(
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98)
@ -1265,6 +1318,68 @@ def test_diffusion_lora(
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_diffusion_sdxl_batch2(sdxl_ddim: StableDiffusion_XL) -> None:
sdxl = sdxl_ddim
prompt1 = "a cute cat, detailed high-quality professional image"
negative_prompt1 = "lowres, bad anatomy, bad hands, cropped, worst quality"
prompt2 = "a cute dog"
negative_prompt2 = "lowres, bad anatomy, bad hands"
clip_text_embedding_b2, pooled_text_embedding_b2 = sdxl.compute_clip_text_embedding(
text=[prompt1, prompt2], negative_text=[negative_prompt1, negative_prompt2]
)
time_ids = sdxl.default_time_ids
time_ids_b2 = sdxl.default_time_ids.repeat(2, 1)
manual_seed(seed=2)
x_b2 = torch.randn(2, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype)
x_1 = x_b2[0:1]
x_2 = x_b2[1:2]
x_b2 = sdxl(
x_b2,
step=sdxl.steps[0],
clip_text_embedding=clip_text_embedding_b2,
pooled_text_embedding=pooled_text_embedding_b2,
time_ids=time_ids_b2,
)
clip_text_embedding_1, pooled_text_embedding_1 = sdxl.compute_clip_text_embedding(
text=prompt1, negative_text=negative_prompt1
)
x_1 = sdxl(
x_1,
step=sdxl.steps[0],
clip_text_embedding=clip_text_embedding_1,
pooled_text_embedding=pooled_text_embedding_1,
time_ids=time_ids,
)
clip_text_embedding_2, pooled_text_embedding_2 = sdxl.compute_clip_text_embedding(
text=prompt2, negative_text=negative_prompt2
)
x_2 = sdxl(
x_2,
step=sdxl.steps[0],
clip_text_embedding=clip_text_embedding_2,
pooled_text_embedding=pooled_text_embedding_2,
time_ids=time_ids,
)
# The 5e-3 tolerance is detailed in https://github.com/finegrain-ai/refiners/pull/263#issuecomment-1956404911
assert torch.allclose(
x_b2[0], x_1[0], atol=5e-3, rtol=0
), f"Batch 2 and batch1 output should be the same and are distant of {torch.max((x_b2[0] - x_1[0]).abs()).item()}"
assert torch.allclose(
x_b2[1], x_2[0], atol=5e-3, rtol=0
), f"Batch 2 and batch1 output should be the same and are distant of {torch.max((x_b2[1] - x_2[0]).abs()).item()}"
@no_grad()
def test_diffusion_sdxl_lora(
sdxl_ddim: StableDiffusion_XL,

View file

@ -100,3 +100,24 @@ def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder:
assert torch.allclose(input=negative_double_embedding, other=negative_prompt_embeds, rtol=1e-3, atol=1e-3)
assert torch.allclose(input=negative_pooled_embedding, other=negative_pooled_prompt_embeds, rtol=1e-3, atol=1e-3)
@no_grad()
def test_double_text_encoder_batch2(double_text_encoder: DoubleTextEncoder) -> None:
manual_seed(seed=0)
prompt1 = "A photo of a pizza."
prompt2 = "A giant duck."
double_embedding_b2, pooled_embedding_b2 = double_text_encoder([prompt1, prompt2])
assert double_embedding_b2.shape == torch.Size([2, 77, 2048])
assert pooled_embedding_b2.shape == torch.Size([2, 1280])
double_embedding_1, pooled_embedding_1 = double_text_encoder(prompt1)
double_embedding_2, pooled_embedding_2 = double_text_encoder(prompt2)
assert torch.allclose(input=double_embedding_1, other=double_embedding_b2[0:1], rtol=1e-3, atol=1e-3)
assert torch.allclose(input=pooled_embedding_1, other=pooled_embedding_b2[0:1], rtol=1e-3, atol=1e-3)
assert torch.allclose(input=double_embedding_2, other=double_embedding_b2[1:2], rtol=1e-3, atol=1e-3)
assert torch.allclose(input=pooled_embedding_2, other=pooled_embedding_b2[1:2], rtol=1e-3, atol=1e-3)