refactor CrossAttentionAdapter to work with context.

This commit is contained in:
limiteinductive 2024-01-08 14:39:55 +01:00 committed by Benjamin Trom
parent a08e04c5af
commit c9e973ba41
6 changed files with 92 additions and 203 deletions

View file

@ -128,16 +128,7 @@ with no_grad():
text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality" text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality"
) )
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt)) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt))
ip_adapter.set_clip_image_embedding(clip_image_embedding)
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
time_ids = sdxl.default_time_ids time_ids = sdxl.default_time_ids
condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype) condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype)

View file

@ -133,24 +133,14 @@ def main() -> None:
ip_adapter_weights: dict[str, torch.Tensor] = weights["ip_adapter"] ip_adapter_weights: dict[str, torch.Tensor] = weights["ip_adapter"]
assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2 assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2
for i, cross_attn in enumerate(ip_adapter.sub_adapters): for i, _ in enumerate(ip_adapter.sub_adapters):
cross_attn_index = cross_attn_mapping[i] cross_attn_index = cross_attn_mapping[i]
k_ip = f"{cross_attn_index}.to_k_ip.weight" k_ip = f"{cross_attn_index}.to_k_ip.weight"
v_ip = f"{cross_attn_index}.to_v_ip.weight" v_ip = f"{cross_attn_index}.to_v_ip.weight"
# Ignore Wq, Wk, Wv and Proj (hence strict=False): at runtime, they will be part of the UNet original weights # the name of the key is not checked at runtime, so we keep the original name
state_dict[f"ip_adapter.{i:03d}.to_k_ip.weight"] = ip_adapter_weights[k_ip]
names = [k for k, _ in cross_attn.named_parameters()] state_dict[f"ip_adapter.{i:03d}.to_v_ip.weight"] = ip_adapter_weights[v_ip]
assert len(names) == 2
cross_attn_state_dict: dict[str, Any] = {
names[0]: ip_adapter_weights[k_ip],
names[1]: ip_adapter_weights[v_ip],
}
cross_attn.load_state_dict(state_dict=cross_attn_state_dict, strict=False)
for k, v in cross_attn_state_dict.items():
state_dict[f"ip_adapter.{i:03d}.{k}"] = v
if args.half: if args.half:
state_dict = {key: value.half() for key, value in state_dict.items()} state_dict = {key: value.half() for key, value in state_dict.items()}

View file

@ -1,17 +1,15 @@
import math import math
from enum import IntEnum from typing import TYPE_CHECKING, Any, Generic, TypeVar
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar
from jaxtyping import Float from jaxtyping import Float
from PIL import Image from PIL import Image
from torch import Tensor, cat, device as Device, dtype as DType, softmax, zeros_like from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, zeros_like
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.adapters.lora import Lora
from refiners.fluxion.context import Contexts from refiners.fluxion.context import Contexts
from refiners.fluxion.layers.attentions import ScaledDotProductAttention from refiners.fluxion.layers.attentions import ScaledDotProductAttention
from refiners.fluxion.layers.chain import Distribute
from refiners.fluxion.utils import image_to_tensor, normalize from refiners.fluxion.utils import image_to_tensor, normalize
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
@ -236,120 +234,89 @@ class PerceiverResampler(fl.Chain):
return {"perceiver_resampler": {"x": None}} return {"perceiver_resampler": {"x": None}}
class _CrossAttnIndex(IntEnum):
TXT_CROSS_ATTN = 0 # text cross-attention
IMG_CROSS_ATTN = 1 # image cross-attention
class InjectionPoint(fl.Chain): class InjectionPoint(fl.Chain):
pass pass
class ImageCrossAttention(fl.Chain):
def __init__(self, text_cross_attention: fl.Attention, scale: float = 1.0) -> None:
self.scale = scale
super().__init__(
fl.Distribute(
fl.UseContext(context="ip_adapter", key="query_projection"),
fl.Chain(
fl.UseContext(context="ip_adapter", key="clip_image_embedding"),
fl.Linear(
in_features=text_cross_attention.key_embedding_dim,
out_features=text_cross_attention.inner_dim,
bias=text_cross_attention.use_bias,
device=text_cross_attention.device,
dtype=text_cross_attention.dtype,
),
),
fl.Chain(
fl.UseContext(context="ip_adapter", key="clip_image_embedding"),
fl.Linear(
in_features=text_cross_attention.key_embedding_dim,
out_features=text_cross_attention.inner_dim,
bias=text_cross_attention.use_bias,
device=text_cross_attention.device,
dtype=text_cross_attention.dtype,
),
),
),
ScaledDotProductAttention(
num_heads=text_cross_attention.num_heads, is_causal=text_cross_attention.is_causal
),
fl.Multiply(self.scale),
)
class SetQueryProjection(fl.Passthrough):
def __init__(self) -> None:
super().__init__(fl.GetArg(index=0), fl.SetContext(context="ip_adapter", key="query_projection"))
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
def __init__( def __init__(
self, self,
target: fl.Attention, target: fl.Attention,
text_sequence_length: int = 77,
image_sequence_length: int = 4,
scale: float = 1.0, scale: float = 1.0,
) -> None: ) -> None:
self.text_sequence_length = text_sequence_length
self.image_sequence_length = image_sequence_length
self.scale = scale
with self.setup_adapter(target): with self.setup_adapter(target):
super().__init__( super().__init__(
fl.Distribute(
# Note: the same query is used for image cross-attention as for text cross-attention
InjectionPoint(), # Wq
fl.Parallel(
fl.Chain(
fl.Slicing(dim=1, end=text_sequence_length),
InjectionPoint(), # Wk
),
fl.Chain(
fl.Slicing(dim=1, start=text_sequence_length),
fl.Linear(
in_features=self.target.key_embedding_dim,
out_features=self.target.inner_dim,
bias=self.target.use_bias,
device=target.device,
dtype=target.dtype,
), # Wk'
),
),
fl.Parallel(
fl.Chain(
fl.Slicing(dim=1, end=text_sequence_length),
InjectionPoint(), # Wv
),
fl.Chain(
fl.Slicing(dim=1, start=text_sequence_length),
fl.Linear(
in_features=self.target.key_embedding_dim,
out_features=self.target.inner_dim,
bias=self.target.use_bias,
device=target.device,
dtype=target.dtype,
), # Wv'
),
),
),
fl.Sum( fl.Sum(
fl.Chain( target[:-1], # original text cross attention
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.TXT_CROSS_ATTN)), ImageCrossAttention(text_cross_attention=target, scale=scale),
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
), ),
fl.Chain( target[-1], # projection
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.IMG_CROSS_ATTN)),
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
fl.Lambda(func=self.scale_outputs),
),
),
InjectionPoint(), # proj
) )
self.ensure_find(fl.Attention).insert_after_type(Distribute, SetQueryProjection())
def select_qkv( @property
self, query: Tensor, keys: tuple[Tensor, Tensor], values: tuple[Tensor, Tensor], index: _CrossAttnIndex def image_cross_attention(self) -> ImageCrossAttention:
) -> tuple[Tensor, Tensor, Tensor]: return self.ensure_find(ImageCrossAttention)
return (query, keys[index.value], values[index.value])
def scale_outputs(self, x: Tensor) -> Tensor: @property
return x * self.scale def image_key_projection(self) -> fl.Linear:
return self.image_cross_attention.Distribute[1].Linear
def _predicate(self, k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]: @property
def f(m: fl.Module, _: fl.Chain) -> bool: def image_value_projection(self) -> fl.Linear:
if isinstance(m, Lora): # do not adapt LoRAs return self.image_cross_attention.Distribute[2].Linear
raise StopIteration
return isinstance(m, k)
return f @property
def scale(self) -> float:
return self.image_cross_attention.scale
def _target_linears(self) -> list[fl.Linear]: @scale.setter
return [m for m, _ in self.target.walk(self._predicate(fl.Linear)) if isinstance(m, fl.Linear)] def scale(self, value: float) -> None:
self.image_cross_attention.scale = value
def inject(self: "CrossAttentionAdapter", parent: fl.Chain | None = None) -> "CrossAttentionAdapter": def load_weights(self, key_tensor: Tensor, value_tensor: Tensor) -> None:
linears = self._target_linears() self.image_key_projection.weight = nn.Parameter(key_tensor)
assert len(linears) == 4 # Wq, Wk, Wv and Proj self.image_value_projection.weight = nn.Parameter(value_tensor)
self.image_cross_attention.to(self.device, self.dtype)
injection_points = list(self.layers(InjectionPoint))
assert len(injection_points) == 4
for linear, ip in zip(linears, injection_points):
ip.append(linear)
assert len(ip) == 1
return super().inject(parent)
def eject(self) -> None:
injection_points = list(self.layers(InjectionPoint))
assert len(injection_points) == 4
for ip in injection_points:
ip.pop()
assert len(ip) == 0
super().eject()
class IPAdapter(Generic[T], fl.Chain, Adapter[T]): class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
@ -377,7 +344,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
self._image_proj = [image_proj] self._image_proj = [image_proj]
self.sub_adapters = [ self.sub_adapters = [
CrossAttentionAdapter(target=cross_attn, scale=scale, image_sequence_length=self.image_proj.num_tokens) CrossAttentionAdapter(target=cross_attn, scale=scale)
for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention)) for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention))
] ]
@ -388,14 +355,15 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
self.image_proj.load_state_dict(image_proj_state_dict) self.image_proj.load_state_dict(image_proj_state_dict)
for i, cross_attn in enumerate(self.sub_adapters): for i, cross_attn in enumerate(self.sub_adapters):
cross_attn_state_dict: dict[str, Tensor] = {} cross_attention_weights: list[Tensor] = []
for k, v in weights.items(): for k, v in weights.items():
prefix = f"ip_adapter.{i:03d}." prefix = f"ip_adapter.{i:03d}."
if not k.startswith(prefix): if not k.startswith(prefix):
continue continue
cross_attn_state_dict[k.removeprefix(prefix)] = v cross_attention_weights.append(v)
cross_attn.load_state_dict(state_dict=cross_attn_state_dict) assert len(cross_attention_weights) == 2
cross_attn.load_weights(*cross_attention_weights)
@property @property
def clip_image_encoder(self) -> CLIPImageEncoderH: def clip_image_encoder(self) -> CLIPImageEncoderH:
@ -420,10 +388,22 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
adapter.eject() adapter.eject()
super().eject() super().eject()
@property
def scale(self) -> float:
return self.sub_adapters[0].scale
@scale.setter
def scale(self, value: float) -> None:
for cross_attn in self.sub_adapters:
cross_attn.scale = value
def set_scale(self, scale: float) -> None: def set_scale(self, scale: float) -> None:
for cross_attn in self.sub_adapters: for cross_attn in self.sub_adapters:
cross_attn.scale = scale cross_attn.scale = scale
def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})
# These should be concatenated to the CLIP text embedding before setting the UNet context # These should be concatenated to the CLIP text embedding before setting the UNet context
def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor: def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor:
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder

View file

@ -1168,16 +1168,7 @@ def test_diffusion_ip_adapter(
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image)) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
ip_adapter.set_clip_image_embedding(clip_image_embedding)
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
sd15.set_num_inference_steps(n_steps) sd15.set_num_inference_steps(n_steps)
@ -1220,16 +1211,8 @@ def test_diffusion_sdxl_ip_adapter(
text=prompt, negative_text=negative_prompt text=prompt, negative_text=negative_prompt
) )
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image)) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
ip_adapter.set_clip_image_embedding(clip_image_embedding)
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
time_ids = sdxl.default_time_ids time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(n_steps) sdxl.set_num_inference_steps(n_steps)
@ -1285,16 +1268,7 @@ def test_diffusion_ip_adapter_controlnet(
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(input_image)) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(input_image))
ip_adapter.set_clip_image_embedding(clip_image_embedding)
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
depth_cn_condition = image_to_tensor( depth_cn_condition = image_to_tensor(
depth_condition_image.convert("RGB"), depth_condition_image.convert("RGB"),
@ -1343,16 +1317,7 @@ def test_diffusion_ip_adapter_plus(
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(statue_image)) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(statue_image))
ip_adapter.set_clip_image_embedding(clip_image_embedding)
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
sd15.set_num_inference_steps(n_steps) sd15.set_num_inference_steps(n_steps)
@ -1396,16 +1361,8 @@ def test_diffusion_sdxl_ip_adapter_plus(
text=prompt, negative_text=negative_prompt text=prompt, negative_text=negative_prompt
) )
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image)) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
ip_adapter.set_clip_image_embedding(clip_image_embedding)
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
time_ids = sdxl.default_time_ids time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(n_steps) sdxl.set_num_inference_steps(n_steps)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 309 KiB

After

Width:  |  Height:  |  Size: 309 KiB

View file

@ -1,29 +0,0 @@
import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion.image_prompt import CrossAttentionAdapter, InjectionPoint
def test_cross_attention_adapter() -> None:
base = fl.Chain(fl.Attention(embedding_dim=4))
adapter = CrossAttentionAdapter(base.Attention).inject()
assert list(base) == [adapter]
assert len(list(adapter.layers(fl.Linear))) == 6
assert len(list(base.layers(fl.Linear))) == 6
injection_points = list(adapter.layers(InjectionPoint))
assert len(injection_points) == 4
for ip in injection_points:
assert len(ip) == 1
assert isinstance(ip[0], fl.Linear)
adapter.eject()
assert len(base) == 1
assert isinstance(base[0], fl.Attention)
assert len(list(adapter.layers(fl.Linear))) == 2
assert len(list(base.layers(fl.Linear))) == 4
injection_points = list(adapter.layers(InjectionPoint))
assert len(injection_points) == 4
for ip in injection_points:
assert len(ip) == 0