diff --git a/pyproject.toml b/pyproject.toml index dd7cbc6..09b00fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ test = [ # HQ-SAM missing dependency: # https://github.com/SysCV/sam-hq/pull/59 "timm>=0.5.0", + "sentencepiece>=0.2.0", ] conversion = [ "diffusers>=0.26.1", diff --git a/scripts/conversion/convert_ella_adapter.py b/scripts/conversion/convert_ella_adapter.py new file mode 100644 index 0000000..74aaa74 --- /dev/null +++ b/scripts/conversion/convert_ella_adapter.py @@ -0,0 +1,102 @@ +import argparse +from pathlib import Path + +import torch +from huggingface_hub import hf_hub_download # type: ignore + +from refiners.fluxion.utils import load_from_safetensors, save_to_safetensors + + +class Args(argparse.Namespace): + source_path: str + output_path: str | None + use_half: bool + + +def convert(args: Args) -> dict[str, torch.Tensor]: + if Path(args.source_path).suffix != ".safetensors": + args.source_path = hf_hub_download( + repo_id=args.source_path, filename="ella-sd1.5-tsc-t5xl.safetensors", local_dir="tests/weights/ELLA-Adapter" + ) + weights = load_from_safetensors(args.source_path) + + for key in list(weights.keys()): + if "latents" in key: + new_key = "PerceiverResampler.Latents.ParameterInitialized.weight" + weights[new_key] = weights.pop(key) + elif "time_embedding" in key: + new_key = key.replace("time_embedding", "TimestepEncoder.RangeEncoder").replace("linear", "Linear") + weights[new_key] = weights.pop(key) + elif "proj_in" in key: + new_key = f"PerceiverResampler.Linear.{key.split('.')[-1]}" + weights[new_key] = weights.pop(key) + elif "time_aware" in key: + new_key = f"PerceiverResampler.Residual.Linear.{key.split('.')[-1]}" + weights[new_key] = weights.pop(key) + elif "attn.in_proj" in key: + layer_num = int(key.split(".")[2]) + query_param, key_param, value_param = weights.pop(key).chunk(3, dim=0) + param_type = "weight" if "weight" in key else "bias" + for i, param in enumerate([query_param, key_param, value_param]): + new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_1.PerceiverAttention.Attention.Distribute.Linear_{i+1}.{param_type}" + weights[new_key] = param + elif "attn.out_proj" in key: + layer_num = int(key.split(".")[2]) + new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_1.PerceiverAttention.Attention.Linear.{key.split('.')[-1]}" + weights[new_key] = weights.pop(key) + elif "ln_ff" in key: + layer_num = int(key.split(".")[2]) + new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_2.AdaLayerNorm.Parallel.Chain.Linear.{key.split('.')[-1]}" + weights[new_key] = weights.pop(key) + elif "ln_1" in key or "ln_2" in key: + layer_num = int(key.split(".")[2]) + n = 1 if int(key.split(".")[3].split("_")[-1]) == 2 else 2 + new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_1.PerceiverAttention.Distribute.AdaLayerNorm_{n}.Parallel.Chain.Linear.{key.split('.')[-1]}" + weights[new_key] = weights.pop(key) + elif "mlp" in key: + layer_num = int(key.split(".")[2]) + n = 1 if "c_fc" in key else 2 + new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_2.FeedForward.Linear_{n}.{key.split('.')[-1]}" + weights[new_key] = weights.pop(key) + + if args.use_half: + weights = {key: value.half() for key, value in weights.items()} + + return weights + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert a pretrained Ella Adapter to refiners implementation") + parser.add_argument( + "--from", + type=str, + dest="source_path", + default="QQGYLab/ELLA", + help=( + "A path to a local .safetensors weights. If not provided, a repo from Hugging Face Hub will be used" + "Default to QQGYLab/ELLA" + ), + ) + + parser.add_argument( + "--to", + type=str, + dest="output_path", + default=None, + help=( + "Path to save the converted model (extension will be .safetensors). If not specified, the output path will" + " be the source path with the prefix set to refiners" + ), + ) + parser.add_argument( + "--half", + action="store_true", + dest="use_half", + default=True, + help="Use this flag to save the output file as half precision (default: full precision).", + ) + args = parser.parse_args(namespace=Args()) + weights = convert(args) + if args.output_path is None: + args.output_path = f"{Path(args.source_path).stem}-refiners.safetensors" + save_to_safetensors(path=args.output_path, tensors=weights) diff --git a/scripts/prepare_test_weights.py b/scripts/prepare_test_weights.py index e09f477..356718e 100644 --- a/scripts/prepare_test_weights.py +++ b/scripts/prepare_test_weights.py @@ -353,6 +353,29 @@ def download_ip_adapter(): download_files(urls, sdxl_models_folder) +def download_t5xl_fp16(): + base_folder = os.path.join(test_weights_dir, "QQGYLab", "T5XLFP16") + urls = [ + "https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/config.json", + "https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/model.safetensors", + "https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/special_tokens_map.json", + "https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/spiece.model", + "https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/tokenizer.json", + "https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/tokenizer_config.json", + ] + download_files(urls, base_folder) + + +def download_ella_adapter(): + download_t5xl_fp16() + base_folder = os.path.join(test_weights_dir, "QQGYLab", "ELLA") + download_file( + "https://huggingface.co/QQGYLab/ELLA/resolve/main/ella-sd1.5-tsc-t5xl.safetensors", + base_folder, + expected_hash="5af7b200", + ) + + def download_t2i_adapter(): base_folder = os.path.join(test_weights_dir, "TencentARC", "t2iadapter_depth_sd15v2") urls = [ @@ -689,6 +712,17 @@ def convert_ip_adapter(): ) +def convert_ella_adapter(): + os.makedirs("tests/weights/ELLA-Adapter", exist_ok=True) + run_conversion_script( + "convert_ella_adapter.py", + "tests/weights/QQGYLab/ELLA/ella-sd1.5-tsc-t5xl.safetensors", + "tests/weights/ELLA-Adapter/ella-sd1.5-tsc-t5xl.safetensors", + half=True, + expected_hash="b8244cb6", + ) + + def convert_t2i_adapter(): os.makedirs("tests/weights/T2I-Adapter", exist_ok=True) run_conversion_script( @@ -860,6 +894,7 @@ def download_all(): download_unclip() download_ip_adapter() download_t2i_adapter() + download_ella_adapter() download_sam() download_hq_sam() download_dinov2() @@ -884,6 +919,7 @@ def convert_all(): convert_unclip() convert_ip_adapter() convert_t2i_adapter() + convert_ella_adapter() convert_sam() convert_hq_sam() convert_dinov2() diff --git a/src/refiners/foundationals/latent_diffusion/__init__.py b/src/refiners/foundationals/latent_diffusion/__init__.py index eab4071..f7290e2 100644 --- a/src/refiners/foundationals/latent_diffusion/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/__init__.py @@ -8,6 +8,7 @@ from refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter from refiners.foundationals.latent_diffusion.solvers import DPMSolver, LCMSolver, Solver from refiners.foundationals.latent_diffusion.stable_diffusion_1 import ( SD1ControlnetAdapter, + SD1ELLAAdapter, SD1IPAdapter, SD1T2IAdapter, SD1UNet, @@ -32,6 +33,7 @@ __all__ = [ "SD1ControlnetAdapter", "SD1IPAdapter", "SD1T2IAdapter", + "SD1ELLAAdapter", "SDXLUNet", "DoubleTextEncoder", "SDXLIPAdapter", diff --git a/src/refiners/foundationals/latent_diffusion/ella_adapter.py b/src/refiners/foundationals/latent_diffusion/ella_adapter.py new file mode 100644 index 0000000..ed2b543 --- /dev/null +++ b/src/refiners/foundationals/latent_diffusion/ella_adapter.py @@ -0,0 +1,285 @@ +from typing import TYPE_CHECKING, Any, Generic, Iterable, TypeVar + +import torch.nn as nn +from torch import Tensor, cat, device as Device, dtype as DType + +import refiners.fluxion.layers as fl +from refiners.fluxion.adapters.adapter import Adapter +from refiners.fluxion.context import Contexts +from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock +from refiners.foundationals.latent_diffusion.range_adapter import RangeEncoder + +if TYPE_CHECKING: + from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet + from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet + +T = TypeVar("T", bound="SD1UNet | SDXLUNet") +TELLAAdapter = TypeVar("TELLAAdapter", bound="ELLAAdapter[Any]") + + +class LayerNormNoAffine(nn.LayerNorm, fl.Module): + def __init__( + self, + normalized_shape: int | Iterable[int], + eps: float = 1e-5, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__(normalized_shape, eps=eps, elementwise_affine=False, device=device, dtype=dtype) # type: ignore + + +class TimestepEncoder(fl.Passthrough): + def __init__( + self, + time_embedding_dim: int, + time_channel: int, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__( + fl.UseContext("diffusion", "timestep"), + RangeEncoder(time_channel, time_embedding_dim, device=device, dtype=dtype), + fl.SetContext("ella", "timestep_embedding"), + ) + + +class SquaredReLU(fl.ReLU): + def __init__(self) -> None: + super().__init__() + + def forward(self, x: Tensor) -> Tensor: + return super().forward(x).pow(2) + + +class AdaLayerNorm(fl.Chain): + def __init__( + self, + embedding_dim: int, + time_embedding_dim: int, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__( + fl.Parallel( + LayerNormNoAffine(embedding_dim, eps=1e-6, device=device, dtype=dtype), + fl.Chain( + fl.UseContext("ella", "timestep_embedding"), + fl.SiLU(), + fl.Linear(time_embedding_dim, embedding_dim * 2, device=device, dtype=dtype), + ), + ), + fl.Lambda(self._scale_shift_tensors), + ) + + self._init_parameters() + + def _init_parameters(self) -> None: + _linear: fl.Linear = self.ensure_find(fl.Linear) + nn.init.zeros_(_linear.weight) + nn.init.zeros_(_linear.bias) + + def _scale_shift_tensors(self, x: Tensor, time_embedding: Tensor) -> Tensor: + shift, scale = time_embedding.chunk(2, dim=-1) + return x * (1 + scale) + shift + + +class ParameterInitialized(fl.Parameter): + def __init__( + self, *dims: int, requires_grad: bool = True, device: Device | str | None = None, dtype: DType | None = None + ) -> None: + super().__init__(*dims, requires_grad=requires_grad, device=device, dtype=dtype) + nn.init.normal_(self.weight, mean=0, std=dims[1] ** 0.5) + + +class Latents(fl.Chain): + def __init__( + self, + num_latents: int, + width: int, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__( + ParameterInitialized( + num_latents, + width, + device=device, + dtype=dtype, + ), + ) + + +class PerceiverAttention(fl.Chain): + def __init__( + self, + width: int, + num_heads: int, + timestep_embedding_dim: int, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__( + fl.Distribute( + AdaLayerNorm(width, timestep_embedding_dim, device=device, dtype=dtype), + AdaLayerNorm(width, timestep_embedding_dim, device=device, dtype=dtype), + ), + fl.Parallel( + fl.GetArg(index=1), + fl.Lambda(func=self.to_kv), + fl.Lambda(func=self.to_kv), + ), + fl.Attention(embedding_dim=width, num_heads=num_heads, device=device, dtype=dtype), + ) + + def to_kv(self, x: Tensor, latents: Tensor) -> Tensor: + return cat((latents, x), dim=-2) + + +class OutputProjection(fl.Chain): + def __init__( + self, width: int, output_dim: int, device: Device | str | None = None, dtype: DType | None = None + ) -> None: + super().__init__( + fl.Linear(width, output_dim, device=device, dtype=dtype), + fl.LayerNorm(output_dim, device=device, dtype=dtype), + ) + + +class Transformer(fl.Chain): + pass + + +class TransformerLayer(fl.Chain): + pass + + +class FeedForward(fl.Chain): + def __init__( + self, + width: int, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__( + fl.Linear(width, width * 4, device=device, dtype=dtype), + SquaredReLU(), + fl.Linear(width * 4, width, device=device, dtype=dtype), + ) + + +class PerceiverResampler(fl.Chain): + def __init__( + self, + time_embedding_dim: int, + width: int, + num_layers: int, + num_heads: int, + num_latents: int, + output_dim: int | None, + input_dim: int | None, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__( + fl.Linear(input_dim, width, device=device, dtype=dtype) if input_dim else fl.Identity(), + fl.SetContext("perceiver_resampler", "x"), + Latents(num_latents, width, device=device, dtype=dtype), + fl.Residual( + fl.UseContext("ella", "timestep_embedding"), + fl.SiLU(), + fl.Linear(time_embedding_dim, width, device=device, dtype=dtype), + ), + Transformer( + TransformerLayer( + fl.Residual( + fl.Parallel(fl.UseContext(context="perceiver_resampler", key="x"), fl.Identity()), + PerceiverAttention(width, num_heads, time_embedding_dim, device=device, dtype=dtype), + ), + fl.Residual( + AdaLayerNorm(width, time_embedding_dim, device=device, dtype=dtype), + FeedForward(width, device=device, dtype=dtype), + ), + ) + for _ in range(num_layers) + ), + OutputProjection(width, output_dim, device=device, dtype=dtype) if output_dim else fl.Identity(), + ) + + def init_context(self) -> Contexts: + return {"perceiver_resampler": {"x": None}} + + +class ELLA(fl.Passthrough): + def __init__( + self, + time_channel: int, + timestep_embedding_dim: int, + width: int, + num_layers: int, + num_heads: int, + num_latents: int, + input_dim: int | None = None, + out_dim: int | None = None, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__( + TimestepEncoder(timestep_embedding_dim, time_channel, device=device, dtype=dtype), + fl.UseContext("adapted_cross_attention_block", "llm_text_embedding"), + PerceiverResampler( + timestep_embedding_dim, + width, + num_layers, + num_heads, + num_latents, + out_dim, + input_dim, + device=device, + dtype=dtype, + ), + fl.SetContext("ella", "latents"), + ) + + +class ELLACrossAttentionAdapter(fl.Chain, Adapter[fl.UseContext]): + # TODO: concatenate the latents with the clip text embedding https://github.com/TencentQQGYLab/ELLA/tree/main?tab=readme-ov-file#3-ellaclip-for-community-models + def __init__(self, target: fl.UseContext) -> None: + with self.setup_adapter(target): + super().__init__(fl.UseContext("ella", "latents")) + + +class ELLAAdapter(Generic[T], fl.Chain, Adapter[T]): + def __init__(self, target: T, latents_encoder: ELLA, weights: dict[str, Tensor] | None = None) -> None: + if weights is not None: + latents_encoder.load_state_dict(weights) + + self._latents_encoder = [latents_encoder] + with self.setup_adapter(target): + super().__init__(target) + self.sub_adapters = [ + ELLACrossAttentionAdapter(use_context) + for cross_attn in target.layers(CrossAttentionBlock) + for use_context in cross_attn.layers(fl.UseContext) + ] + + def inject(self: TELLAAdapter, parent: fl.Chain | None = None) -> TELLAAdapter: + for adapter in self.sub_adapters: + adapter.inject() + self.target.insert(0, self.latents_encoder) + return super().inject(parent) + + def eject(self) -> None: + for adapter in self.sub_adapters: + adapter.eject() + self.target.pop(0) + super().eject() + + @property + def latents_encoder(self) -> ELLA: + return self._latents_encoder[0] + + def set_llm_text_embedding(self, text_embedding: Tensor) -> None: + self.set_context("adapted_cross_attention_block", {"llm_text_embedding": text_embedding}) + + def init_context(self) -> Contexts: + return {"ella": {"timestep_embedding": None, "latents": None}} diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/__init__.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/__init__.py index 2914ff3..d85be1c 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/__init__.py @@ -1,4 +1,5 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter +from refiners.foundationals.latent_diffusion.stable_diffusion_1.ella_adapter import SD1ELLAAdapter from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight from refiners.foundationals.latent_diffusion.stable_diffusion_1.image_prompt import SD1IPAdapter from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import ( @@ -18,4 +19,5 @@ __all__ = [ "SD1IPAdapter", "SD1T2IAdapter", "ICLight", + "SD1ELLAAdapter", ] diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ella_adapter.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ella_adapter.py new file mode 100644 index 0000000..4b4fccd --- /dev/null +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ella_adapter.py @@ -0,0 +1,20 @@ +from torch import Tensor + +from refiners.foundationals.latent_diffusion.ella_adapter import ELLA, ELLAAdapter +from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet + + +class SD1ELLAAdapter(ELLAAdapter[SD1UNet]): + def __init__(self, target: SD1UNet, weights: dict[str, Tensor] | None = None) -> None: + latents_encoder = ELLA( + time_channel=320, + timestep_embedding_dim=768, + width=768, + num_layers=6, + num_heads=8, + num_latents=64, + input_dim=2048, + device=target.device, + dtype=target.dtype, + ) + super().__init__(target=target, latents_encoder=latents_encoder, weights=weights) diff --git a/tests/adapters/test_ella_adapter.py b/tests/adapters/test_ella_adapter.py new file mode 100644 index 0000000..db6dc1b --- /dev/null +++ b/tests/adapters/test_ella_adapter.py @@ -0,0 +1,41 @@ +import torch + +import refiners.fluxion.layers as fl +from refiners.fluxion.utils import no_grad +from refiners.foundationals.latent_diffusion import SD1ELLAAdapter, SD1UNet +from refiners.foundationals.latent_diffusion.ella_adapter import ELLACrossAttentionAdapter + + +def new_adapter(target: SD1UNet) -> SD1ELLAAdapter: + return SD1ELLAAdapter(target=target) + + +@no_grad() +def test_inject_eject(test_device: torch.device): + unet = SD1UNet(in_channels=4, device=test_device, dtype=torch.float16) + initial_repr = repr(unet) + adapter = new_adapter(unet) + assert repr(unet) == initial_repr + adapter.inject() + assert repr(unet) != initial_repr + adapter.eject() + assert repr(unet) == initial_repr + adapter.inject() + assert repr(unet) != initial_repr + adapter.eject() + assert repr(unet) == initial_repr + + +@no_grad() +def test_ella_cross_attention(test_device: torch.device): + unet = SD1UNet(in_channels=4, device=test_device, dtype=torch.float16) + adapter = new_adapter(unet).inject() + + def predicate(m: fl.Module, p: fl.Chain) -> bool: + return isinstance(p, ELLACrossAttentionAdapter) and isinstance(m, fl.UseContext) + + for m, _ in unet.walk(predicate): + assert isinstance(m, fl.UseContext) + assert m.context == "ella" + assert m.key == "latents" + assert len(adapter.sub_adapters) == 32 diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index a8f5396..c8320dc 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -7,7 +7,7 @@ from warnings import warn import pytest import torch from PIL import Image -from tests.utils import ensure_similar_images +from tests.utils import T5TextEmbedder, ensure_similar_images from refiners.fluxion.layers.attentions import ScaledDotProductAttention from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad @@ -16,6 +16,7 @@ from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.latent_diffusion import ( ControlLoraAdapter, SD1ControlnetAdapter, + SD1ELLAAdapter, SD1IPAdapter, SD1T2IAdapter, SD1UNet, @@ -116,6 +117,11 @@ def expected_image_std_init_image(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_std_init_image.png").convert("RGB") +@pytest.fixture +def expected_image_ella_adapter(ref_path: Path) -> Image.Image: + return _img_open(ref_path / "expected_image_ella_adapter.png").convert("RGB") + + @pytest.fixture def expected_image_std_inpainting(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_std_inpainting.png").convert("RGB") @@ -506,6 +512,29 @@ def lda_ft_mse_weights(test_weights_path: Path) -> Path: return lda_weights +@pytest.fixture(scope="module") +def ella_weights(test_weights_path: Path) -> tuple[Path, Path]: + ella_adapter_weights = test_weights_path / "ELLA-Adapter" / "ella-sd1.5-tsc-t5xl.safetensors" + if not ella_adapter_weights.is_file(): + warn(f"could not find weights at {ella_adapter_weights}, skipping") + pytest.skip(allow_module_level=True) + t5xl_weights = test_weights_path / "QQGYLab" / "T5XLFP16" + t5xl_files = [ + "config.json", + "model.safetensors", + "special_tokens_map.json", + "spiece.model", + "tokenizer_config.json", + "tokenizer.json", + ] + for file in t5xl_files: + if not (t5xl_weights / file).is_file(): + warn(f"could not find weights at {t5xl_weights / file}, skipping") + pytest.skip(allow_module_level=True) + + return (ella_adapter_weights, t5xl_weights) + + @pytest.fixture(scope="module") def ip_adapter_weights(test_weights_path: Path) -> Path: ip_adapter_weights = test_weights_path / "ip-adapter_sd15.safetensors" @@ -1799,6 +1828,43 @@ def test_diffusion_textual_inversion_random_init( ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98) +@no_grad() +def test_diffusion_ella_adapter( + sd15_std_float16: StableDiffusion_1, + ella_weights: tuple[Path, Path], + expected_image_ella_adapter: Image.Image, + test_device: torch.device, +): + sd15 = sd15_std_float16 + ella_adapter_weights, t5xl_weights = ella_weights + t5_encoder = T5TextEmbedder(pretrained_path=t5xl_weights, max_length=128).to(test_device, torch.float16) + + prompt = "a chinese man wearing a white shirt and a checkered headscarf, holds a large falcon near his shoulder. the falcon has dark feathers with a distinctive beak. the background consists of a clear sky and a fence, suggesting an outdoor setting, possibly a desert or arid region" + negative_prompt = "" + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + assert clip_text_embedding.dtype == torch.float16 + + llm_text_embedding, negative_prompt_embeds = t5_encoder(prompt), t5_encoder(negative_prompt) + prompt_embedding = torch.cat((negative_prompt_embeds, llm_text_embedding)).to(test_device, torch.float16) + + adapter = SD1ELLAAdapter(target=sd15.unet, weights=load_from_safetensors(ella_adapter_weights)) + adapter.inject() + sd15.set_inference_steps(50) + manual_seed(1001) + x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) + + for step in sd15.steps: + adapter.set_llm_text_embedding(prompt_embedding) + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=12, + ) + predicted_image = sd15.lda.latents_to_image(x) + ensure_similar_images(predicted_image, expected_image_ella_adapter, min_psnr=35, min_ssim=0.98) + + @no_grad() def test_diffusion_ip_adapter( sd15_ddim_lda_ft_mse: StableDiffusion_1, diff --git a/tests/e2e/test_diffusion_ref/expected_image_ella_adapter.png b/tests/e2e/test_diffusion_ref/expected_image_ella_adapter.png new file mode 100644 index 0000000..646089c Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_image_ella_adapter.png differ diff --git a/tests/utils.py b/tests/utils.py index a89a86f..df0f2a2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,7 +1,11 @@ +from pathlib import Path + import numpy as np import piq # type: ignore import torch +import torch.nn as nn from PIL import Image +from transformers import T5EncoderModel, T5Tokenizer # type: ignore def compare_images(img_1: Image.Image, img_2: Image.Image) -> tuple[int, float]: @@ -16,3 +20,46 @@ def ensure_similar_images(img_1: Image.Image, img_2: Image.Image, min_psnr: int assert (psnr >= min_psnr) and ( ssim >= min_ssim ), f"PSNR {psnr} / SSIM {ssim}, expected at least {min_psnr} / {min_ssim}" + + +class T5TextEmbedder(nn.Module): + def __init__( + self, pretrained_path: Path = Path("tests/weights/QQGYLab/T5XLFP16"), max_length: int | None = None + ) -> None: + super().__init__() # type: ignore[reportUnknownMemberType] + self.model: nn.Module = T5EncoderModel.from_pretrained(pretrained_path, local_files_only=True) # type: ignore + self.tokenizer: transformers.T5Tokenizer = T5Tokenizer.from_pretrained(pretrained_path, local_files_only=True) # type: ignore + self.max_length = max_length + + def forward( + self, + caption: str, + text_input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + max_length: int | None = None, + ) -> torch.Tensor: + if max_length is None: + max_length = self.max_length + + if text_input_ids is None or attention_mask is None: + if max_length is not None: + text_inputs = self.tokenizer( # type: ignore + caption, + return_tensors="pt", + add_special_tokens=True, + max_length=max_length, + padding="max_length", + truncation=True, + ) + else: + text_inputs = self.tokenizer(caption, return_tensors="pt", add_special_tokens=True) # type: ignore + _text_input_ids: torch.Tensor = text_inputs.input_ids.to(self.model.device) # type: ignore + _attention_mask: torch.Tensor = text_inputs.attention_mask.to(self.model.device) # type: ignore + else: + _text_input_ids: torch.Tensor = text_input_ids.to(self.model.device) # type: ignore + _attention_mask: torch.Tensor = attention_mask.to(self.model.device) # type: ignore + + outputs = self.model(_text_input_ids, attention_mask=_attention_mask) + + embeddings = outputs.last_hidden_state + return embeddings