mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
implement DoubleTextEncoder for SDXL
This commit is contained in:
parent
71ddb55a8e
commit
3565a4127f
|
@ -255,7 +255,6 @@ class CLIPTextEncoderG(CLIPTextEncoder):
|
|||
num_layers=32,
|
||||
num_attention_heads=20,
|
||||
feedforward_dim=5120,
|
||||
use_quick_gelu=True,
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
from typing import cast
|
||||
from torch import device as Device, dtype as DType, Tensor, cat
|
||||
from refiners.adapters.adapter import Adapter
|
||||
from refiners.fluxion.context import Contexts
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextEncoderL
|
||||
from jaxtyping import Float
|
||||
|
||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||
|
||||
|
||||
class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
|
||||
def __init__(
|
||||
self,
|
||||
target: CLIPTextEncoderG,
|
||||
projection: fl.Linear | None = None,
|
||||
) -> None:
|
||||
with self.setup_adapter(target=target):
|
||||
tokenizer = target.find(layer_type=CLIPTokenizer)
|
||||
assert tokenizer is not None, "Tokenizer not found."
|
||||
super().__init__(
|
||||
tokenizer,
|
||||
fl.SetContext(
|
||||
context="text_encoder_pooling", key="end_of_text_index", callback=self.set_end_of_text_index
|
||||
),
|
||||
target[1:-2],
|
||||
fl.Parallel(
|
||||
fl.Identity(),
|
||||
fl.Chain(
|
||||
target[-2:],
|
||||
projection
|
||||
or fl.Linear(
|
||||
in_features=1280, out_features=1280, bias=False, device=target.device, dtype=target.dtype
|
||||
),
|
||||
fl.Lambda(func=self.pool),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def init_context(self) -> Contexts:
|
||||
return {"text_encoder_pooling": {"end_of_text_index": []}}
|
||||
|
||||
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 1280"], Float[Tensor, "1 1280"]]:
|
||||
return super().__call__(text)
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> CLIPTokenizer:
|
||||
tokenizer = self.find(layer_type=CLIPTokenizer)
|
||||
assert tokenizer is not None, "Tokenizer not found."
|
||||
return tokenizer
|
||||
|
||||
def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None:
|
||||
position = (tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item()
|
||||
end_of_text_index.append(cast(int, position))
|
||||
|
||||
def pool(self, x: Float[Tensor, "1 77 1280"]) -> Float[Tensor, "1 1280"]:
|
||||
end_of_text_index = self.use_context(context_name="text_encoder_pooling").get("end_of_text_index", [])
|
||||
assert len(end_of_text_index) == 1, "End of text index not found."
|
||||
return x[:, end_of_text_index[0], :]
|
||||
|
||||
|
||||
class DoubleTextEncoder(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder_l: CLIPTextEncoderL | None = None,
|
||||
text_encoder_g: CLIPTextEncoderG | None = None,
|
||||
projection: fl.Linear | None = None,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
text_encoder_l = text_encoder_l or CLIPTextEncoderL(device=device, dtype=dtype)
|
||||
text_encoder_g = text_encoder_g or CLIPTextEncoderG(device=device, dtype=dtype)
|
||||
text_encoder_with_pooling = TextEncoderWithPooling(target=text_encoder_g, projection=projection)
|
||||
super().__init__(
|
||||
fl.Parallel(text_encoder_l[:-2], text_encoder_g),
|
||||
fl.Lambda(func=self.concatenate_embeddings),
|
||||
)
|
||||
text_encoder_with_pooling.inject(parent=self.Parallel)
|
||||
|
||||
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]:
|
||||
return super().__call__(text)
|
||||
|
||||
def concatenate_embeddings(
|
||||
self, text_embedding_l: Tensor, text_embedding_with_pooling: tuple[Tensor, Tensor]
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
text_embedding_g, pooled_text_embedding = text_embedding_with_pooling
|
||||
text_embedding = cat(tensors=[text_embedding_l, text_embedding_g], dim=-1)
|
||||
return text_embedding, pooled_text_embedding
|
102
tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py
Normal file
102
tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py
Normal file
|
@ -0,0 +1,102 @@
|
|||
from typing import Any, Protocol, cast
|
||||
from pathlib import Path
|
||||
from warnings import warn
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from refiners.fluxion.utils import manual_seed
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextEncoderL
|
||||
from refiners.foundationals.latent_diffusion.sdxl_text_encoder import DoubleTextEncoder
|
||||
|
||||
|
||||
class DiffusersSDXL(Protocol):
|
||||
unet: fl.Module
|
||||
text_encoder: fl.Module
|
||||
text_encoder_2: fl.Module
|
||||
tokenizer: fl.Module
|
||||
tokenizer_2: fl.Module
|
||||
vae: fl.Module
|
||||
|
||||
def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any:
|
||||
...
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_2: str | None = None,
|
||||
negative_prompt: str | None = None,
|
||||
negative_prompt_2: str | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
...
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def stabilityai_sdxl_base_path(test_weights_path: Path) -> Path:
|
||||
r = test_weights_path / "stabilityai" / "stable-diffusion-xl-base-1.0"
|
||||
if not r.is_dir():
|
||||
warn(message=f"could not find Stability SDXL base weights at {r}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return r
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def diffusers_sdxl(stabilityai_sdxl_base_path: Path) -> Any:
|
||||
from diffusers import DiffusionPipeline # type: ignore
|
||||
|
||||
return DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=stabilityai_sdxl_base_path) # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def double_text_encoder(test_weights_path: Path) -> DoubleTextEncoder:
|
||||
text_encoder_l = CLIPTextEncoderL()
|
||||
text_encoder_g_with_projection = CLIPTextEncoderG()
|
||||
text_encoder_g_with_projection.append(module=fl.Linear(in_features=1280, out_features=1280, bias=False))
|
||||
|
||||
text_encoder_l_path = test_weights_path / "CLIPTextEncoderL.safetensors"
|
||||
text_encdoer_g_path = test_weights_path / "CLIPTextEncoderGWithProjection.safetensors"
|
||||
|
||||
text_encoder_l.load_from_safetensors(tensors_path=text_encoder_l_path)
|
||||
text_encoder_g_with_projection.load_from_safetensors(tensors_path=text_encdoer_g_path)
|
||||
|
||||
linear = text_encoder_g_with_projection.pop(index=-1)
|
||||
assert isinstance(linear, fl.Linear)
|
||||
|
||||
double_text_encoder = DoubleTextEncoder(
|
||||
text_encoder_l=text_encoder_l, text_encoder_g=text_encoder_g_with_projection, projection=linear
|
||||
)
|
||||
|
||||
return double_text_encoder
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder: DoubleTextEncoder) -> None:
|
||||
manual_seed(seed=0)
|
||||
prompt = "A photo of a pizza."
|
||||
|
||||
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = (
|
||||
diffusers_sdxl.encode_prompt(prompt=prompt, negative_prompt="")
|
||||
)
|
||||
|
||||
double_embedding, pooled_embedding = double_text_encoder(prompt)
|
||||
|
||||
assert double_embedding.shape == torch.Size([1, 77, 2048])
|
||||
assert pooled_embedding.shape == torch.Size([1, 1280])
|
||||
|
||||
embedding_1, embedding_2 = cast(
|
||||
tuple[Tensor, Tensor], prompt_embeds.split(split_size=[768, 1280], dim=-1) # type: ignore
|
||||
)
|
||||
|
||||
rembedding_1, rembedding_2 = cast(
|
||||
tuple[Tensor, Tensor], double_embedding.split(split_size=[768, 1280], dim=-1) # type: ignore
|
||||
)
|
||||
|
||||
assert torch.allclose(input=embedding_1, other=rembedding_1, rtol=1e-3, atol=1e-3)
|
||||
assert torch.allclose(input=embedding_2, other=rembedding_2, rtol=1e-3, atol=1e-3)
|
||||
assert torch.allclose(input=pooled_embedding, other=pooled_prompt_embeds, rtol=1e-3, atol=1e-3)
|
||||
|
||||
negative_double_embedding, negative_pooled_embedding = double_text_encoder("")
|
||||
|
||||
assert torch.allclose(input=negative_double_embedding, other=negative_prompt_embeds, rtol=1e-3, atol=1e-3)
|
||||
assert torch.allclose(input=negative_pooled_embedding, other=negative_pooled_prompt_embeds, rtol=1e-3, atol=1e-3)
|
Loading…
Reference in a new issue