diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py index e406aec..b48eb1f 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py @@ -68,7 +68,7 @@ class StableDiffusion_1(LatentDiffusionModel): dtype=dtype, ) - def compute_clip_text_embedding(self, text: str, negative_text: str = "") -> Tensor: + def compute_clip_text_embedding(self, text: str | list[str], negative_text: str | list[str] = "") -> Tensor: """Compute the CLIP text embedding associated with the given prompt and negative prompt. Args: @@ -76,12 +76,14 @@ class StableDiffusion_1(LatentDiffusionModel): negative_text: The negative prompt to compute the CLIP text embedding of. If not provided, the negative prompt is assumed to be empty (i.e., `""`). """ - conditional_embedding = self.clip_text_encoder(text) - if text == negative_text: - return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0) + text = [text] if isinstance(text, str) else text + negative_text = [negative_text] if isinstance(negative_text, str) else negative_text + assert len(text) == len(negative_text), "The length of the text list and negative_text should be the same" - negative_embedding = self.clip_text_encoder(negative_text or "") - return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0) + conditional_embedding = self.clip_text_encoder(text) + negative_embedding = self.clip_text_encoder(negative_text) + + return torch.cat((negative_embedding, conditional_embedding)) def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: """Set the various context parameters required by the U-Net model. diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py index 606448c..7899365 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py @@ -65,7 +65,9 @@ class StableDiffusion_XL(LatentDiffusionModel): dtype=dtype, ) - def compute_clip_text_embedding(self, text: str, negative_text: str | None = None) -> tuple[Tensor, Tensor]: + def compute_clip_text_embedding( + self, text: str | list[str], negative_text: str | list[str] = "" + ) -> tuple[Tensor, Tensor]: """Compute the CLIP text embedding associated with the given prompt and negative prompt. Args: @@ -73,14 +75,13 @@ class StableDiffusion_XL(LatentDiffusionModel): negative_text: The negative prompt to compute the CLIP text embedding of. If not provided, the negative prompt is assumed to be empty (i.e., `""`). """ - conditional_embedding, conditional_pooled_embedding = self.clip_text_encoder(text) - if text == negative_text: - return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0), torch.cat( - tensors=(conditional_pooled_embedding, conditional_pooled_embedding), dim=0 - ) - # TODO: when negative_text is None, use zero tensor? - negative_embedding, negative_pooled_embedding = self.clip_text_encoder(negative_text or "") + text = [text] if isinstance(text, str) else text + negative_text = [negative_text] if isinstance(negative_text, str) else negative_text + assert len(text) == len(negative_text), "The length of the text list and negative_text should be the same" + + conditional_embedding, conditional_pooled_embedding = self.clip_text_encoder(text) + negative_embedding, negative_pooled_embedding = self.clip_text_encoder(negative_text) return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0), torch.cat( tensors=(negative_pooled_embedding, conditional_pooled_embedding), dim=0 diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py index ec17507..39051a8 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py @@ -1,7 +1,7 @@ from typing import cast from jaxtyping import Float -from torch import Tensor, cat, device as Device, dtype as DType +from torch import Tensor, cat, device as Device, dtype as DType, split import refiners.fluxion.layers as fl from refiners.fluxion.adapters.adapter import Adapter @@ -40,7 +40,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]): def init_context(self) -> Contexts: return {"text_encoder_pooling": {"end_of_text_index": []}} - def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 1280"], Float[Tensor, "1 1280"]]: + def __call__(self, text: str | list[str]) -> tuple[Float[Tensor, "batch 77 1280"], Float[Tensor, "batch 1280"]]: return super().__call__(text) @property @@ -48,13 +48,14 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]): return self.ensure_find(CLIPTokenizer) def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None: - position = (tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item() - end_of_text_index.append(cast(int, position)) + for str_tokens in split(tokens, 1): + position = (str_tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item() # type: ignore + end_of_text_index.append(cast(int, position)) - def pool(self, x: Float[Tensor, "1 77 1280"]) -> Float[Tensor, "1 1280"]: + def pool(self, x: Float[Tensor, "batch 77 1280"]) -> Float[Tensor, "batch 1280"]: end_of_text_index = self.use_context(context_name="text_encoder_pooling").get("end_of_text_index", []) - assert len(end_of_text_index) == 1, "End of text index not found." - return x[:, end_of_text_index[0], :] + assert len(end_of_text_index) == x.shape[0], "End of text index not found." + return cat([x[i : i + 1, end_of_text_index[i], :] for i in range(x.shape[0])], dim=0) class DoubleTextEncoder(fl.Chain): @@ -75,7 +76,7 @@ class DoubleTextEncoder(fl.Chain): tep = TextEncoderWithPooling(target=text_encoder_g, projection=projection) tep.inject(self.layer("Parallel", fl.Parallel)) - def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]: + def __call__(self, text: str | list[str]) -> tuple[Float[Tensor, "batch 77 2048"], Float[Tensor, "batch 1280"]]: return super().__call__(text) def concatenate_embeddings( diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 9631d83..95aac98 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -757,6 +757,60 @@ def test_diffusion_std_random_init( ensure_similar_images(predicted_image, expected_image_std_random_init) +@no_grad() +def test_diffusion_batch2(sd15_std: StableDiffusion_1): + sd15 = sd15_std + + prompt1 = "a cute cat, detailed high-quality professional image" + negative_prompt1 = "lowres, bad anatomy, bad hands, cropped, worst quality" + prompt2 = "a cute dog" + negative_prompt2 = "lowres, bad anatomy, bad hands" + + clip_text_embedding_b2 = sd15.compute_clip_text_embedding( + text=[prompt1, prompt2], negative_text=[negative_prompt1, negative_prompt2] + ) + + step = sd15.steps[0] + + manual_seed(2) + rand_b2 = torch.randn(2, 4, 64, 64, device=sd15.device) + + x_b2 = sd15( + rand_b2, + step=step, + clip_text_embedding=clip_text_embedding_b2, + condition_scale=7.5, + ) + + assert x_b2.shape == (2, 4, 64, 64) + + rand_1 = rand_b2[0:1] + clip_text_embedding_1 = sd15.compute_clip_text_embedding(text=[prompt1], negative_text=[negative_prompt1]) + x_1 = sd15( + rand_1, + step=step, + clip_text_embedding=clip_text_embedding_1, + condition_scale=7.5, + ) + + rand_2 = rand_b2[1:2] + clip_text_embedding_2 = sd15.compute_clip_text_embedding(text=[prompt2], negative_text=[negative_prompt2]) + x_2 = sd15( + rand_2, + step=step, + clip_text_embedding=clip_text_embedding_2, + condition_scale=7.5, + ) + + # The 5e-3 tolerance is detailed in https://github.com/finegrain-ai/refiners/pull/263#issuecomment-1956404911 + assert torch.allclose( + x_b2[0], x_1[0], atol=5e-3, rtol=0 + ), f"Batch 2 and batch1 output should be the same and are distant of {torch.max((x_b2[0] - x_1[0]).abs()).item()}" + assert torch.allclose( + x_b2[1], x_2[0], atol=5e-3, rtol=0 + ), f"Batch 2 and batch1 output should be the same and are distant of {torch.max((x_b2[1] - x_2[0]).abs()).item()}" + + @no_grad() def test_diffusion_std_random_init_euler( sd15_euler: StableDiffusion_1, expected_image_std_random_init_euler: Image.Image, test_device: torch.device @@ -836,7 +890,6 @@ def test_diffusion_std_random_init_float16( condition_scale=7.5, ) predicted_image = sd15.lda.latents_to_image(x) - ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98) @@ -1265,6 +1318,68 @@ def test_diffusion_lora( ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) +@no_grad() +def test_diffusion_sdxl_batch2(sdxl_ddim: StableDiffusion_XL) -> None: + sdxl = sdxl_ddim + + prompt1 = "a cute cat, detailed high-quality professional image" + negative_prompt1 = "lowres, bad anatomy, bad hands, cropped, worst quality" + prompt2 = "a cute dog" + negative_prompt2 = "lowres, bad anatomy, bad hands" + + clip_text_embedding_b2, pooled_text_embedding_b2 = sdxl.compute_clip_text_embedding( + text=[prompt1, prompt2], negative_text=[negative_prompt1, negative_prompt2] + ) + + time_ids = sdxl.default_time_ids + time_ids_b2 = sdxl.default_time_ids.repeat(2, 1) + + manual_seed(seed=2) + x_b2 = torch.randn(2, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype) + x_1 = x_b2[0:1] + x_2 = x_b2[1:2] + + x_b2 = sdxl( + x_b2, + step=sdxl.steps[0], + clip_text_embedding=clip_text_embedding_b2, + pooled_text_embedding=pooled_text_embedding_b2, + time_ids=time_ids_b2, + ) + + clip_text_embedding_1, pooled_text_embedding_1 = sdxl.compute_clip_text_embedding( + text=prompt1, negative_text=negative_prompt1 + ) + + x_1 = sdxl( + x_1, + step=sdxl.steps[0], + clip_text_embedding=clip_text_embedding_1, + pooled_text_embedding=pooled_text_embedding_1, + time_ids=time_ids, + ) + + clip_text_embedding_2, pooled_text_embedding_2 = sdxl.compute_clip_text_embedding( + text=prompt2, negative_text=negative_prompt2 + ) + + x_2 = sdxl( + x_2, + step=sdxl.steps[0], + clip_text_embedding=clip_text_embedding_2, + pooled_text_embedding=pooled_text_embedding_2, + time_ids=time_ids, + ) + + # The 5e-3 tolerance is detailed in https://github.com/finegrain-ai/refiners/pull/263#issuecomment-1956404911 + assert torch.allclose( + x_b2[0], x_1[0], atol=5e-3, rtol=0 + ), f"Batch 2 and batch1 output should be the same and are distant of {torch.max((x_b2[0] - x_1[0]).abs()).item()}" + assert torch.allclose( + x_b2[1], x_2[0], atol=5e-3, rtol=0 + ), f"Batch 2 and batch1 output should be the same and are distant of {torch.max((x_b2[1] - x_2[0]).abs()).item()}" + + @no_grad() def test_diffusion_sdxl_lora( sdxl_ddim: StableDiffusion_XL, diff --git a/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py b/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py index 9435b89..4563c15 100644 --- a/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py +++ b/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py @@ -100,3 +100,24 @@ def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder: assert torch.allclose(input=negative_double_embedding, other=negative_prompt_embeds, rtol=1e-3, atol=1e-3) assert torch.allclose(input=negative_pooled_embedding, other=negative_pooled_prompt_embeds, rtol=1e-3, atol=1e-3) + + +@no_grad() +def test_double_text_encoder_batch2(double_text_encoder: DoubleTextEncoder) -> None: + manual_seed(seed=0) + prompt1 = "A photo of a pizza." + prompt2 = "A giant duck." + + double_embedding_b2, pooled_embedding_b2 = double_text_encoder([prompt1, prompt2]) + + assert double_embedding_b2.shape == torch.Size([2, 77, 2048]) + assert pooled_embedding_b2.shape == torch.Size([2, 1280]) + + double_embedding_1, pooled_embedding_1 = double_text_encoder(prompt1) + double_embedding_2, pooled_embedding_2 = double_text_encoder(prompt2) + + assert torch.allclose(input=double_embedding_1, other=double_embedding_b2[0:1], rtol=1e-3, atol=1e-3) + assert torch.allclose(input=pooled_embedding_1, other=pooled_embedding_b2[0:1], rtol=1e-3, atol=1e-3) + + assert torch.allclose(input=double_embedding_2, other=double_embedding_b2[1:2], rtol=1e-3, atol=1e-3) + assert torch.allclose(input=pooled_embedding_2, other=pooled_embedding_b2[1:2], rtol=1e-3, atol=1e-3)