implement CrossAttentionAdapter using chain operations

This commit is contained in:
Pierre Chapuis 2023-09-08 18:51:54 +02:00
parent 43075f60b0
commit dc2c3e0163
2 changed files with 108 additions and 174 deletions

View file

@ -1,14 +1,16 @@
from enum import IntEnum from enum import IntEnum
from functools import partial from functools import partial
from typing import Generic, TypeVar, Any from typing import Generic, TypeVar, Any, Callable
from torch import Tensor, as_tensor, cat, zeros_like, device as Device, dtype as DType from torch import Tensor, as_tensor, cat, zeros_like, device as Device, dtype as DType
from PIL import Image from PIL import Image
from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.adapters.lora import Lora
from refiners.foundationals.clip.image_encoder import CLIPImageEncoder from refiners.foundationals.clip.image_encoder import CLIPImageEncoder
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
from refiners.fluxion.layers.module import Module
from refiners.fluxion.layers.attentions import ScaledDotProductAttention from refiners.fluxion.layers.attentions import ScaledDotProductAttention
from refiners.fluxion.utils import image_to_tensor from refiners.fluxion.utils import image_to_tensor
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
@ -48,133 +50,8 @@ class _CrossAttnIndex(IntEnum):
IMG_CROSS_ATTN = 1 # image cross-attention IMG_CROSS_ATTN = 1 # image cross-attention
# Fluxion's Attention layer drop-in replacement implementing Decoupled Cross-Attention class InjectionPoint(fl.Chain):
class IPAttention(fl.Chain): pass
structural_attrs = [
"embedding_dim",
"text_sequence_length",
"image_sequence_length",
"scale",
"num_heads",
"heads_dim",
"key_embedding_dim",
"value_embedding_dim",
"inner_dim",
"use_bias",
"is_causal",
]
def __init__(
self,
embedding_dim: int,
text_sequence_length: int = 77,
image_sequence_length: int = 4,
scale: float = 1.0,
num_heads: int = 1,
key_embedding_dim: int | None = None,
value_embedding_dim: int | None = None,
inner_dim: int | None = None,
use_bias: bool = True,
is_causal: bool | None = None,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
assert (
embedding_dim % num_heads == 0
), f"embedding_dim {embedding_dim} must be divisible by num_heads {num_heads}"
self.embedding_dim = embedding_dim
self.text_sequence_length = text_sequence_length
self.image_sequence_length = image_sequence_length
self.scale = scale
self.num_heads = num_heads
self.heads_dim = embedding_dim // num_heads
self.key_embedding_dim = key_embedding_dim or embedding_dim
self.value_embedding_dim = value_embedding_dim or embedding_dim
self.inner_dim = inner_dim or embedding_dim
self.use_bias = use_bias
self.is_causal = is_causal
super().__init__(
fl.Distribute(
# Note: the same query is used for image cross-attention as for text cross-attention
fl.Linear(
in_features=self.embedding_dim,
out_features=self.inner_dim,
bias=self.use_bias,
device=device,
dtype=dtype,
), # Wq
fl.Parallel(
fl.Chain(
fl.Slicing(dim=1, start=0, length=text_sequence_length),
fl.Linear(
in_features=self.key_embedding_dim,
out_features=self.inner_dim,
bias=self.use_bias,
device=device,
dtype=dtype,
), # Wk
),
fl.Chain(
fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length),
fl.Linear(
in_features=self.key_embedding_dim,
out_features=self.inner_dim,
bias=self.use_bias,
device=device,
dtype=dtype,
), # Wk'
),
),
fl.Parallel(
fl.Chain(
fl.Slicing(dim=1, start=0, length=text_sequence_length),
fl.Linear(
in_features=self.key_embedding_dim,
out_features=self.inner_dim,
bias=self.use_bias,
device=device,
dtype=dtype,
), # Wv
),
fl.Chain(
fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length),
fl.Linear(
in_features=self.key_embedding_dim,
out_features=self.inner_dim,
bias=self.use_bias,
device=device,
dtype=dtype,
), # Wv'
),
),
),
fl.Sum(
fl.Chain(
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.TXT_CROSS_ATTN)),
ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal),
),
fl.Chain(
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.IMG_CROSS_ATTN)),
ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal),
fl.Lambda(func=self.scale_outputs),
),
),
fl.Linear(
in_features=self.inner_dim,
out_features=self.embedding_dim,
bias=True,
device=device,
dtype=dtype,
),
)
def select_qkv(
self, query: Tensor, keys: tuple[Tensor, Tensor], values: tuple[Tensor, Tensor], index: _CrossAttnIndex
) -> tuple[Tensor, Tensor, Tensor]:
return (query, keys[index.value], values[index.value])
def scale_outputs(self, x: Tensor) -> Tensor:
return x * self.scale
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
@ -190,46 +67,100 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
self.text_sequence_length = text_sequence_length self.text_sequence_length = text_sequence_length
self.image_sequence_length = image_sequence_length self.image_sequence_length = image_sequence_length
self.scale = scale self.scale = scale
with self.setup_adapter(target): with self.setup_adapter(target):
super().__init__( super().__init__(
IPAttention( fl.Distribute(
embedding_dim=target.embedding_dim, # Note: the same query is used for image cross-attention as for text cross-attention
text_sequence_length=text_sequence_length, InjectionPoint(), # Wq
image_sequence_length=image_sequence_length, fl.Parallel(
scale=scale, fl.Chain(
num_heads=target.num_heads, fl.Slicing(dim=1, start=0, length=text_sequence_length),
key_embedding_dim=target.key_embedding_dim, InjectionPoint(), # Wk
value_embedding_dim=target.value_embedding_dim, ),
inner_dim=target.inner_dim, fl.Chain(
use_bias=target.use_bias, fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length),
is_causal=target.is_causal, fl.Linear(
in_features=self.target.key_embedding_dim,
out_features=self.target.inner_dim,
bias=self.target.use_bias,
device=target.device, device=target.device,
dtype=target.dtype, dtype=target.dtype,
) ), # Wk'
),
),
fl.Parallel(
fl.Chain(
fl.Slicing(dim=1, start=0, length=text_sequence_length),
InjectionPoint(), # Wv
),
fl.Chain(
fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length),
fl.Linear(
in_features=self.target.key_embedding_dim,
out_features=self.target.inner_dim,
bias=self.target.use_bias,
device=target.device,
dtype=target.dtype,
), # Wv'
),
),
),
fl.Sum(
fl.Chain(
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.TXT_CROSS_ATTN)),
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
),
fl.Chain(
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.IMG_CROSS_ATTN)),
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
fl.Lambda(func=self.scale_outputs),
),
),
InjectionPoint(), # proj
) )
def get_parameter_name(self, matrix: str, bias: bool = False) -> str: def select_qkv(
match matrix: self, query: Tensor, keys: tuple[Tensor, Tensor], values: tuple[Tensor, Tensor], index: _CrossAttnIndex
case "wq": ) -> tuple[Tensor, Tensor, Tensor]:
index = 0 return (query, keys[index.value], values[index.value])
case "wk":
index = 1
case "wk_prime":
index = 2
case "wv":
index = 3
case "wv_prime":
index = 4
case "proj":
index = 5
case _:
raise ValueError(f"Unexpected matrix name {matrix}")
linear = list(self.IPAttention.layers(fl.Linear))[index] def scale_outputs(self, x: Tensor) -> Tensor:
param = getattr(linear, "bias" if bias else "weight") return x * self.scale
name = next((n for n, p in self.named_parameters() if id(p) == id(param)), None)
assert name is not None def _predicate(self, k: type[Module]) -> Callable[[fl.Module, fl.Chain], bool]:
return name def f(m: fl.Module, _: fl.Chain) -> bool:
if isinstance(m, Lora): # do not adapt LoRAs
raise StopIteration
return isinstance(m, k)
return f
def _target_linears(self) -> list[fl.Linear]:
return [m for m, _ in self.target.walk(self._predicate(fl.Linear)) if isinstance(m, fl.Linear)]
def inject(self: "CrossAttentionAdapter", parent: fl.Chain | None = None) -> "CrossAttentionAdapter":
linears = self._target_linears()
assert len(linears) == 4 # Wq, Wk, Wv and Proj
injection_points = list(self.layers(InjectionPoint))
assert len(injection_points) == 4
for linear, ip in zip(linears, injection_points):
ip.append(linear)
assert len(ip) == 1
return super().inject(parent)
def eject(self) -> None:
injection_points = list(self.layers(InjectionPoint))
assert len(injection_points) == 4
for ip in injection_points:
ip.pop()
assert len(ip) == 0
super().eject()
class IPAdapter(Generic[T], fl.Chain, Adapter[T]): class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
@ -265,17 +196,6 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
continue continue
cross_attn_state_dict[k.removeprefix(prefix)] = v cross_attn_state_dict[k.removeprefix(prefix)] = v
# Retrieve original (frozen) cross-attention weights
# Note: this assumes the target UNet has already loaded weights
cross_attn_linears = list(cross_attn.target.layers(fl.Linear))
assert len(cross_attn_linears) == 4 # Wq, Wk, Wv and Proj
cross_attn_state_dict[cross_attn.get_parameter_name("wq")] = cross_attn_linears[0].weight
cross_attn_state_dict[cross_attn.get_parameter_name("wk")] = cross_attn_linears[1].weight
cross_attn_state_dict[cross_attn.get_parameter_name("wv")] = cross_attn_linears[2].weight
cross_attn_state_dict[cross_attn.get_parameter_name("proj")] = cross_attn_linears[3].weight
cross_attn_state_dict[cross_attn.get_parameter_name("proj", bias=True)] = cross_attn_linears[3].bias
cross_attn.load_state_dict(state_dict=cross_attn_state_dict) cross_attn.load_state_dict(state_dict=cross_attn_state_dict)
def inject(self: "TIPAdapter", parent: fl.Chain | None = None) -> "TIPAdapter": def inject(self: "TIPAdapter", parent: fl.Chain | None = None) -> "TIPAdapter":

View file

@ -0,0 +1,14 @@
import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion.image_prompt import CrossAttentionAdapter
def test_cross_attention_adapter() -> None:
base = fl.Chain(fl.Attention(embedding_dim=4))
adapter = CrossAttentionAdapter(base.Attention).inject()
assert list(base) == [adapter]
adapter.eject()
assert len(base) == 1
assert isinstance(base[0], fl.Attention)