implement DoubleTextEncoder for SDXL

This commit is contained in:
limiteinductive 2023-08-17 18:34:56 +02:00 committed by Benjamin Trom
parent 71ddb55a8e
commit 3565a4127f
4 changed files with 191 additions and 2 deletions

View file

@ -255,7 +255,6 @@ class CLIPTextEncoderG(CLIPTextEncoder):
num_layers=32,
num_attention_heads=20,
feedforward_dim=5120,
use_quick_gelu=True,
tokenizer=tokenizer,
device=device,
dtype=dtype,

View file

@ -0,0 +1,88 @@
from typing import cast
from torch import device as Device, dtype as DType, Tensor, cat
from refiners.adapters.adapter import Adapter
from refiners.fluxion.context import Contexts
import refiners.fluxion.layers as fl
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextEncoderL
from jaxtyping import Float
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
def __init__(
self,
target: CLIPTextEncoderG,
projection: fl.Linear | None = None,
) -> None:
with self.setup_adapter(target=target):
tokenizer = target.find(layer_type=CLIPTokenizer)
assert tokenizer is not None, "Tokenizer not found."
super().__init__(
tokenizer,
fl.SetContext(
context="text_encoder_pooling", key="end_of_text_index", callback=self.set_end_of_text_index
),
target[1:-2],
fl.Parallel(
fl.Identity(),
fl.Chain(
target[-2:],
projection
or fl.Linear(
in_features=1280, out_features=1280, bias=False, device=target.device, dtype=target.dtype
),
fl.Lambda(func=self.pool),
),
),
)
def init_context(self) -> Contexts:
return {"text_encoder_pooling": {"end_of_text_index": []}}
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 1280"], Float[Tensor, "1 1280"]]:
return super().__call__(text)
@property
def tokenizer(self) -> CLIPTokenizer:
tokenizer = self.find(layer_type=CLIPTokenizer)
assert tokenizer is not None, "Tokenizer not found."
return tokenizer
def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None:
position = (tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item()
end_of_text_index.append(cast(int, position))
def pool(self, x: Float[Tensor, "1 77 1280"]) -> Float[Tensor, "1 1280"]:
end_of_text_index = self.use_context(context_name="text_encoder_pooling").get("end_of_text_index", [])
assert len(end_of_text_index) == 1, "End of text index not found."
return x[:, end_of_text_index[0], :]
class DoubleTextEncoder(fl.Chain):
def __init__(
self,
text_encoder_l: CLIPTextEncoderL | None = None,
text_encoder_g: CLIPTextEncoderG | None = None,
projection: fl.Linear | None = None,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
text_encoder_l = text_encoder_l or CLIPTextEncoderL(device=device, dtype=dtype)
text_encoder_g = text_encoder_g or CLIPTextEncoderG(device=device, dtype=dtype)
text_encoder_with_pooling = TextEncoderWithPooling(target=text_encoder_g, projection=projection)
super().__init__(
fl.Parallel(text_encoder_l[:-2], text_encoder_g),
fl.Lambda(func=self.concatenate_embeddings),
)
text_encoder_with_pooling.inject(parent=self.Parallel)
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]:
return super().__call__(text)
def concatenate_embeddings(
self, text_embedding_l: Tensor, text_embedding_with_pooling: tuple[Tensor, Tensor]
) -> tuple[Tensor, Tensor]:
text_embedding_g, pooled_text_embedding = text_embedding_with_pooling
text_embedding = cat(tensors=[text_embedding_l, text_embedding_g], dim=-1)
return text_embedding, pooled_text_embedding

View file

@ -7,7 +7,7 @@ from pathlib import Path
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.fluxion.utils import load_from_safetensors
import transformers # type: ignore
import transformers # type: ignore
from refiners.foundationals.clip.tokenizer import CLIPTokenizer

View file

@ -0,0 +1,102 @@
from typing import Any, Protocol, cast
from pathlib import Path
from warnings import warn
import pytest
import torch
from torch import Tensor
from refiners.fluxion.utils import manual_seed
import refiners.fluxion.layers as fl
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextEncoderL
from refiners.foundationals.latent_diffusion.sdxl_text_encoder import DoubleTextEncoder
class DiffusersSDXL(Protocol):
unet: fl.Module
text_encoder: fl.Module
text_encoder_2: fl.Module
tokenizer: fl.Module
tokenizer_2: fl.Module
vae: fl.Module
def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any:
...
def encode_prompt(
self,
prompt: str,
prompt_2: str | None = None,
negative_prompt: str | None = None,
negative_prompt_2: str | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
...
@pytest.fixture(scope="module")
def stabilityai_sdxl_base_path(test_weights_path: Path) -> Path:
r = test_weights_path / "stabilityai" / "stable-diffusion-xl-base-1.0"
if not r.is_dir():
warn(message=f"could not find Stability SDXL base weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture(scope="module")
def diffusers_sdxl(stabilityai_sdxl_base_path: Path) -> Any:
from diffusers import DiffusionPipeline # type: ignore
return DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=stabilityai_sdxl_base_path) # type: ignore
@pytest.fixture(scope="module")
def double_text_encoder(test_weights_path: Path) -> DoubleTextEncoder:
text_encoder_l = CLIPTextEncoderL()
text_encoder_g_with_projection = CLIPTextEncoderG()
text_encoder_g_with_projection.append(module=fl.Linear(in_features=1280, out_features=1280, bias=False))
text_encoder_l_path = test_weights_path / "CLIPTextEncoderL.safetensors"
text_encdoer_g_path = test_weights_path / "CLIPTextEncoderGWithProjection.safetensors"
text_encoder_l.load_from_safetensors(tensors_path=text_encoder_l_path)
text_encoder_g_with_projection.load_from_safetensors(tensors_path=text_encdoer_g_path)
linear = text_encoder_g_with_projection.pop(index=-1)
assert isinstance(linear, fl.Linear)
double_text_encoder = DoubleTextEncoder(
text_encoder_l=text_encoder_l, text_encoder_g=text_encoder_g_with_projection, projection=linear
)
return double_text_encoder
@torch.no_grad()
def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder: DoubleTextEncoder) -> None:
manual_seed(seed=0)
prompt = "A photo of a pizza."
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = (
diffusers_sdxl.encode_prompt(prompt=prompt, negative_prompt="")
)
double_embedding, pooled_embedding = double_text_encoder(prompt)
assert double_embedding.shape == torch.Size([1, 77, 2048])
assert pooled_embedding.shape == torch.Size([1, 1280])
embedding_1, embedding_2 = cast(
tuple[Tensor, Tensor], prompt_embeds.split(split_size=[768, 1280], dim=-1) # type: ignore
)
rembedding_1, rembedding_2 = cast(
tuple[Tensor, Tensor], double_embedding.split(split_size=[768, 1280], dim=-1) # type: ignore
)
assert torch.allclose(input=embedding_1, other=rembedding_1, rtol=1e-3, atol=1e-3)
assert torch.allclose(input=embedding_2, other=rembedding_2, rtol=1e-3, atol=1e-3)
assert torch.allclose(input=pooled_embedding, other=pooled_prompt_embeds, rtol=1e-3, atol=1e-3)
negative_double_embedding, negative_pooled_embedding = double_text_encoder("")
assert torch.allclose(input=negative_double_embedding, other=negative_prompt_embeds, rtol=1e-3, atol=1e-3)
assert torch.allclose(input=negative_pooled_embedding, other=negative_pooled_prompt_embeds, rtol=1e-3, atol=1e-3)