From 3565a4127fada4e330054d221793314e0a21329b Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Thu, 17 Aug 2023 18:34:56 +0200 Subject: [PATCH] implement DoubleTextEncoder for SDXL --- .../foundationals/clip/text_encoder.py | 1 - .../latent_diffusion/sdxl_text_encoder.py | 88 +++++++++++++++ tests/foundationals/clip/test_text_encoder.py | 2 +- .../test_sdxl_double_encoder.py | 102 ++++++++++++++++++ 4 files changed, 191 insertions(+), 2 deletions(-) create mode 100644 src/refiners/foundationals/latent_diffusion/sdxl_text_encoder.py create mode 100644 tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py diff --git a/src/refiners/foundationals/clip/text_encoder.py b/src/refiners/foundationals/clip/text_encoder.py index 6a85101..9e43411 100644 --- a/src/refiners/foundationals/clip/text_encoder.py +++ b/src/refiners/foundationals/clip/text_encoder.py @@ -255,7 +255,6 @@ class CLIPTextEncoderG(CLIPTextEncoder): num_layers=32, num_attention_heads=20, feedforward_dim=5120, - use_quick_gelu=True, tokenizer=tokenizer, device=device, dtype=dtype, diff --git a/src/refiners/foundationals/latent_diffusion/sdxl_text_encoder.py b/src/refiners/foundationals/latent_diffusion/sdxl_text_encoder.py new file mode 100644 index 0000000..4123b11 --- /dev/null +++ b/src/refiners/foundationals/latent_diffusion/sdxl_text_encoder.py @@ -0,0 +1,88 @@ +from typing import cast +from torch import device as Device, dtype as DType, Tensor, cat +from refiners.adapters.adapter import Adapter +from refiners.fluxion.context import Contexts +import refiners.fluxion.layers as fl +from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextEncoderL +from jaxtyping import Float + +from refiners.foundationals.clip.tokenizer import CLIPTokenizer + + +class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]): + def __init__( + self, + target: CLIPTextEncoderG, + projection: fl.Linear | None = None, + ) -> None: + with self.setup_adapter(target=target): + tokenizer = target.find(layer_type=CLIPTokenizer) + assert tokenizer is not None, "Tokenizer not found." + super().__init__( + tokenizer, + fl.SetContext( + context="text_encoder_pooling", key="end_of_text_index", callback=self.set_end_of_text_index + ), + target[1:-2], + fl.Parallel( + fl.Identity(), + fl.Chain( + target[-2:], + projection + or fl.Linear( + in_features=1280, out_features=1280, bias=False, device=target.device, dtype=target.dtype + ), + fl.Lambda(func=self.pool), + ), + ), + ) + + def init_context(self) -> Contexts: + return {"text_encoder_pooling": {"end_of_text_index": []}} + + def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 1280"], Float[Tensor, "1 1280"]]: + return super().__call__(text) + + @property + def tokenizer(self) -> CLIPTokenizer: + tokenizer = self.find(layer_type=CLIPTokenizer) + assert tokenizer is not None, "Tokenizer not found." + return tokenizer + + def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None: + position = (tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item() + end_of_text_index.append(cast(int, position)) + + def pool(self, x: Float[Tensor, "1 77 1280"]) -> Float[Tensor, "1 1280"]: + end_of_text_index = self.use_context(context_name="text_encoder_pooling").get("end_of_text_index", []) + assert len(end_of_text_index) == 1, "End of text index not found." + return x[:, end_of_text_index[0], :] + + +class DoubleTextEncoder(fl.Chain): + def __init__( + self, + text_encoder_l: CLIPTextEncoderL | None = None, + text_encoder_g: CLIPTextEncoderG | None = None, + projection: fl.Linear | None = None, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + text_encoder_l = text_encoder_l or CLIPTextEncoderL(device=device, dtype=dtype) + text_encoder_g = text_encoder_g or CLIPTextEncoderG(device=device, dtype=dtype) + text_encoder_with_pooling = TextEncoderWithPooling(target=text_encoder_g, projection=projection) + super().__init__( + fl.Parallel(text_encoder_l[:-2], text_encoder_g), + fl.Lambda(func=self.concatenate_embeddings), + ) + text_encoder_with_pooling.inject(parent=self.Parallel) + + def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]: + return super().__call__(text) + + def concatenate_embeddings( + self, text_embedding_l: Tensor, text_embedding_with_pooling: tuple[Tensor, Tensor] + ) -> tuple[Tensor, Tensor]: + text_embedding_g, pooled_text_embedding = text_embedding_with_pooling + text_embedding = cat(tensors=[text_embedding_l, text_embedding_g], dim=-1) + return text_embedding, pooled_text_embedding diff --git a/tests/foundationals/clip/test_text_encoder.py b/tests/foundationals/clip/test_text_encoder.py index 75d2fe4..b5d4693 100644 --- a/tests/foundationals/clip/test_text_encoder.py +++ b/tests/foundationals/clip/test_text_encoder.py @@ -7,7 +7,7 @@ from pathlib import Path from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.fluxion.utils import load_from_safetensors -import transformers # type: ignore +import transformers # type: ignore from refiners.foundationals.clip.tokenizer import CLIPTokenizer diff --git a/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py b/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py new file mode 100644 index 0000000..9bcc8c6 --- /dev/null +++ b/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py @@ -0,0 +1,102 @@ +from typing import Any, Protocol, cast +from pathlib import Path +from warnings import warn +import pytest +import torch +from torch import Tensor + +from refiners.fluxion.utils import manual_seed +import refiners.fluxion.layers as fl +from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextEncoderL +from refiners.foundationals.latent_diffusion.sdxl_text_encoder import DoubleTextEncoder + + +class DiffusersSDXL(Protocol): + unet: fl.Module + text_encoder: fl.Module + text_encoder_2: fl.Module + tokenizer: fl.Module + tokenizer_2: fl.Module + vae: fl.Module + + def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any: + ... + + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ... + + +@pytest.fixture(scope="module") +def stabilityai_sdxl_base_path(test_weights_path: Path) -> Path: + r = test_weights_path / "stabilityai" / "stable-diffusion-xl-base-1.0" + if not r.is_dir(): + warn(message=f"could not find Stability SDXL base weights at {r}, skipping") + pytest.skip(allow_module_level=True) + return r + + +@pytest.fixture(scope="module") +def diffusers_sdxl(stabilityai_sdxl_base_path: Path) -> Any: + from diffusers import DiffusionPipeline # type: ignore + + return DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=stabilityai_sdxl_base_path) # type: ignore + + +@pytest.fixture(scope="module") +def double_text_encoder(test_weights_path: Path) -> DoubleTextEncoder: + text_encoder_l = CLIPTextEncoderL() + text_encoder_g_with_projection = CLIPTextEncoderG() + text_encoder_g_with_projection.append(module=fl.Linear(in_features=1280, out_features=1280, bias=False)) + + text_encoder_l_path = test_weights_path / "CLIPTextEncoderL.safetensors" + text_encdoer_g_path = test_weights_path / "CLIPTextEncoderGWithProjection.safetensors" + + text_encoder_l.load_from_safetensors(tensors_path=text_encoder_l_path) + text_encoder_g_with_projection.load_from_safetensors(tensors_path=text_encdoer_g_path) + + linear = text_encoder_g_with_projection.pop(index=-1) + assert isinstance(linear, fl.Linear) + + double_text_encoder = DoubleTextEncoder( + text_encoder_l=text_encoder_l, text_encoder_g=text_encoder_g_with_projection, projection=linear + ) + + return double_text_encoder + + +@torch.no_grad() +def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder: DoubleTextEncoder) -> None: + manual_seed(seed=0) + prompt = "A photo of a pizza." + + prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = ( + diffusers_sdxl.encode_prompt(prompt=prompt, negative_prompt="") + ) + + double_embedding, pooled_embedding = double_text_encoder(prompt) + + assert double_embedding.shape == torch.Size([1, 77, 2048]) + assert pooled_embedding.shape == torch.Size([1, 1280]) + + embedding_1, embedding_2 = cast( + tuple[Tensor, Tensor], prompt_embeds.split(split_size=[768, 1280], dim=-1) # type: ignore + ) + + rembedding_1, rembedding_2 = cast( + tuple[Tensor, Tensor], double_embedding.split(split_size=[768, 1280], dim=-1) # type: ignore + ) + + assert torch.allclose(input=embedding_1, other=rembedding_1, rtol=1e-3, atol=1e-3) + assert torch.allclose(input=embedding_2, other=rembedding_2, rtol=1e-3, atol=1e-3) + assert torch.allclose(input=pooled_embedding, other=pooled_prompt_embeds, rtol=1e-3, atol=1e-3) + + negative_double_embedding, negative_pooled_embedding = double_text_encoder("") + + assert torch.allclose(input=negative_double_embedding, other=negative_prompt_embeds, rtol=1e-3, atol=1e-3) + assert torch.allclose(input=negative_pooled_embedding, other=negative_pooled_prompt_embeds, rtol=1e-3, atol=1e-3)