mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
ella adapter implementation. tested with sd1.5 model
This commit is contained in:
parent
a8efe5eab0
commit
277b0fd837
|
@ -55,6 +55,7 @@ test = [
|
||||||
# HQ-SAM missing dependency:
|
# HQ-SAM missing dependency:
|
||||||
# https://github.com/SysCV/sam-hq/pull/59
|
# https://github.com/SysCV/sam-hq/pull/59
|
||||||
"timm>=0.5.0",
|
"timm>=0.5.0",
|
||||||
|
"sentencepiece>=0.2.0",
|
||||||
]
|
]
|
||||||
conversion = [
|
conversion = [
|
||||||
"diffusers>=0.26.1",
|
"diffusers>=0.26.1",
|
||||||
|
|
102
scripts/conversion/convert_ella_adapter.py
Normal file
102
scripts/conversion/convert_ella_adapter.py
Normal file
|
@ -0,0 +1,102 @@
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import hf_hub_download # type: ignore
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import load_from_safetensors, save_to_safetensors
|
||||||
|
|
||||||
|
|
||||||
|
class Args(argparse.Namespace):
|
||||||
|
source_path: str
|
||||||
|
output_path: str | None
|
||||||
|
use_half: bool
|
||||||
|
|
||||||
|
|
||||||
|
def convert(args: Args) -> dict[str, torch.Tensor]:
|
||||||
|
if Path(args.source_path).suffix != ".safetensors":
|
||||||
|
args.source_path = hf_hub_download(
|
||||||
|
repo_id=args.source_path, filename="ella-sd1.5-tsc-t5xl.safetensors", local_dir="tests/weights/ELLA-Adapter"
|
||||||
|
)
|
||||||
|
weights = load_from_safetensors(args.source_path)
|
||||||
|
|
||||||
|
for key in list(weights.keys()):
|
||||||
|
if "latents" in key:
|
||||||
|
new_key = "PerceiverResampler.Latents.ParameterInitialized.weight"
|
||||||
|
weights[new_key] = weights.pop(key)
|
||||||
|
elif "time_embedding" in key:
|
||||||
|
new_key = key.replace("time_embedding", "TimestepEncoder.RangeEncoder").replace("linear", "Linear")
|
||||||
|
weights[new_key] = weights.pop(key)
|
||||||
|
elif "proj_in" in key:
|
||||||
|
new_key = f"PerceiverResampler.Linear.{key.split('.')[-1]}"
|
||||||
|
weights[new_key] = weights.pop(key)
|
||||||
|
elif "time_aware" in key:
|
||||||
|
new_key = f"PerceiverResampler.Residual.Linear.{key.split('.')[-1]}"
|
||||||
|
weights[new_key] = weights.pop(key)
|
||||||
|
elif "attn.in_proj" in key:
|
||||||
|
layer_num = int(key.split(".")[2])
|
||||||
|
query_param, key_param, value_param = weights.pop(key).chunk(3, dim=0)
|
||||||
|
param_type = "weight" if "weight" in key else "bias"
|
||||||
|
for i, param in enumerate([query_param, key_param, value_param]):
|
||||||
|
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_1.PerceiverAttention.Attention.Distribute.Linear_{i+1}.{param_type}"
|
||||||
|
weights[new_key] = param
|
||||||
|
elif "attn.out_proj" in key:
|
||||||
|
layer_num = int(key.split(".")[2])
|
||||||
|
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_1.PerceiverAttention.Attention.Linear.{key.split('.')[-1]}"
|
||||||
|
weights[new_key] = weights.pop(key)
|
||||||
|
elif "ln_ff" in key:
|
||||||
|
layer_num = int(key.split(".")[2])
|
||||||
|
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_2.AdaLayerNorm.Parallel.Chain.Linear.{key.split('.')[-1]}"
|
||||||
|
weights[new_key] = weights.pop(key)
|
||||||
|
elif "ln_1" in key or "ln_2" in key:
|
||||||
|
layer_num = int(key.split(".")[2])
|
||||||
|
n = 1 if int(key.split(".")[3].split("_")[-1]) == 2 else 2
|
||||||
|
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_1.PerceiverAttention.Distribute.AdaLayerNorm_{n}.Parallel.Chain.Linear.{key.split('.')[-1]}"
|
||||||
|
weights[new_key] = weights.pop(key)
|
||||||
|
elif "mlp" in key:
|
||||||
|
layer_num = int(key.split(".")[2])
|
||||||
|
n = 1 if "c_fc" in key else 2
|
||||||
|
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_2.FeedForward.Linear_{n}.{key.split('.')[-1]}"
|
||||||
|
weights[new_key] = weights.pop(key)
|
||||||
|
|
||||||
|
if args.use_half:
|
||||||
|
weights = {key: value.half() for key, value in weights.items()}
|
||||||
|
|
||||||
|
return weights
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Convert a pretrained Ella Adapter to refiners implementation")
|
||||||
|
parser.add_argument(
|
||||||
|
"--from",
|
||||||
|
type=str,
|
||||||
|
dest="source_path",
|
||||||
|
default="QQGYLab/ELLA",
|
||||||
|
help=(
|
||||||
|
"A path to a local .safetensors weights. If not provided, a repo from Hugging Face Hub will be used"
|
||||||
|
"Default to QQGYLab/ELLA"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--to",
|
||||||
|
type=str,
|
||||||
|
dest="output_path",
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Path to save the converted model (extension will be .safetensors). If not specified, the output path will"
|
||||||
|
" be the source path with the prefix set to refiners"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--half",
|
||||||
|
action="store_true",
|
||||||
|
dest="use_half",
|
||||||
|
default=True,
|
||||||
|
help="Use this flag to save the output file as half precision (default: full precision).",
|
||||||
|
)
|
||||||
|
args = parser.parse_args(namespace=Args())
|
||||||
|
weights = convert(args)
|
||||||
|
if args.output_path is None:
|
||||||
|
args.output_path = f"{Path(args.source_path).stem}-refiners.safetensors"
|
||||||
|
save_to_safetensors(path=args.output_path, tensors=weights)
|
|
@ -353,6 +353,29 @@ def download_ip_adapter():
|
||||||
download_files(urls, sdxl_models_folder)
|
download_files(urls, sdxl_models_folder)
|
||||||
|
|
||||||
|
|
||||||
|
def download_t5xl_fp16():
|
||||||
|
base_folder = os.path.join(test_weights_dir, "QQGYLab", "T5XLFP16")
|
||||||
|
urls = [
|
||||||
|
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/config.json",
|
||||||
|
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/model.safetensors",
|
||||||
|
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/special_tokens_map.json",
|
||||||
|
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/spiece.model",
|
||||||
|
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/tokenizer.json",
|
||||||
|
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/tokenizer_config.json",
|
||||||
|
]
|
||||||
|
download_files(urls, base_folder)
|
||||||
|
|
||||||
|
|
||||||
|
def download_ella_adapter():
|
||||||
|
download_t5xl_fp16()
|
||||||
|
base_folder = os.path.join(test_weights_dir, "QQGYLab", "ELLA")
|
||||||
|
download_file(
|
||||||
|
"https://huggingface.co/QQGYLab/ELLA/resolve/main/ella-sd1.5-tsc-t5xl.safetensors",
|
||||||
|
base_folder,
|
||||||
|
expected_hash="5af7b200",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def download_t2i_adapter():
|
def download_t2i_adapter():
|
||||||
base_folder = os.path.join(test_weights_dir, "TencentARC", "t2iadapter_depth_sd15v2")
|
base_folder = os.path.join(test_weights_dir, "TencentARC", "t2iadapter_depth_sd15v2")
|
||||||
urls = [
|
urls = [
|
||||||
|
@ -689,6 +712,17 @@ def convert_ip_adapter():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_ella_adapter():
|
||||||
|
os.makedirs("tests/weights/ELLA-Adapter", exist_ok=True)
|
||||||
|
run_conversion_script(
|
||||||
|
"convert_ella_adapter.py",
|
||||||
|
"tests/weights/QQGYLab/ELLA/ella-sd1.5-tsc-t5xl.safetensors",
|
||||||
|
"tests/weights/ELLA-Adapter/ella-sd1.5-tsc-t5xl.safetensors",
|
||||||
|
half=True,
|
||||||
|
expected_hash="b8244cb6",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert_t2i_adapter():
|
def convert_t2i_adapter():
|
||||||
os.makedirs("tests/weights/T2I-Adapter", exist_ok=True)
|
os.makedirs("tests/weights/T2I-Adapter", exist_ok=True)
|
||||||
run_conversion_script(
|
run_conversion_script(
|
||||||
|
@ -860,6 +894,7 @@ def download_all():
|
||||||
download_unclip()
|
download_unclip()
|
||||||
download_ip_adapter()
|
download_ip_adapter()
|
||||||
download_t2i_adapter()
|
download_t2i_adapter()
|
||||||
|
download_ella_adapter()
|
||||||
download_sam()
|
download_sam()
|
||||||
download_hq_sam()
|
download_hq_sam()
|
||||||
download_dinov2()
|
download_dinov2()
|
||||||
|
@ -884,6 +919,7 @@ def convert_all():
|
||||||
convert_unclip()
|
convert_unclip()
|
||||||
convert_ip_adapter()
|
convert_ip_adapter()
|
||||||
convert_t2i_adapter()
|
convert_t2i_adapter()
|
||||||
|
convert_ella_adapter()
|
||||||
convert_sam()
|
convert_sam()
|
||||||
convert_hq_sam()
|
convert_hq_sam()
|
||||||
convert_dinov2()
|
convert_dinov2()
|
||||||
|
|
|
@ -8,6 +8,7 @@ from refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter
|
||||||
from refiners.foundationals.latent_diffusion.solvers import DPMSolver, LCMSolver, Solver
|
from refiners.foundationals.latent_diffusion.solvers import DPMSolver, LCMSolver, Solver
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
|
||||||
SD1ControlnetAdapter,
|
SD1ControlnetAdapter,
|
||||||
|
SD1ELLAAdapter,
|
||||||
SD1IPAdapter,
|
SD1IPAdapter,
|
||||||
SD1T2IAdapter,
|
SD1T2IAdapter,
|
||||||
SD1UNet,
|
SD1UNet,
|
||||||
|
@ -32,6 +33,7 @@ __all__ = [
|
||||||
"SD1ControlnetAdapter",
|
"SD1ControlnetAdapter",
|
||||||
"SD1IPAdapter",
|
"SD1IPAdapter",
|
||||||
"SD1T2IAdapter",
|
"SD1T2IAdapter",
|
||||||
|
"SD1ELLAAdapter",
|
||||||
"SDXLUNet",
|
"SDXLUNet",
|
||||||
"DoubleTextEncoder",
|
"DoubleTextEncoder",
|
||||||
"SDXLIPAdapter",
|
"SDXLIPAdapter",
|
||||||
|
|
285
src/refiners/foundationals/latent_diffusion/ella_adapter.py
Normal file
285
src/refiners/foundationals/latent_diffusion/ella_adapter.py
Normal file
|
@ -0,0 +1,285 @@
|
||||||
|
from typing import TYPE_CHECKING, Any, Generic, Iterable, TypeVar
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor, cat, device as Device, dtype as DType
|
||||||
|
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
from refiners.fluxion.context import Contexts
|
||||||
|
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
|
||||||
|
from refiners.foundationals.latent_diffusion.range_adapter import RangeEncoder
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="SD1UNet | SDXLUNet")
|
||||||
|
TELLAAdapter = TypeVar("TELLAAdapter", bound="ELLAAdapter[Any]")
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNormNoAffine(nn.LayerNorm, fl.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
normalized_shape: int | Iterable[int],
|
||||||
|
eps: float = 1e-5,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(normalized_shape, eps=eps, elementwise_affine=False, device=device, dtype=dtype) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEncoder(fl.Passthrough):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
time_embedding_dim: int,
|
||||||
|
time_channel: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
fl.UseContext("diffusion", "timestep"),
|
||||||
|
RangeEncoder(time_channel, time_embedding_dim, device=device, dtype=dtype),
|
||||||
|
fl.SetContext("ella", "timestep_embedding"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SquaredReLU(fl.ReLU):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return super().forward(x).pow(2)
|
||||||
|
|
||||||
|
|
||||||
|
class AdaLayerNorm(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
time_embedding_dim: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
fl.Parallel(
|
||||||
|
LayerNormNoAffine(embedding_dim, eps=1e-6, device=device, dtype=dtype),
|
||||||
|
fl.Chain(
|
||||||
|
fl.UseContext("ella", "timestep_embedding"),
|
||||||
|
fl.SiLU(),
|
||||||
|
fl.Linear(time_embedding_dim, embedding_dim * 2, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Lambda(self._scale_shift_tensors),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._init_parameters()
|
||||||
|
|
||||||
|
def _init_parameters(self) -> None:
|
||||||
|
_linear: fl.Linear = self.ensure_find(fl.Linear)
|
||||||
|
nn.init.zeros_(_linear.weight)
|
||||||
|
nn.init.zeros_(_linear.bias)
|
||||||
|
|
||||||
|
def _scale_shift_tensors(self, x: Tensor, time_embedding: Tensor) -> Tensor:
|
||||||
|
shift, scale = time_embedding.chunk(2, dim=-1)
|
||||||
|
return x * (1 + scale) + shift
|
||||||
|
|
||||||
|
|
||||||
|
class ParameterInitialized(fl.Parameter):
|
||||||
|
def __init__(
|
||||||
|
self, *dims: int, requires_grad: bool = True, device: Device | str | None = None, dtype: DType | None = None
|
||||||
|
) -> None:
|
||||||
|
super().__init__(*dims, requires_grad=requires_grad, device=device, dtype=dtype)
|
||||||
|
nn.init.normal_(self.weight, mean=0, std=dims[1] ** 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
class Latents(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_latents: int,
|
||||||
|
width: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
ParameterInitialized(
|
||||||
|
num_latents,
|
||||||
|
width,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PerceiverAttention(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
width: int,
|
||||||
|
num_heads: int,
|
||||||
|
timestep_embedding_dim: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
fl.Distribute(
|
||||||
|
AdaLayerNorm(width, timestep_embedding_dim, device=device, dtype=dtype),
|
||||||
|
AdaLayerNorm(width, timestep_embedding_dim, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Parallel(
|
||||||
|
fl.GetArg(index=1),
|
||||||
|
fl.Lambda(func=self.to_kv),
|
||||||
|
fl.Lambda(func=self.to_kv),
|
||||||
|
),
|
||||||
|
fl.Attention(embedding_dim=width, num_heads=num_heads, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_kv(self, x: Tensor, latents: Tensor) -> Tensor:
|
||||||
|
return cat((latents, x), dim=-2)
|
||||||
|
|
||||||
|
|
||||||
|
class OutputProjection(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self, width: int, output_dim: int, device: Device | str | None = None, dtype: DType | None = None
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
fl.Linear(width, output_dim, device=device, dtype=dtype),
|
||||||
|
fl.LayerNorm(output_dim, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(fl.Chain):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerLayer(fl.Chain):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
width: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
fl.Linear(width, width * 4, device=device, dtype=dtype),
|
||||||
|
SquaredReLU(),
|
||||||
|
fl.Linear(width * 4, width, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PerceiverResampler(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
time_embedding_dim: int,
|
||||||
|
width: int,
|
||||||
|
num_layers: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_latents: int,
|
||||||
|
output_dim: int | None,
|
||||||
|
input_dim: int | None,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
fl.Linear(input_dim, width, device=device, dtype=dtype) if input_dim else fl.Identity(),
|
||||||
|
fl.SetContext("perceiver_resampler", "x"),
|
||||||
|
Latents(num_latents, width, device=device, dtype=dtype),
|
||||||
|
fl.Residual(
|
||||||
|
fl.UseContext("ella", "timestep_embedding"),
|
||||||
|
fl.SiLU(),
|
||||||
|
fl.Linear(time_embedding_dim, width, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
Transformer(
|
||||||
|
TransformerLayer(
|
||||||
|
fl.Residual(
|
||||||
|
fl.Parallel(fl.UseContext(context="perceiver_resampler", key="x"), fl.Identity()),
|
||||||
|
PerceiverAttention(width, num_heads, time_embedding_dim, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Residual(
|
||||||
|
AdaLayerNorm(width, time_embedding_dim, device=device, dtype=dtype),
|
||||||
|
FeedForward(width, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
),
|
||||||
|
OutputProjection(width, output_dim, device=device, dtype=dtype) if output_dim else fl.Identity(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {"perceiver_resampler": {"x": None}}
|
||||||
|
|
||||||
|
|
||||||
|
class ELLA(fl.Passthrough):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
time_channel: int,
|
||||||
|
timestep_embedding_dim: int,
|
||||||
|
width: int,
|
||||||
|
num_layers: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_latents: int,
|
||||||
|
input_dim: int | None = None,
|
||||||
|
out_dim: int | None = None,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
TimestepEncoder(timestep_embedding_dim, time_channel, device=device, dtype=dtype),
|
||||||
|
fl.UseContext("adapted_cross_attention_block", "llm_text_embedding"),
|
||||||
|
PerceiverResampler(
|
||||||
|
timestep_embedding_dim,
|
||||||
|
width,
|
||||||
|
num_layers,
|
||||||
|
num_heads,
|
||||||
|
num_latents,
|
||||||
|
out_dim,
|
||||||
|
input_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.SetContext("ella", "latents"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ELLACrossAttentionAdapter(fl.Chain, Adapter[fl.UseContext]):
|
||||||
|
# TODO: concatenate the latents with the clip text embedding https://github.com/TencentQQGYLab/ELLA/tree/main?tab=readme-ov-file#3-ellaclip-for-community-models
|
||||||
|
def __init__(self, target: fl.UseContext) -> None:
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(fl.UseContext("ella", "latents"))
|
||||||
|
|
||||||
|
|
||||||
|
class ELLAAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
|
def __init__(self, target: T, latents_encoder: ELLA, weights: dict[str, Tensor] | None = None) -> None:
|
||||||
|
if weights is not None:
|
||||||
|
latents_encoder.load_state_dict(weights)
|
||||||
|
|
||||||
|
self._latents_encoder = [latents_encoder]
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(target)
|
||||||
|
self.sub_adapters = [
|
||||||
|
ELLACrossAttentionAdapter(use_context)
|
||||||
|
for cross_attn in target.layers(CrossAttentionBlock)
|
||||||
|
for use_context in cross_attn.layers(fl.UseContext)
|
||||||
|
]
|
||||||
|
|
||||||
|
def inject(self: TELLAAdapter, parent: fl.Chain | None = None) -> TELLAAdapter:
|
||||||
|
for adapter in self.sub_adapters:
|
||||||
|
adapter.inject()
|
||||||
|
self.target.insert(0, self.latents_encoder)
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
for adapter in self.sub_adapters:
|
||||||
|
adapter.eject()
|
||||||
|
self.target.pop(0)
|
||||||
|
super().eject()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def latents_encoder(self) -> ELLA:
|
||||||
|
return self._latents_encoder[0]
|
||||||
|
|
||||||
|
def set_llm_text_embedding(self, text_embedding: Tensor) -> None:
|
||||||
|
self.set_context("adapted_cross_attention_block", {"llm_text_embedding": text_embedding})
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {"ella": {"timestep_embedding": None, "latents": None}}
|
|
@ -1,4 +1,5 @@
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.ella_adapter import SD1ELLAAdapter
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.image_prompt import SD1IPAdapter
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.image_prompt import SD1IPAdapter
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
||||||
|
@ -18,4 +19,5 @@ __all__ = [
|
||||||
"SD1IPAdapter",
|
"SD1IPAdapter",
|
||||||
"SD1T2IAdapter",
|
"SD1T2IAdapter",
|
||||||
"ICLight",
|
"ICLight",
|
||||||
|
"SD1ELLAAdapter",
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from refiners.foundationals.latent_diffusion.ella_adapter import ELLA, ELLAAdapter
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
|
|
||||||
|
|
||||||
|
class SD1ELLAAdapter(ELLAAdapter[SD1UNet]):
|
||||||
|
def __init__(self, target: SD1UNet, weights: dict[str, Tensor] | None = None) -> None:
|
||||||
|
latents_encoder = ELLA(
|
||||||
|
time_channel=320,
|
||||||
|
timestep_embedding_dim=768,
|
||||||
|
width=768,
|
||||||
|
num_layers=6,
|
||||||
|
num_heads=8,
|
||||||
|
num_latents=64,
|
||||||
|
input_dim=2048,
|
||||||
|
device=target.device,
|
||||||
|
dtype=target.dtype,
|
||||||
|
)
|
||||||
|
super().__init__(target=target, latents_encoder=latents_encoder, weights=weights)
|
41
tests/adapters/test_ella_adapter.py
Normal file
41
tests/adapters/test_ella_adapter.py
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
from refiners.fluxion.utils import no_grad
|
||||||
|
from refiners.foundationals.latent_diffusion import SD1ELLAAdapter, SD1UNet
|
||||||
|
from refiners.foundationals.latent_diffusion.ella_adapter import ELLACrossAttentionAdapter
|
||||||
|
|
||||||
|
|
||||||
|
def new_adapter(target: SD1UNet) -> SD1ELLAAdapter:
|
||||||
|
return SD1ELLAAdapter(target=target)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_inject_eject(test_device: torch.device):
|
||||||
|
unet = SD1UNet(in_channels=4, device=test_device, dtype=torch.float16)
|
||||||
|
initial_repr = repr(unet)
|
||||||
|
adapter = new_adapter(unet)
|
||||||
|
assert repr(unet) == initial_repr
|
||||||
|
adapter.inject()
|
||||||
|
assert repr(unet) != initial_repr
|
||||||
|
adapter.eject()
|
||||||
|
assert repr(unet) == initial_repr
|
||||||
|
adapter.inject()
|
||||||
|
assert repr(unet) != initial_repr
|
||||||
|
adapter.eject()
|
||||||
|
assert repr(unet) == initial_repr
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_ella_cross_attention(test_device: torch.device):
|
||||||
|
unet = SD1UNet(in_channels=4, device=test_device, dtype=torch.float16)
|
||||||
|
adapter = new_adapter(unet).inject()
|
||||||
|
|
||||||
|
def predicate(m: fl.Module, p: fl.Chain) -> bool:
|
||||||
|
return isinstance(p, ELLACrossAttentionAdapter) and isinstance(m, fl.UseContext)
|
||||||
|
|
||||||
|
for m, _ in unet.walk(predicate):
|
||||||
|
assert isinstance(m, fl.UseContext)
|
||||||
|
assert m.context == "ella"
|
||||||
|
assert m.key == "latents"
|
||||||
|
assert len(adapter.sub_adapters) == 32
|
|
@ -7,7 +7,7 @@ from warnings import warn
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tests.utils import ensure_similar_images
|
from tests.utils import T5TextEmbedder, ensure_similar_images
|
||||||
|
|
||||||
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
||||||
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad
|
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad
|
||||||
|
@ -16,6 +16,7 @@ from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
from refiners.foundationals.latent_diffusion import (
|
from refiners.foundationals.latent_diffusion import (
|
||||||
ControlLoraAdapter,
|
ControlLoraAdapter,
|
||||||
SD1ControlnetAdapter,
|
SD1ControlnetAdapter,
|
||||||
|
SD1ELLAAdapter,
|
||||||
SD1IPAdapter,
|
SD1IPAdapter,
|
||||||
SD1T2IAdapter,
|
SD1T2IAdapter,
|
||||||
SD1UNet,
|
SD1UNet,
|
||||||
|
@ -116,6 +117,11 @@ def expected_image_std_init_image(ref_path: Path) -> Image.Image:
|
||||||
return _img_open(ref_path / "expected_std_init_image.png").convert("RGB")
|
return _img_open(ref_path / "expected_std_init_image.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_image_ella_adapter(ref_path: Path) -> Image.Image:
|
||||||
|
return _img_open(ref_path / "expected_image_ella_adapter.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_std_inpainting(ref_path: Path) -> Image.Image:
|
def expected_image_std_inpainting(ref_path: Path) -> Image.Image:
|
||||||
return _img_open(ref_path / "expected_std_inpainting.png").convert("RGB")
|
return _img_open(ref_path / "expected_std_inpainting.png").convert("RGB")
|
||||||
|
@ -506,6 +512,29 @@ def lda_ft_mse_weights(test_weights_path: Path) -> Path:
|
||||||
return lda_weights
|
return lda_weights
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def ella_weights(test_weights_path: Path) -> tuple[Path, Path]:
|
||||||
|
ella_adapter_weights = test_weights_path / "ELLA-Adapter" / "ella-sd1.5-tsc-t5xl.safetensors"
|
||||||
|
if not ella_adapter_weights.is_file():
|
||||||
|
warn(f"could not find weights at {ella_adapter_weights}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
t5xl_weights = test_weights_path / "QQGYLab" / "T5XLFP16"
|
||||||
|
t5xl_files = [
|
||||||
|
"config.json",
|
||||||
|
"model.safetensors",
|
||||||
|
"special_tokens_map.json",
|
||||||
|
"spiece.model",
|
||||||
|
"tokenizer_config.json",
|
||||||
|
"tokenizer.json",
|
||||||
|
]
|
||||||
|
for file in t5xl_files:
|
||||||
|
if not (t5xl_weights / file).is_file():
|
||||||
|
warn(f"could not find weights at {t5xl_weights / file}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
|
||||||
|
return (ella_adapter_weights, t5xl_weights)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def ip_adapter_weights(test_weights_path: Path) -> Path:
|
def ip_adapter_weights(test_weights_path: Path) -> Path:
|
||||||
ip_adapter_weights = test_weights_path / "ip-adapter_sd15.safetensors"
|
ip_adapter_weights = test_weights_path / "ip-adapter_sd15.safetensors"
|
||||||
|
@ -1799,6 +1828,43 @@ def test_diffusion_textual_inversion_random_init(
|
||||||
ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_diffusion_ella_adapter(
|
||||||
|
sd15_std_float16: StableDiffusion_1,
|
||||||
|
ella_weights: tuple[Path, Path],
|
||||||
|
expected_image_ella_adapter: Image.Image,
|
||||||
|
test_device: torch.device,
|
||||||
|
):
|
||||||
|
sd15 = sd15_std_float16
|
||||||
|
ella_adapter_weights, t5xl_weights = ella_weights
|
||||||
|
t5_encoder = T5TextEmbedder(pretrained_path=t5xl_weights, max_length=128).to(test_device, torch.float16)
|
||||||
|
|
||||||
|
prompt = "a chinese man wearing a white shirt and a checkered headscarf, holds a large falcon near his shoulder. the falcon has dark feathers with a distinctive beak. the background consists of a clear sky and a fence, suggesting an outdoor setting, possibly a desert or arid region"
|
||||||
|
negative_prompt = ""
|
||||||
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
|
assert clip_text_embedding.dtype == torch.float16
|
||||||
|
|
||||||
|
llm_text_embedding, negative_prompt_embeds = t5_encoder(prompt), t5_encoder(negative_prompt)
|
||||||
|
prompt_embedding = torch.cat((negative_prompt_embeds, llm_text_embedding)).to(test_device, torch.float16)
|
||||||
|
|
||||||
|
adapter = SD1ELLAAdapter(target=sd15.unet, weights=load_from_safetensors(ella_adapter_weights))
|
||||||
|
adapter.inject()
|
||||||
|
sd15.set_inference_steps(50)
|
||||||
|
manual_seed(1001)
|
||||||
|
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
||||||
|
|
||||||
|
for step in sd15.steps:
|
||||||
|
adapter.set_llm_text_embedding(prompt_embedding)
|
||||||
|
x = sd15(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
condition_scale=12,
|
||||||
|
)
|
||||||
|
predicted_image = sd15.lda.latents_to_image(x)
|
||||||
|
ensure_similar_images(predicted_image, expected_image_ella_adapter, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_diffusion_ip_adapter(
|
def test_diffusion_ip_adapter(
|
||||||
sd15_ddim_lda_ft_mse: StableDiffusion_1,
|
sd15_ddim_lda_ft_mse: StableDiffusion_1,
|
||||||
|
|
BIN
tests/e2e/test_diffusion_ref/expected_image_ella_adapter.png
Normal file
BIN
tests/e2e/test_diffusion_ref/expected_image_ella_adapter.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 414 KiB |
|
@ -1,7 +1,11 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import piq # type: ignore
|
import piq # type: ignore
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def compare_images(img_1: Image.Image, img_2: Image.Image) -> tuple[int, float]:
|
def compare_images(img_1: Image.Image, img_2: Image.Image) -> tuple[int, float]:
|
||||||
|
@ -16,3 +20,46 @@ def ensure_similar_images(img_1: Image.Image, img_2: Image.Image, min_psnr: int
|
||||||
assert (psnr >= min_psnr) and (
|
assert (psnr >= min_psnr) and (
|
||||||
ssim >= min_ssim
|
ssim >= min_ssim
|
||||||
), f"PSNR {psnr} / SSIM {ssim}, expected at least {min_psnr} / {min_ssim}"
|
), f"PSNR {psnr} / SSIM {ssim}, expected at least {min_psnr} / {min_ssim}"
|
||||||
|
|
||||||
|
|
||||||
|
class T5TextEmbedder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, pretrained_path: Path = Path("tests/weights/QQGYLab/T5XLFP16"), max_length: int | None = None
|
||||||
|
) -> None:
|
||||||
|
super().__init__() # type: ignore[reportUnknownMemberType]
|
||||||
|
self.model: nn.Module = T5EncoderModel.from_pretrained(pretrained_path, local_files_only=True) # type: ignore
|
||||||
|
self.tokenizer: transformers.T5Tokenizer = T5Tokenizer.from_pretrained(pretrained_path, local_files_only=True) # type: ignore
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
caption: str,
|
||||||
|
text_input_ids: torch.Tensor | None = None,
|
||||||
|
attention_mask: torch.Tensor | None = None,
|
||||||
|
max_length: int | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if max_length is None:
|
||||||
|
max_length = self.max_length
|
||||||
|
|
||||||
|
if text_input_ids is None or attention_mask is None:
|
||||||
|
if max_length is not None:
|
||||||
|
text_inputs = self.tokenizer( # type: ignore
|
||||||
|
caption,
|
||||||
|
return_tensors="pt",
|
||||||
|
add_special_tokens=True,
|
||||||
|
max_length=max_length,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
text_inputs = self.tokenizer(caption, return_tensors="pt", add_special_tokens=True) # type: ignore
|
||||||
|
_text_input_ids: torch.Tensor = text_inputs.input_ids.to(self.model.device) # type: ignore
|
||||||
|
_attention_mask: torch.Tensor = text_inputs.attention_mask.to(self.model.device) # type: ignore
|
||||||
|
else:
|
||||||
|
_text_input_ids: torch.Tensor = text_input_ids.to(self.model.device) # type: ignore
|
||||||
|
_attention_mask: torch.Tensor = attention_mask.to(self.model.device) # type: ignore
|
||||||
|
|
||||||
|
outputs = self.model(_text_input_ids, attention_mask=_attention_mask)
|
||||||
|
|
||||||
|
embeddings = outputs.last_hidden_state
|
||||||
|
return embeddings
|
||||||
|
|
Loading…
Reference in a new issue