diff --git a/src/refiners/foundationals/latent_diffusion/image_prompt.py b/src/refiners/foundationals/latent_diffusion/image_prompt.py index 32cade0..0acf298 100644 --- a/src/refiners/foundationals/latent_diffusion/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/image_prompt.py @@ -1,14 +1,16 @@ from enum import IntEnum from functools import partial -from typing import Generic, TypeVar, Any +from typing import Generic, TypeVar, Any, Callable from torch import Tensor, as_tensor, cat, zeros_like, device as Device, dtype as DType from PIL import Image from refiners.fluxion.adapters.adapter import Adapter +from refiners.fluxion.adapters.lora import Lora from refiners.foundationals.clip.image_encoder import CLIPImageEncoder from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet +from refiners.fluxion.layers.module import Module from refiners.fluxion.layers.attentions import ScaledDotProductAttention from refiners.fluxion.utils import image_to_tensor import refiners.fluxion.layers as fl @@ -48,133 +50,8 @@ class _CrossAttnIndex(IntEnum): IMG_CROSS_ATTN = 1 # image cross-attention -# Fluxion's Attention layer drop-in replacement implementing Decoupled Cross-Attention -class IPAttention(fl.Chain): - structural_attrs = [ - "embedding_dim", - "text_sequence_length", - "image_sequence_length", - "scale", - "num_heads", - "heads_dim", - "key_embedding_dim", - "value_embedding_dim", - "inner_dim", - "use_bias", - "is_causal", - ] - - def __init__( - self, - embedding_dim: int, - text_sequence_length: int = 77, - image_sequence_length: int = 4, - scale: float = 1.0, - num_heads: int = 1, - key_embedding_dim: int | None = None, - value_embedding_dim: int | None = None, - inner_dim: int | None = None, - use_bias: bool = True, - is_causal: bool | None = None, - device: Device | str | None = None, - dtype: DType | None = None, - ) -> None: - assert ( - embedding_dim % num_heads == 0 - ), f"embedding_dim {embedding_dim} must be divisible by num_heads {num_heads}" - self.embedding_dim = embedding_dim - self.text_sequence_length = text_sequence_length - self.image_sequence_length = image_sequence_length - self.scale = scale - self.num_heads = num_heads - self.heads_dim = embedding_dim // num_heads - self.key_embedding_dim = key_embedding_dim or embedding_dim - self.value_embedding_dim = value_embedding_dim or embedding_dim - self.inner_dim = inner_dim or embedding_dim - self.use_bias = use_bias - self.is_causal = is_causal - super().__init__( - fl.Distribute( - # Note: the same query is used for image cross-attention as for text cross-attention - fl.Linear( - in_features=self.embedding_dim, - out_features=self.inner_dim, - bias=self.use_bias, - device=device, - dtype=dtype, - ), # Wq - fl.Parallel( - fl.Chain( - fl.Slicing(dim=1, start=0, length=text_sequence_length), - fl.Linear( - in_features=self.key_embedding_dim, - out_features=self.inner_dim, - bias=self.use_bias, - device=device, - dtype=dtype, - ), # Wk - ), - fl.Chain( - fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length), - fl.Linear( - in_features=self.key_embedding_dim, - out_features=self.inner_dim, - bias=self.use_bias, - device=device, - dtype=dtype, - ), # Wk' - ), - ), - fl.Parallel( - fl.Chain( - fl.Slicing(dim=1, start=0, length=text_sequence_length), - fl.Linear( - in_features=self.key_embedding_dim, - out_features=self.inner_dim, - bias=self.use_bias, - device=device, - dtype=dtype, - ), # Wv - ), - fl.Chain( - fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length), - fl.Linear( - in_features=self.key_embedding_dim, - out_features=self.inner_dim, - bias=self.use_bias, - device=device, - dtype=dtype, - ), # Wv' - ), - ), - ), - fl.Sum( - fl.Chain( - fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.TXT_CROSS_ATTN)), - ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal), - ), - fl.Chain( - fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.IMG_CROSS_ATTN)), - ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal), - fl.Lambda(func=self.scale_outputs), - ), - ), - fl.Linear( - in_features=self.inner_dim, - out_features=self.embedding_dim, - bias=True, - device=device, - dtype=dtype, - ), - ) - - def select_qkv( - self, query: Tensor, keys: tuple[Tensor, Tensor], values: tuple[Tensor, Tensor], index: _CrossAttnIndex - ) -> tuple[Tensor, Tensor, Tensor]: - return (query, keys[index.value], values[index.value]) - - def scale_outputs(self, x: Tensor) -> Tensor: - return x * self.scale +class InjectionPoint(fl.Chain): + pass class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): @@ -190,46 +67,100 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): self.text_sequence_length = text_sequence_length self.image_sequence_length = image_sequence_length self.scale = scale + with self.setup_adapter(target): super().__init__( - IPAttention( - embedding_dim=target.embedding_dim, - text_sequence_length=text_sequence_length, - image_sequence_length=image_sequence_length, - scale=scale, - num_heads=target.num_heads, - key_embedding_dim=target.key_embedding_dim, - value_embedding_dim=target.value_embedding_dim, - inner_dim=target.inner_dim, - use_bias=target.use_bias, - is_causal=target.is_causal, - device=target.device, - dtype=target.dtype, - ) + fl.Distribute( + # Note: the same query is used for image cross-attention as for text cross-attention + InjectionPoint(), # Wq + fl.Parallel( + fl.Chain( + fl.Slicing(dim=1, start=0, length=text_sequence_length), + InjectionPoint(), # Wk + ), + fl.Chain( + fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length), + fl.Linear( + in_features=self.target.key_embedding_dim, + out_features=self.target.inner_dim, + bias=self.target.use_bias, + device=target.device, + dtype=target.dtype, + ), # Wk' + ), + ), + fl.Parallel( + fl.Chain( + fl.Slicing(dim=1, start=0, length=text_sequence_length), + InjectionPoint(), # Wv + ), + fl.Chain( + fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length), + fl.Linear( + in_features=self.target.key_embedding_dim, + out_features=self.target.inner_dim, + bias=self.target.use_bias, + device=target.device, + dtype=target.dtype, + ), # Wv' + ), + ), + ), + fl.Sum( + fl.Chain( + fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.TXT_CROSS_ATTN)), + ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal), + ), + fl.Chain( + fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.IMG_CROSS_ATTN)), + ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal), + fl.Lambda(func=self.scale_outputs), + ), + ), + InjectionPoint(), # proj ) - def get_parameter_name(self, matrix: str, bias: bool = False) -> str: - match matrix: - case "wq": - index = 0 - case "wk": - index = 1 - case "wk_prime": - index = 2 - case "wv": - index = 3 - case "wv_prime": - index = 4 - case "proj": - index = 5 - case _: - raise ValueError(f"Unexpected matrix name {matrix}") + def select_qkv( + self, query: Tensor, keys: tuple[Tensor, Tensor], values: tuple[Tensor, Tensor], index: _CrossAttnIndex + ) -> tuple[Tensor, Tensor, Tensor]: + return (query, keys[index.value], values[index.value]) - linear = list(self.IPAttention.layers(fl.Linear))[index] - param = getattr(linear, "bias" if bias else "weight") - name = next((n for n, p in self.named_parameters() if id(p) == id(param)), None) - assert name is not None - return name + def scale_outputs(self, x: Tensor) -> Tensor: + return x * self.scale + + def _predicate(self, k: type[Module]) -> Callable[[fl.Module, fl.Chain], bool]: + def f(m: fl.Module, _: fl.Chain) -> bool: + if isinstance(m, Lora): # do not adapt LoRAs + raise StopIteration + return isinstance(m, k) + + return f + + def _target_linears(self) -> list[fl.Linear]: + return [m for m, _ in self.target.walk(self._predicate(fl.Linear)) if isinstance(m, fl.Linear)] + + def inject(self: "CrossAttentionAdapter", parent: fl.Chain | None = None) -> "CrossAttentionAdapter": + linears = self._target_linears() + assert len(linears) == 4 # Wq, Wk, Wv and Proj + + injection_points = list(self.layers(InjectionPoint)) + assert len(injection_points) == 4 + + for linear, ip in zip(linears, injection_points): + ip.append(linear) + assert len(ip) == 1 + + return super().inject(parent) + + def eject(self) -> None: + injection_points = list(self.layers(InjectionPoint)) + assert len(injection_points) == 4 + + for ip in injection_points: + ip.pop() + assert len(ip) == 0 + + super().eject() class IPAdapter(Generic[T], fl.Chain, Adapter[T]): @@ -265,17 +196,6 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): continue cross_attn_state_dict[k.removeprefix(prefix)] = v - # Retrieve original (frozen) cross-attention weights - # Note: this assumes the target UNet has already loaded weights - cross_attn_linears = list(cross_attn.target.layers(fl.Linear)) - assert len(cross_attn_linears) == 4 # Wq, Wk, Wv and Proj - - cross_attn_state_dict[cross_attn.get_parameter_name("wq")] = cross_attn_linears[0].weight - cross_attn_state_dict[cross_attn.get_parameter_name("wk")] = cross_attn_linears[1].weight - cross_attn_state_dict[cross_attn.get_parameter_name("wv")] = cross_attn_linears[2].weight - cross_attn_state_dict[cross_attn.get_parameter_name("proj")] = cross_attn_linears[3].weight - cross_attn_state_dict[cross_attn.get_parameter_name("proj", bias=True)] = cross_attn_linears[3].bias - cross_attn.load_state_dict(state_dict=cross_attn_state_dict) def inject(self: "TIPAdapter", parent: fl.Chain | None = None) -> "TIPAdapter": diff --git a/tests/foundationals/latent_diffusion/test_image_prompt.py b/tests/foundationals/latent_diffusion/test_image_prompt.py new file mode 100644 index 0000000..505bc78 --- /dev/null +++ b/tests/foundationals/latent_diffusion/test_image_prompt.py @@ -0,0 +1,14 @@ +import refiners.fluxion.layers as fl +from refiners.foundationals.latent_diffusion.image_prompt import CrossAttentionAdapter + + +def test_cross_attention_adapter() -> None: + base = fl.Chain(fl.Attention(embedding_dim=4)) + adapter = CrossAttentionAdapter(base.Attention).inject() + + assert list(base) == [adapter] + + adapter.eject() + + assert len(base) == 1 + assert isinstance(base[0], fl.Attention)