mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
implement CrossAttentionAdapter using chain operations
This commit is contained in:
parent
43075f60b0
commit
dc2c3e0163
|
@ -1,14 +1,16 @@
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Generic, TypeVar, Any
|
from typing import Generic, TypeVar, Any, Callable
|
||||||
|
|
||||||
from torch import Tensor, as_tensor, cat, zeros_like, device as Device, dtype as DType
|
from torch import Tensor, as_tensor, cat, zeros_like, device as Device, dtype as DType
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from refiners.fluxion.adapters.adapter import Adapter
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
from refiners.fluxion.adapters.lora import Lora
|
||||||
from refiners.foundationals.clip.image_encoder import CLIPImageEncoder
|
from refiners.foundationals.clip.image_encoder import CLIPImageEncoder
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||||
|
from refiners.fluxion.layers.module import Module
|
||||||
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
||||||
from refiners.fluxion.utils import image_to_tensor
|
from refiners.fluxion.utils import image_to_tensor
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
|
@ -48,133 +50,8 @@ class _CrossAttnIndex(IntEnum):
|
||||||
IMG_CROSS_ATTN = 1 # image cross-attention
|
IMG_CROSS_ATTN = 1 # image cross-attention
|
||||||
|
|
||||||
|
|
||||||
# Fluxion's Attention layer drop-in replacement implementing Decoupled Cross-Attention
|
class InjectionPoint(fl.Chain):
|
||||||
class IPAttention(fl.Chain):
|
pass
|
||||||
structural_attrs = [
|
|
||||||
"embedding_dim",
|
|
||||||
"text_sequence_length",
|
|
||||||
"image_sequence_length",
|
|
||||||
"scale",
|
|
||||||
"num_heads",
|
|
||||||
"heads_dim",
|
|
||||||
"key_embedding_dim",
|
|
||||||
"value_embedding_dim",
|
|
||||||
"inner_dim",
|
|
||||||
"use_bias",
|
|
||||||
"is_causal",
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
embedding_dim: int,
|
|
||||||
text_sequence_length: int = 77,
|
|
||||||
image_sequence_length: int = 4,
|
|
||||||
scale: float = 1.0,
|
|
||||||
num_heads: int = 1,
|
|
||||||
key_embedding_dim: int | None = None,
|
|
||||||
value_embedding_dim: int | None = None,
|
|
||||||
inner_dim: int | None = None,
|
|
||||||
use_bias: bool = True,
|
|
||||||
is_causal: bool | None = None,
|
|
||||||
device: Device | str | None = None,
|
|
||||||
dtype: DType | None = None,
|
|
||||||
) -> None:
|
|
||||||
assert (
|
|
||||||
embedding_dim % num_heads == 0
|
|
||||||
), f"embedding_dim {embedding_dim} must be divisible by num_heads {num_heads}"
|
|
||||||
self.embedding_dim = embedding_dim
|
|
||||||
self.text_sequence_length = text_sequence_length
|
|
||||||
self.image_sequence_length = image_sequence_length
|
|
||||||
self.scale = scale
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.heads_dim = embedding_dim // num_heads
|
|
||||||
self.key_embedding_dim = key_embedding_dim or embedding_dim
|
|
||||||
self.value_embedding_dim = value_embedding_dim or embedding_dim
|
|
||||||
self.inner_dim = inner_dim or embedding_dim
|
|
||||||
self.use_bias = use_bias
|
|
||||||
self.is_causal = is_causal
|
|
||||||
super().__init__(
|
|
||||||
fl.Distribute(
|
|
||||||
# Note: the same query is used for image cross-attention as for text cross-attention
|
|
||||||
fl.Linear(
|
|
||||||
in_features=self.embedding_dim,
|
|
||||||
out_features=self.inner_dim,
|
|
||||||
bias=self.use_bias,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
), # Wq
|
|
||||||
fl.Parallel(
|
|
||||||
fl.Chain(
|
|
||||||
fl.Slicing(dim=1, start=0, length=text_sequence_length),
|
|
||||||
fl.Linear(
|
|
||||||
in_features=self.key_embedding_dim,
|
|
||||||
out_features=self.inner_dim,
|
|
||||||
bias=self.use_bias,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
), # Wk
|
|
||||||
),
|
|
||||||
fl.Chain(
|
|
||||||
fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length),
|
|
||||||
fl.Linear(
|
|
||||||
in_features=self.key_embedding_dim,
|
|
||||||
out_features=self.inner_dim,
|
|
||||||
bias=self.use_bias,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
), # Wk'
|
|
||||||
),
|
|
||||||
),
|
|
||||||
fl.Parallel(
|
|
||||||
fl.Chain(
|
|
||||||
fl.Slicing(dim=1, start=0, length=text_sequence_length),
|
|
||||||
fl.Linear(
|
|
||||||
in_features=self.key_embedding_dim,
|
|
||||||
out_features=self.inner_dim,
|
|
||||||
bias=self.use_bias,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
), # Wv
|
|
||||||
),
|
|
||||||
fl.Chain(
|
|
||||||
fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length),
|
|
||||||
fl.Linear(
|
|
||||||
in_features=self.key_embedding_dim,
|
|
||||||
out_features=self.inner_dim,
|
|
||||||
bias=self.use_bias,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
), # Wv'
|
|
||||||
),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
fl.Sum(
|
|
||||||
fl.Chain(
|
|
||||||
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.TXT_CROSS_ATTN)),
|
|
||||||
ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal),
|
|
||||||
),
|
|
||||||
fl.Chain(
|
|
||||||
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.IMG_CROSS_ATTN)),
|
|
||||||
ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal),
|
|
||||||
fl.Lambda(func=self.scale_outputs),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
fl.Linear(
|
|
||||||
in_features=self.inner_dim,
|
|
||||||
out_features=self.embedding_dim,
|
|
||||||
bias=True,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def select_qkv(
|
|
||||||
self, query: Tensor, keys: tuple[Tensor, Tensor], values: tuple[Tensor, Tensor], index: _CrossAttnIndex
|
|
||||||
) -> tuple[Tensor, Tensor, Tensor]:
|
|
||||||
return (query, keys[index.value], values[index.value])
|
|
||||||
|
|
||||||
def scale_outputs(self, x: Tensor) -> Tensor:
|
|
||||||
return x * self.scale
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
||||||
|
@ -190,46 +67,100 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
||||||
self.text_sequence_length = text_sequence_length
|
self.text_sequence_length = text_sequence_length
|
||||||
self.image_sequence_length = image_sequence_length
|
self.image_sequence_length = image_sequence_length
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
|
||||||
with self.setup_adapter(target):
|
with self.setup_adapter(target):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
IPAttention(
|
fl.Distribute(
|
||||||
embedding_dim=target.embedding_dim,
|
# Note: the same query is used for image cross-attention as for text cross-attention
|
||||||
text_sequence_length=text_sequence_length,
|
InjectionPoint(), # Wq
|
||||||
image_sequence_length=image_sequence_length,
|
fl.Parallel(
|
||||||
scale=scale,
|
fl.Chain(
|
||||||
num_heads=target.num_heads,
|
fl.Slicing(dim=1, start=0, length=text_sequence_length),
|
||||||
key_embedding_dim=target.key_embedding_dim,
|
InjectionPoint(), # Wk
|
||||||
value_embedding_dim=target.value_embedding_dim,
|
),
|
||||||
inner_dim=target.inner_dim,
|
fl.Chain(
|
||||||
use_bias=target.use_bias,
|
fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length),
|
||||||
is_causal=target.is_causal,
|
fl.Linear(
|
||||||
device=target.device,
|
in_features=self.target.key_embedding_dim,
|
||||||
dtype=target.dtype,
|
out_features=self.target.inner_dim,
|
||||||
)
|
bias=self.target.use_bias,
|
||||||
|
device=target.device,
|
||||||
|
dtype=target.dtype,
|
||||||
|
), # Wk'
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Parallel(
|
||||||
|
fl.Chain(
|
||||||
|
fl.Slicing(dim=1, start=0, length=text_sequence_length),
|
||||||
|
InjectionPoint(), # Wv
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length),
|
||||||
|
fl.Linear(
|
||||||
|
in_features=self.target.key_embedding_dim,
|
||||||
|
out_features=self.target.inner_dim,
|
||||||
|
bias=self.target.use_bias,
|
||||||
|
device=target.device,
|
||||||
|
dtype=target.dtype,
|
||||||
|
), # Wv'
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Sum(
|
||||||
|
fl.Chain(
|
||||||
|
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.TXT_CROSS_ATTN)),
|
||||||
|
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.IMG_CROSS_ATTN)),
|
||||||
|
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
|
||||||
|
fl.Lambda(func=self.scale_outputs),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
InjectionPoint(), # proj
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_parameter_name(self, matrix: str, bias: bool = False) -> str:
|
def select_qkv(
|
||||||
match matrix:
|
self, query: Tensor, keys: tuple[Tensor, Tensor], values: tuple[Tensor, Tensor], index: _CrossAttnIndex
|
||||||
case "wq":
|
) -> tuple[Tensor, Tensor, Tensor]:
|
||||||
index = 0
|
return (query, keys[index.value], values[index.value])
|
||||||
case "wk":
|
|
||||||
index = 1
|
|
||||||
case "wk_prime":
|
|
||||||
index = 2
|
|
||||||
case "wv":
|
|
||||||
index = 3
|
|
||||||
case "wv_prime":
|
|
||||||
index = 4
|
|
||||||
case "proj":
|
|
||||||
index = 5
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"Unexpected matrix name {matrix}")
|
|
||||||
|
|
||||||
linear = list(self.IPAttention.layers(fl.Linear))[index]
|
def scale_outputs(self, x: Tensor) -> Tensor:
|
||||||
param = getattr(linear, "bias" if bias else "weight")
|
return x * self.scale
|
||||||
name = next((n for n, p in self.named_parameters() if id(p) == id(param)), None)
|
|
||||||
assert name is not None
|
def _predicate(self, k: type[Module]) -> Callable[[fl.Module, fl.Chain], bool]:
|
||||||
return name
|
def f(m: fl.Module, _: fl.Chain) -> bool:
|
||||||
|
if isinstance(m, Lora): # do not adapt LoRAs
|
||||||
|
raise StopIteration
|
||||||
|
return isinstance(m, k)
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
def _target_linears(self) -> list[fl.Linear]:
|
||||||
|
return [m for m, _ in self.target.walk(self._predicate(fl.Linear)) if isinstance(m, fl.Linear)]
|
||||||
|
|
||||||
|
def inject(self: "CrossAttentionAdapter", parent: fl.Chain | None = None) -> "CrossAttentionAdapter":
|
||||||
|
linears = self._target_linears()
|
||||||
|
assert len(linears) == 4 # Wq, Wk, Wv and Proj
|
||||||
|
|
||||||
|
injection_points = list(self.layers(InjectionPoint))
|
||||||
|
assert len(injection_points) == 4
|
||||||
|
|
||||||
|
for linear, ip in zip(linears, injection_points):
|
||||||
|
ip.append(linear)
|
||||||
|
assert len(ip) == 1
|
||||||
|
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
injection_points = list(self.layers(InjectionPoint))
|
||||||
|
assert len(injection_points) == 4
|
||||||
|
|
||||||
|
for ip in injection_points:
|
||||||
|
ip.pop()
|
||||||
|
assert len(ip) == 0
|
||||||
|
|
||||||
|
super().eject()
|
||||||
|
|
||||||
|
|
||||||
class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
|
@ -265,17 +196,6 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
continue
|
continue
|
||||||
cross_attn_state_dict[k.removeprefix(prefix)] = v
|
cross_attn_state_dict[k.removeprefix(prefix)] = v
|
||||||
|
|
||||||
# Retrieve original (frozen) cross-attention weights
|
|
||||||
# Note: this assumes the target UNet has already loaded weights
|
|
||||||
cross_attn_linears = list(cross_attn.target.layers(fl.Linear))
|
|
||||||
assert len(cross_attn_linears) == 4 # Wq, Wk, Wv and Proj
|
|
||||||
|
|
||||||
cross_attn_state_dict[cross_attn.get_parameter_name("wq")] = cross_attn_linears[0].weight
|
|
||||||
cross_attn_state_dict[cross_attn.get_parameter_name("wk")] = cross_attn_linears[1].weight
|
|
||||||
cross_attn_state_dict[cross_attn.get_parameter_name("wv")] = cross_attn_linears[2].weight
|
|
||||||
cross_attn_state_dict[cross_attn.get_parameter_name("proj")] = cross_attn_linears[3].weight
|
|
||||||
cross_attn_state_dict[cross_attn.get_parameter_name("proj", bias=True)] = cross_attn_linears[3].bias
|
|
||||||
|
|
||||||
cross_attn.load_state_dict(state_dict=cross_attn_state_dict)
|
cross_attn.load_state_dict(state_dict=cross_attn_state_dict)
|
||||||
|
|
||||||
def inject(self: "TIPAdapter", parent: fl.Chain | None = None) -> "TIPAdapter":
|
def inject(self: "TIPAdapter", parent: fl.Chain | None = None) -> "TIPAdapter":
|
||||||
|
|
14
tests/foundationals/latent_diffusion/test_image_prompt.py
Normal file
14
tests/foundationals/latent_diffusion/test_image_prompt.py
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
from refiners.foundationals.latent_diffusion.image_prompt import CrossAttentionAdapter
|
||||||
|
|
||||||
|
|
||||||
|
def test_cross_attention_adapter() -> None:
|
||||||
|
base = fl.Chain(fl.Attention(embedding_dim=4))
|
||||||
|
adapter = CrossAttentionAdapter(base.Attention).inject()
|
||||||
|
|
||||||
|
assert list(base) == [adapter]
|
||||||
|
|
||||||
|
adapter.eject()
|
||||||
|
|
||||||
|
assert len(base) == 1
|
||||||
|
assert isinstance(base[0], fl.Attention)
|
Loading…
Reference in a new issue