diff --git a/README.md b/README.md index 4c52c9f..dad2e4c 100644 --- a/README.md +++ b/README.md @@ -128,16 +128,7 @@ with no_grad(): text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality" ) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt)) - - negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2) - negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2) - - clip_text_embedding = torch.cat( - ( - torch.cat([negative_text_embedding, negative_image_embedding], dim=1), - torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1), - ) - ) + ip_adapter.set_clip_image_embedding(clip_image_embedding) time_ids = sdxl.default_time_ids condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype) diff --git a/scripts/conversion/convert_diffusers_ip_adapter.py b/scripts/conversion/convert_diffusers_ip_adapter.py index 8fd33be..8282db1 100644 --- a/scripts/conversion/convert_diffusers_ip_adapter.py +++ b/scripts/conversion/convert_diffusers_ip_adapter.py @@ -133,24 +133,14 @@ def main() -> None: ip_adapter_weights: dict[str, torch.Tensor] = weights["ip_adapter"] assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2 - for i, cross_attn in enumerate(ip_adapter.sub_adapters): + for i, _ in enumerate(ip_adapter.sub_adapters): cross_attn_index = cross_attn_mapping[i] k_ip = f"{cross_attn_index}.to_k_ip.weight" v_ip = f"{cross_attn_index}.to_v_ip.weight" - # Ignore Wq, Wk, Wv and Proj (hence strict=False): at runtime, they will be part of the UNet original weights - - names = [k for k, _ in cross_attn.named_parameters()] - assert len(names) == 2 - - cross_attn_state_dict: dict[str, Any] = { - names[0]: ip_adapter_weights[k_ip], - names[1]: ip_adapter_weights[v_ip], - } - cross_attn.load_state_dict(state_dict=cross_attn_state_dict, strict=False) - - for k, v in cross_attn_state_dict.items(): - state_dict[f"ip_adapter.{i:03d}.{k}"] = v + # the name of the key is not checked at runtime, so we keep the original name + state_dict[f"ip_adapter.{i:03d}.to_k_ip.weight"] = ip_adapter_weights[k_ip] + state_dict[f"ip_adapter.{i:03d}.to_v_ip.weight"] = ip_adapter_weights[v_ip] if args.half: state_dict = {key: value.half() for key, value in state_dict.items()} diff --git a/src/refiners/foundationals/latent_diffusion/image_prompt.py b/src/refiners/foundationals/latent_diffusion/image_prompt.py index 9c4bac0..b0cdfdd 100644 --- a/src/refiners/foundationals/latent_diffusion/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/image_prompt.py @@ -1,17 +1,15 @@ import math -from enum import IntEnum -from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar from jaxtyping import Float from PIL import Image -from torch import Tensor, cat, device as Device, dtype as DType, softmax, zeros_like +from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, zeros_like import refiners.fluxion.layers as fl from refiners.fluxion.adapters.adapter import Adapter -from refiners.fluxion.adapters.lora import Lora from refiners.fluxion.context import Contexts from refiners.fluxion.layers.attentions import ScaledDotProductAttention +from refiners.fluxion.layers.chain import Distribute from refiners.fluxion.utils import image_to_tensor, normalize from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH @@ -236,120 +234,89 @@ class PerceiverResampler(fl.Chain): return {"perceiver_resampler": {"x": None}} -class _CrossAttnIndex(IntEnum): - TXT_CROSS_ATTN = 0 # text cross-attention - IMG_CROSS_ATTN = 1 # image cross-attention - - class InjectionPoint(fl.Chain): pass +class ImageCrossAttention(fl.Chain): + def __init__(self, text_cross_attention: fl.Attention, scale: float = 1.0) -> None: + self.scale = scale + super().__init__( + fl.Distribute( + fl.UseContext(context="ip_adapter", key="query_projection"), + fl.Chain( + fl.UseContext(context="ip_adapter", key="clip_image_embedding"), + fl.Linear( + in_features=text_cross_attention.key_embedding_dim, + out_features=text_cross_attention.inner_dim, + bias=text_cross_attention.use_bias, + device=text_cross_attention.device, + dtype=text_cross_attention.dtype, + ), + ), + fl.Chain( + fl.UseContext(context="ip_adapter", key="clip_image_embedding"), + fl.Linear( + in_features=text_cross_attention.key_embedding_dim, + out_features=text_cross_attention.inner_dim, + bias=text_cross_attention.use_bias, + device=text_cross_attention.device, + dtype=text_cross_attention.dtype, + ), + ), + ), + ScaledDotProductAttention( + num_heads=text_cross_attention.num_heads, is_causal=text_cross_attention.is_causal + ), + fl.Multiply(self.scale), + ) + + +class SetQueryProjection(fl.Passthrough): + def __init__(self) -> None: + super().__init__(fl.GetArg(index=0), fl.SetContext(context="ip_adapter", key="query_projection")) + + class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): def __init__( self, target: fl.Attention, - text_sequence_length: int = 77, - image_sequence_length: int = 4, scale: float = 1.0, ) -> None: - self.text_sequence_length = text_sequence_length - self.image_sequence_length = image_sequence_length - self.scale = scale - with self.setup_adapter(target): super().__init__( - fl.Distribute( - # Note: the same query is used for image cross-attention as for text cross-attention - InjectionPoint(), # Wq - fl.Parallel( - fl.Chain( - fl.Slicing(dim=1, end=text_sequence_length), - InjectionPoint(), # Wk - ), - fl.Chain( - fl.Slicing(dim=1, start=text_sequence_length), - fl.Linear( - in_features=self.target.key_embedding_dim, - out_features=self.target.inner_dim, - bias=self.target.use_bias, - device=target.device, - dtype=target.dtype, - ), # Wk' - ), - ), - fl.Parallel( - fl.Chain( - fl.Slicing(dim=1, end=text_sequence_length), - InjectionPoint(), # Wv - ), - fl.Chain( - fl.Slicing(dim=1, start=text_sequence_length), - fl.Linear( - in_features=self.target.key_embedding_dim, - out_features=self.target.inner_dim, - bias=self.target.use_bias, - device=target.device, - dtype=target.dtype, - ), # Wv' - ), - ), - ), fl.Sum( - fl.Chain( - fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.TXT_CROSS_ATTN)), - ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal), - ), - fl.Chain( - fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.IMG_CROSS_ATTN)), - ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal), - fl.Lambda(func=self.scale_outputs), - ), + target[:-1], # original text cross attention + ImageCrossAttention(text_cross_attention=target, scale=scale), ), - InjectionPoint(), # proj + target[-1], # projection ) + self.ensure_find(fl.Attention).insert_after_type(Distribute, SetQueryProjection()) - def select_qkv( - self, query: Tensor, keys: tuple[Tensor, Tensor], values: tuple[Tensor, Tensor], index: _CrossAttnIndex - ) -> tuple[Tensor, Tensor, Tensor]: - return (query, keys[index.value], values[index.value]) + @property + def image_cross_attention(self) -> ImageCrossAttention: + return self.ensure_find(ImageCrossAttention) - def scale_outputs(self, x: Tensor) -> Tensor: - return x * self.scale + @property + def image_key_projection(self) -> fl.Linear: + return self.image_cross_attention.Distribute[1].Linear - def _predicate(self, k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]: - def f(m: fl.Module, _: fl.Chain) -> bool: - if isinstance(m, Lora): # do not adapt LoRAs - raise StopIteration - return isinstance(m, k) + @property + def image_value_projection(self) -> fl.Linear: + return self.image_cross_attention.Distribute[2].Linear - return f + @property + def scale(self) -> float: + return self.image_cross_attention.scale - def _target_linears(self) -> list[fl.Linear]: - return [m for m, _ in self.target.walk(self._predicate(fl.Linear)) if isinstance(m, fl.Linear)] + @scale.setter + def scale(self, value: float) -> None: + self.image_cross_attention.scale = value - def inject(self: "CrossAttentionAdapter", parent: fl.Chain | None = None) -> "CrossAttentionAdapter": - linears = self._target_linears() - assert len(linears) == 4 # Wq, Wk, Wv and Proj - - injection_points = list(self.layers(InjectionPoint)) - assert len(injection_points) == 4 - - for linear, ip in zip(linears, injection_points): - ip.append(linear) - assert len(ip) == 1 - - return super().inject(parent) - - def eject(self) -> None: - injection_points = list(self.layers(InjectionPoint)) - assert len(injection_points) == 4 - - for ip in injection_points: - ip.pop() - assert len(ip) == 0 - - super().eject() + def load_weights(self, key_tensor: Tensor, value_tensor: Tensor) -> None: + self.image_key_projection.weight = nn.Parameter(key_tensor) + self.image_value_projection.weight = nn.Parameter(value_tensor) + self.image_cross_attention.to(self.device, self.dtype) class IPAdapter(Generic[T], fl.Chain, Adapter[T]): @@ -377,7 +344,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): self._image_proj = [image_proj] self.sub_adapters = [ - CrossAttentionAdapter(target=cross_attn, scale=scale, image_sequence_length=self.image_proj.num_tokens) + CrossAttentionAdapter(target=cross_attn, scale=scale) for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention)) ] @@ -388,14 +355,15 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): self.image_proj.load_state_dict(image_proj_state_dict) for i, cross_attn in enumerate(self.sub_adapters): - cross_attn_state_dict: dict[str, Tensor] = {} + cross_attention_weights: list[Tensor] = [] for k, v in weights.items(): prefix = f"ip_adapter.{i:03d}." if not k.startswith(prefix): continue - cross_attn_state_dict[k.removeprefix(prefix)] = v + cross_attention_weights.append(v) - cross_attn.load_state_dict(state_dict=cross_attn_state_dict) + assert len(cross_attention_weights) == 2 + cross_attn.load_weights(*cross_attention_weights) @property def clip_image_encoder(self) -> CLIPImageEncoderH: @@ -420,10 +388,22 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): adapter.eject() super().eject() + @property + def scale(self) -> float: + return self.sub_adapters[0].scale + + @scale.setter + def scale(self, value: float) -> None: + for cross_attn in self.sub_adapters: + cross_attn.scale = value + def set_scale(self, scale: float) -> None: for cross_attn in self.sub_adapters: cross_attn.scale = scale + def set_clip_image_embedding(self, image_embedding: Tensor) -> None: + self.set_context("ip_adapter", {"clip_image_embedding": image_embedding}) + # These should be concatenated to the CLIP text embedding before setting the UNet context def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor: image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 9cf70e8..4850e79 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -1168,16 +1168,7 @@ def test_diffusion_ip_adapter( clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image)) - - negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2) - negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2) - - clip_text_embedding = torch.cat( - ( - torch.cat([negative_text_embedding, negative_image_embedding], dim=1), - torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1), - ) - ) + ip_adapter.set_clip_image_embedding(clip_image_embedding) sd15.set_num_inference_steps(n_steps) @@ -1220,16 +1211,8 @@ def test_diffusion_sdxl_ip_adapter( text=prompt, negative_text=negative_prompt ) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image)) + ip_adapter.set_clip_image_embedding(clip_image_embedding) - negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2) - negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2) - - clip_text_embedding = torch.cat( - ( - torch.cat([negative_text_embedding, negative_image_embedding], dim=1), - torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1), - ) - ) time_ids = sdxl.default_time_ids sdxl.set_num_inference_steps(n_steps) @@ -1285,16 +1268,7 @@ def test_diffusion_ip_adapter_controlnet( clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(input_image)) - - negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2) - negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2) - - clip_text_embedding = torch.cat( - ( - torch.cat([negative_text_embedding, negative_image_embedding], dim=1), - torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1), - ) - ) + ip_adapter.set_clip_image_embedding(clip_image_embedding) depth_cn_condition = image_to_tensor( depth_condition_image.convert("RGB"), @@ -1343,16 +1317,7 @@ def test_diffusion_ip_adapter_plus( clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(statue_image)) - - negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2) - negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2) - - clip_text_embedding = torch.cat( - ( - torch.cat([negative_text_embedding, negative_image_embedding], dim=1), - torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1), - ) - ) + ip_adapter.set_clip_image_embedding(clip_image_embedding) sd15.set_num_inference_steps(n_steps) @@ -1396,16 +1361,8 @@ def test_diffusion_sdxl_ip_adapter_plus( text=prompt, negative_text=negative_prompt ) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image)) + ip_adapter.set_clip_image_embedding(clip_image_embedding) - negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2) - negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2) - - clip_text_embedding = torch.cat( - ( - torch.cat([negative_text_embedding, negative_image_embedding], dim=1), - torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1), - ) - ) time_ids = sdxl.default_time_ids sdxl.set_num_inference_steps(n_steps) diff --git a/tests/e2e/test_diffusion_ref/expected_ip_adapter_controlnet.png b/tests/e2e/test_diffusion_ref/expected_ip_adapter_controlnet.png index e838df1..41a90cc 100644 Binary files a/tests/e2e/test_diffusion_ref/expected_ip_adapter_controlnet.png and b/tests/e2e/test_diffusion_ref/expected_ip_adapter_controlnet.png differ diff --git a/tests/foundationals/latent_diffusion/test_image_prompt.py b/tests/foundationals/latent_diffusion/test_image_prompt.py deleted file mode 100644 index 612de2e..0000000 --- a/tests/foundationals/latent_diffusion/test_image_prompt.py +++ /dev/null @@ -1,29 +0,0 @@ -import refiners.fluxion.layers as fl -from refiners.foundationals.latent_diffusion.image_prompt import CrossAttentionAdapter, InjectionPoint - - -def test_cross_attention_adapter() -> None: - base = fl.Chain(fl.Attention(embedding_dim=4)) - adapter = CrossAttentionAdapter(base.Attention).inject() - - assert list(base) == [adapter] - assert len(list(adapter.layers(fl.Linear))) == 6 - assert len(list(base.layers(fl.Linear))) == 6 - - injection_points = list(adapter.layers(InjectionPoint)) - assert len(injection_points) == 4 - for ip in injection_points: - assert len(ip) == 1 - assert isinstance(ip[0], fl.Linear) - - adapter.eject() - - assert len(base) == 1 - assert isinstance(base[0], fl.Attention) - assert len(list(adapter.layers(fl.Linear))) == 2 - assert len(list(base.layers(fl.Linear))) == 4 - - injection_points = list(adapter.layers(InjectionPoint)) - assert len(injection_points) == 4 - for ip in injection_points: - assert len(ip) == 0