ella adapter implementation. tested with sd1.5 model

This commit is contained in:
ily-R 2024-06-30 18:56:32 +00:00 committed by Cédric Deltheil
parent a8efe5eab0
commit 277b0fd837
11 changed files with 603 additions and 1 deletions

View file

@ -55,6 +55,7 @@ test = [
# HQ-SAM missing dependency:
# https://github.com/SysCV/sam-hq/pull/59
"timm>=0.5.0",
"sentencepiece>=0.2.0",
]
conversion = [
"diffusers>=0.26.1",

View file

@ -0,0 +1,102 @@
import argparse
from pathlib import Path
import torch
from huggingface_hub import hf_hub_download # type: ignore
from refiners.fluxion.utils import load_from_safetensors, save_to_safetensors
class Args(argparse.Namespace):
source_path: str
output_path: str | None
use_half: bool
def convert(args: Args) -> dict[str, torch.Tensor]:
if Path(args.source_path).suffix != ".safetensors":
args.source_path = hf_hub_download(
repo_id=args.source_path, filename="ella-sd1.5-tsc-t5xl.safetensors", local_dir="tests/weights/ELLA-Adapter"
)
weights = load_from_safetensors(args.source_path)
for key in list(weights.keys()):
if "latents" in key:
new_key = "PerceiverResampler.Latents.ParameterInitialized.weight"
weights[new_key] = weights.pop(key)
elif "time_embedding" in key:
new_key = key.replace("time_embedding", "TimestepEncoder.RangeEncoder").replace("linear", "Linear")
weights[new_key] = weights.pop(key)
elif "proj_in" in key:
new_key = f"PerceiverResampler.Linear.{key.split('.')[-1]}"
weights[new_key] = weights.pop(key)
elif "time_aware" in key:
new_key = f"PerceiverResampler.Residual.Linear.{key.split('.')[-1]}"
weights[new_key] = weights.pop(key)
elif "attn.in_proj" in key:
layer_num = int(key.split(".")[2])
query_param, key_param, value_param = weights.pop(key).chunk(3, dim=0)
param_type = "weight" if "weight" in key else "bias"
for i, param in enumerate([query_param, key_param, value_param]):
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_1.PerceiverAttention.Attention.Distribute.Linear_{i+1}.{param_type}"
weights[new_key] = param
elif "attn.out_proj" in key:
layer_num = int(key.split(".")[2])
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_1.PerceiverAttention.Attention.Linear.{key.split('.')[-1]}"
weights[new_key] = weights.pop(key)
elif "ln_ff" in key:
layer_num = int(key.split(".")[2])
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_2.AdaLayerNorm.Parallel.Chain.Linear.{key.split('.')[-1]}"
weights[new_key] = weights.pop(key)
elif "ln_1" in key or "ln_2" in key:
layer_num = int(key.split(".")[2])
n = 1 if int(key.split(".")[3].split("_")[-1]) == 2 else 2
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_1.PerceiverAttention.Distribute.AdaLayerNorm_{n}.Parallel.Chain.Linear.{key.split('.')[-1]}"
weights[new_key] = weights.pop(key)
elif "mlp" in key:
layer_num = int(key.split(".")[2])
n = 1 if "c_fc" in key else 2
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_2.FeedForward.Linear_{n}.{key.split('.')[-1]}"
weights[new_key] = weights.pop(key)
if args.use_half:
weights = {key: value.half() for key, value in weights.items()}
return weights
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert a pretrained Ella Adapter to refiners implementation")
parser.add_argument(
"--from",
type=str,
dest="source_path",
default="QQGYLab/ELLA",
help=(
"A path to a local .safetensors weights. If not provided, a repo from Hugging Face Hub will be used"
"Default to QQGYLab/ELLA"
),
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default=None,
help=(
"Path to save the converted model (extension will be .safetensors). If not specified, the output path will"
" be the source path with the prefix set to refiners"
),
)
parser.add_argument(
"--half",
action="store_true",
dest="use_half",
default=True,
help="Use this flag to save the output file as half precision (default: full precision).",
)
args = parser.parse_args(namespace=Args())
weights = convert(args)
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}-refiners.safetensors"
save_to_safetensors(path=args.output_path, tensors=weights)

View file

@ -353,6 +353,29 @@ def download_ip_adapter():
download_files(urls, sdxl_models_folder)
def download_t5xl_fp16():
base_folder = os.path.join(test_weights_dir, "QQGYLab", "T5XLFP16")
urls = [
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/config.json",
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/model.safetensors",
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/special_tokens_map.json",
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/spiece.model",
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/tokenizer.json",
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/tokenizer_config.json",
]
download_files(urls, base_folder)
def download_ella_adapter():
download_t5xl_fp16()
base_folder = os.path.join(test_weights_dir, "QQGYLab", "ELLA")
download_file(
"https://huggingface.co/QQGYLab/ELLA/resolve/main/ella-sd1.5-tsc-t5xl.safetensors",
base_folder,
expected_hash="5af7b200",
)
def download_t2i_adapter():
base_folder = os.path.join(test_weights_dir, "TencentARC", "t2iadapter_depth_sd15v2")
urls = [
@ -689,6 +712,17 @@ def convert_ip_adapter():
)
def convert_ella_adapter():
os.makedirs("tests/weights/ELLA-Adapter", exist_ok=True)
run_conversion_script(
"convert_ella_adapter.py",
"tests/weights/QQGYLab/ELLA/ella-sd1.5-tsc-t5xl.safetensors",
"tests/weights/ELLA-Adapter/ella-sd1.5-tsc-t5xl.safetensors",
half=True,
expected_hash="b8244cb6",
)
def convert_t2i_adapter():
os.makedirs("tests/weights/T2I-Adapter", exist_ok=True)
run_conversion_script(
@ -860,6 +894,7 @@ def download_all():
download_unclip()
download_ip_adapter()
download_t2i_adapter()
download_ella_adapter()
download_sam()
download_hq_sam()
download_dinov2()
@ -884,6 +919,7 @@ def convert_all():
convert_unclip()
convert_ip_adapter()
convert_t2i_adapter()
convert_ella_adapter()
convert_sam()
convert_hq_sam()
convert_dinov2()

View file

@ -8,6 +8,7 @@ from refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter
from refiners.foundationals.latent_diffusion.solvers import DPMSolver, LCMSolver, Solver
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
SD1ControlnetAdapter,
SD1ELLAAdapter,
SD1IPAdapter,
SD1T2IAdapter,
SD1UNet,
@ -32,6 +33,7 @@ __all__ = [
"SD1ControlnetAdapter",
"SD1IPAdapter",
"SD1T2IAdapter",
"SD1ELLAAdapter",
"SDXLUNet",
"DoubleTextEncoder",
"SDXLIPAdapter",

View file

@ -0,0 +1,285 @@
from typing import TYPE_CHECKING, Any, Generic, Iterable, TypeVar
import torch.nn as nn
from torch import Tensor, cat, device as Device, dtype as DType
import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.context import Contexts
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
from refiners.foundationals.latent_diffusion.range_adapter import RangeEncoder
if TYPE_CHECKING:
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
T = TypeVar("T", bound="SD1UNet | SDXLUNet")
TELLAAdapter = TypeVar("TELLAAdapter", bound="ELLAAdapter[Any]")
class LayerNormNoAffine(nn.LayerNorm, fl.Module):
def __init__(
self,
normalized_shape: int | Iterable[int],
eps: float = 1e-5,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(normalized_shape, eps=eps, elementwise_affine=False, device=device, dtype=dtype) # type: ignore
class TimestepEncoder(fl.Passthrough):
def __init__(
self,
time_embedding_dim: int,
time_channel: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
fl.UseContext("diffusion", "timestep"),
RangeEncoder(time_channel, time_embedding_dim, device=device, dtype=dtype),
fl.SetContext("ella", "timestep_embedding"),
)
class SquaredReLU(fl.ReLU):
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor) -> Tensor:
return super().forward(x).pow(2)
class AdaLayerNorm(fl.Chain):
def __init__(
self,
embedding_dim: int,
time_embedding_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
fl.Parallel(
LayerNormNoAffine(embedding_dim, eps=1e-6, device=device, dtype=dtype),
fl.Chain(
fl.UseContext("ella", "timestep_embedding"),
fl.SiLU(),
fl.Linear(time_embedding_dim, embedding_dim * 2, device=device, dtype=dtype),
),
),
fl.Lambda(self._scale_shift_tensors),
)
self._init_parameters()
def _init_parameters(self) -> None:
_linear: fl.Linear = self.ensure_find(fl.Linear)
nn.init.zeros_(_linear.weight)
nn.init.zeros_(_linear.bias)
def _scale_shift_tensors(self, x: Tensor, time_embedding: Tensor) -> Tensor:
shift, scale = time_embedding.chunk(2, dim=-1)
return x * (1 + scale) + shift
class ParameterInitialized(fl.Parameter):
def __init__(
self, *dims: int, requires_grad: bool = True, device: Device | str | None = None, dtype: DType | None = None
) -> None:
super().__init__(*dims, requires_grad=requires_grad, device=device, dtype=dtype)
nn.init.normal_(self.weight, mean=0, std=dims[1] ** 0.5)
class Latents(fl.Chain):
def __init__(
self,
num_latents: int,
width: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
ParameterInitialized(
num_latents,
width,
device=device,
dtype=dtype,
),
)
class PerceiverAttention(fl.Chain):
def __init__(
self,
width: int,
num_heads: int,
timestep_embedding_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
fl.Distribute(
AdaLayerNorm(width, timestep_embedding_dim, device=device, dtype=dtype),
AdaLayerNorm(width, timestep_embedding_dim, device=device, dtype=dtype),
),
fl.Parallel(
fl.GetArg(index=1),
fl.Lambda(func=self.to_kv),
fl.Lambda(func=self.to_kv),
),
fl.Attention(embedding_dim=width, num_heads=num_heads, device=device, dtype=dtype),
)
def to_kv(self, x: Tensor, latents: Tensor) -> Tensor:
return cat((latents, x), dim=-2)
class OutputProjection(fl.Chain):
def __init__(
self, width: int, output_dim: int, device: Device | str | None = None, dtype: DType | None = None
) -> None:
super().__init__(
fl.Linear(width, output_dim, device=device, dtype=dtype),
fl.LayerNorm(output_dim, device=device, dtype=dtype),
)
class Transformer(fl.Chain):
pass
class TransformerLayer(fl.Chain):
pass
class FeedForward(fl.Chain):
def __init__(
self,
width: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
fl.Linear(width, width * 4, device=device, dtype=dtype),
SquaredReLU(),
fl.Linear(width * 4, width, device=device, dtype=dtype),
)
class PerceiverResampler(fl.Chain):
def __init__(
self,
time_embedding_dim: int,
width: int,
num_layers: int,
num_heads: int,
num_latents: int,
output_dim: int | None,
input_dim: int | None,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
fl.Linear(input_dim, width, device=device, dtype=dtype) if input_dim else fl.Identity(),
fl.SetContext("perceiver_resampler", "x"),
Latents(num_latents, width, device=device, dtype=dtype),
fl.Residual(
fl.UseContext("ella", "timestep_embedding"),
fl.SiLU(),
fl.Linear(time_embedding_dim, width, device=device, dtype=dtype),
),
Transformer(
TransformerLayer(
fl.Residual(
fl.Parallel(fl.UseContext(context="perceiver_resampler", key="x"), fl.Identity()),
PerceiverAttention(width, num_heads, time_embedding_dim, device=device, dtype=dtype),
),
fl.Residual(
AdaLayerNorm(width, time_embedding_dim, device=device, dtype=dtype),
FeedForward(width, device=device, dtype=dtype),
),
)
for _ in range(num_layers)
),
OutputProjection(width, output_dim, device=device, dtype=dtype) if output_dim else fl.Identity(),
)
def init_context(self) -> Contexts:
return {"perceiver_resampler": {"x": None}}
class ELLA(fl.Passthrough):
def __init__(
self,
time_channel: int,
timestep_embedding_dim: int,
width: int,
num_layers: int,
num_heads: int,
num_latents: int,
input_dim: int | None = None,
out_dim: int | None = None,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
TimestepEncoder(timestep_embedding_dim, time_channel, device=device, dtype=dtype),
fl.UseContext("adapted_cross_attention_block", "llm_text_embedding"),
PerceiverResampler(
timestep_embedding_dim,
width,
num_layers,
num_heads,
num_latents,
out_dim,
input_dim,
device=device,
dtype=dtype,
),
fl.SetContext("ella", "latents"),
)
class ELLACrossAttentionAdapter(fl.Chain, Adapter[fl.UseContext]):
# TODO: concatenate the latents with the clip text embedding https://github.com/TencentQQGYLab/ELLA/tree/main?tab=readme-ov-file#3-ellaclip-for-community-models
def __init__(self, target: fl.UseContext) -> None:
with self.setup_adapter(target):
super().__init__(fl.UseContext("ella", "latents"))
class ELLAAdapter(Generic[T], fl.Chain, Adapter[T]):
def __init__(self, target: T, latents_encoder: ELLA, weights: dict[str, Tensor] | None = None) -> None:
if weights is not None:
latents_encoder.load_state_dict(weights)
self._latents_encoder = [latents_encoder]
with self.setup_adapter(target):
super().__init__(target)
self.sub_adapters = [
ELLACrossAttentionAdapter(use_context)
for cross_attn in target.layers(CrossAttentionBlock)
for use_context in cross_attn.layers(fl.UseContext)
]
def inject(self: TELLAAdapter, parent: fl.Chain | None = None) -> TELLAAdapter:
for adapter in self.sub_adapters:
adapter.inject()
self.target.insert(0, self.latents_encoder)
return super().inject(parent)
def eject(self) -> None:
for adapter in self.sub_adapters:
adapter.eject()
self.target.pop(0)
super().eject()
@property
def latents_encoder(self) -> ELLA:
return self._latents_encoder[0]
def set_llm_text_embedding(self, text_embedding: Tensor) -> None:
self.set_context("adapted_cross_attention_block", {"llm_text_embedding": text_embedding})
def init_context(self) -> Contexts:
return {"ella": {"timestep_embedding": None, "latents": None}}

View file

@ -1,4 +1,5 @@
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_1.ella_adapter import SD1ELLAAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight
from refiners.foundationals.latent_diffusion.stable_diffusion_1.image_prompt import SD1IPAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
@ -18,4 +19,5 @@ __all__ = [
"SD1IPAdapter",
"SD1T2IAdapter",
"ICLight",
"SD1ELLAAdapter",
]

View file

@ -0,0 +1,20 @@
from torch import Tensor
from refiners.foundationals.latent_diffusion.ella_adapter import ELLA, ELLAAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
class SD1ELLAAdapter(ELLAAdapter[SD1UNet]):
def __init__(self, target: SD1UNet, weights: dict[str, Tensor] | None = None) -> None:
latents_encoder = ELLA(
time_channel=320,
timestep_embedding_dim=768,
width=768,
num_layers=6,
num_heads=8,
num_latents=64,
input_dim=2048,
device=target.device,
dtype=target.dtype,
)
super().__init__(target=target, latents_encoder=latents_encoder, weights=weights)

View file

@ -0,0 +1,41 @@
import torch
import refiners.fluxion.layers as fl
from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion import SD1ELLAAdapter, SD1UNet
from refiners.foundationals.latent_diffusion.ella_adapter import ELLACrossAttentionAdapter
def new_adapter(target: SD1UNet) -> SD1ELLAAdapter:
return SD1ELLAAdapter(target=target)
@no_grad()
def test_inject_eject(test_device: torch.device):
unet = SD1UNet(in_channels=4, device=test_device, dtype=torch.float16)
initial_repr = repr(unet)
adapter = new_adapter(unet)
assert repr(unet) == initial_repr
adapter.inject()
assert repr(unet) != initial_repr
adapter.eject()
assert repr(unet) == initial_repr
adapter.inject()
assert repr(unet) != initial_repr
adapter.eject()
assert repr(unet) == initial_repr
@no_grad()
def test_ella_cross_attention(test_device: torch.device):
unet = SD1UNet(in_channels=4, device=test_device, dtype=torch.float16)
adapter = new_adapter(unet).inject()
def predicate(m: fl.Module, p: fl.Chain) -> bool:
return isinstance(p, ELLACrossAttentionAdapter) and isinstance(m, fl.UseContext)
for m, _ in unet.walk(predicate):
assert isinstance(m, fl.UseContext)
assert m.context == "ella"
assert m.key == "latents"
assert len(adapter.sub_adapters) == 32

View file

@ -7,7 +7,7 @@ from warnings import warn
import pytest
import torch
from PIL import Image
from tests.utils import ensure_similar_images
from tests.utils import T5TextEmbedder, ensure_similar_images
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad
@ -16,6 +16,7 @@ from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.latent_diffusion import (
ControlLoraAdapter,
SD1ControlnetAdapter,
SD1ELLAAdapter,
SD1IPAdapter,
SD1T2IAdapter,
SD1UNet,
@ -116,6 +117,11 @@ def expected_image_std_init_image(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_init_image.png").convert("RGB")
@pytest.fixture
def expected_image_ella_adapter(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_ella_adapter.png").convert("RGB")
@pytest.fixture
def expected_image_std_inpainting(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_inpainting.png").convert("RGB")
@ -506,6 +512,29 @@ def lda_ft_mse_weights(test_weights_path: Path) -> Path:
return lda_weights
@pytest.fixture(scope="module")
def ella_weights(test_weights_path: Path) -> tuple[Path, Path]:
ella_adapter_weights = test_weights_path / "ELLA-Adapter" / "ella-sd1.5-tsc-t5xl.safetensors"
if not ella_adapter_weights.is_file():
warn(f"could not find weights at {ella_adapter_weights}, skipping")
pytest.skip(allow_module_level=True)
t5xl_weights = test_weights_path / "QQGYLab" / "T5XLFP16"
t5xl_files = [
"config.json",
"model.safetensors",
"special_tokens_map.json",
"spiece.model",
"tokenizer_config.json",
"tokenizer.json",
]
for file in t5xl_files:
if not (t5xl_weights / file).is_file():
warn(f"could not find weights at {t5xl_weights / file}, skipping")
pytest.skip(allow_module_level=True)
return (ella_adapter_weights, t5xl_weights)
@pytest.fixture(scope="module")
def ip_adapter_weights(test_weights_path: Path) -> Path:
ip_adapter_weights = test_weights_path / "ip-adapter_sd15.safetensors"
@ -1799,6 +1828,43 @@ def test_diffusion_textual_inversion_random_init(
ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_diffusion_ella_adapter(
sd15_std_float16: StableDiffusion_1,
ella_weights: tuple[Path, Path],
expected_image_ella_adapter: Image.Image,
test_device: torch.device,
):
sd15 = sd15_std_float16
ella_adapter_weights, t5xl_weights = ella_weights
t5_encoder = T5TextEmbedder(pretrained_path=t5xl_weights, max_length=128).to(test_device, torch.float16)
prompt = "a chinese man wearing a white shirt and a checkered headscarf, holds a large falcon near his shoulder. the falcon has dark feathers with a distinctive beak. the background consists of a clear sky and a fence, suggesting an outdoor setting, possibly a desert or arid region"
negative_prompt = ""
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
assert clip_text_embedding.dtype == torch.float16
llm_text_embedding, negative_prompt_embeds = t5_encoder(prompt), t5_encoder(negative_prompt)
prompt_embedding = torch.cat((negative_prompt_embeds, llm_text_embedding)).to(test_device, torch.float16)
adapter = SD1ELLAAdapter(target=sd15.unet, weights=load_from_safetensors(ella_adapter_weights))
adapter.inject()
sd15.set_inference_steps(50)
manual_seed(1001)
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
for step in sd15.steps:
adapter.set_llm_text_embedding(prompt_embedding)
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=12,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_ella_adapter, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_diffusion_ip_adapter(
sd15_ddim_lda_ft_mse: StableDiffusion_1,

Binary file not shown.

After

Width:  |  Height:  |  Size: 414 KiB

View file

@ -1,7 +1,11 @@
from pathlib import Path
import numpy as np
import piq # type: ignore
import torch
import torch.nn as nn
from PIL import Image
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
def compare_images(img_1: Image.Image, img_2: Image.Image) -> tuple[int, float]:
@ -16,3 +20,46 @@ def ensure_similar_images(img_1: Image.Image, img_2: Image.Image, min_psnr: int
assert (psnr >= min_psnr) and (
ssim >= min_ssim
), f"PSNR {psnr} / SSIM {ssim}, expected at least {min_psnr} / {min_ssim}"
class T5TextEmbedder(nn.Module):
def __init__(
self, pretrained_path: Path = Path("tests/weights/QQGYLab/T5XLFP16"), max_length: int | None = None
) -> None:
super().__init__() # type: ignore[reportUnknownMemberType]
self.model: nn.Module = T5EncoderModel.from_pretrained(pretrained_path, local_files_only=True) # type: ignore
self.tokenizer: transformers.T5Tokenizer = T5Tokenizer.from_pretrained(pretrained_path, local_files_only=True) # type: ignore
self.max_length = max_length
def forward(
self,
caption: str,
text_input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
max_length: int | None = None,
) -> torch.Tensor:
if max_length is None:
max_length = self.max_length
if text_input_ids is None or attention_mask is None:
if max_length is not None:
text_inputs = self.tokenizer( # type: ignore
caption,
return_tensors="pt",
add_special_tokens=True,
max_length=max_length,
padding="max_length",
truncation=True,
)
else:
text_inputs = self.tokenizer(caption, return_tensors="pt", add_special_tokens=True) # type: ignore
_text_input_ids: torch.Tensor = text_inputs.input_ids.to(self.model.device) # type: ignore
_attention_mask: torch.Tensor = text_inputs.attention_mask.to(self.model.device) # type: ignore
else:
_text_input_ids: torch.Tensor = text_input_ids.to(self.model.device) # type: ignore
_attention_mask: torch.Tensor = attention_mask.to(self.model.device) # type: ignore
outputs = self.model(_text_input_ids, attention_mask=_attention_mask)
embeddings = outputs.last_hidden_state
return embeddings