refactor CrossAttentionAdapter to work with context.

This commit is contained in:
limiteinductive 2024-01-08 14:39:55 +01:00 committed by Benjamin Trom
parent a08e04c5af
commit c9e973ba41
6 changed files with 92 additions and 203 deletions

View file

@ -128,16 +128,7 @@ with no_grad():
text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality"
)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt))
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
ip_adapter.set_clip_image_embedding(clip_image_embedding)
time_ids = sdxl.default_time_ids
condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype)

View file

@ -133,24 +133,14 @@ def main() -> None:
ip_adapter_weights: dict[str, torch.Tensor] = weights["ip_adapter"]
assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2
for i, cross_attn in enumerate(ip_adapter.sub_adapters):
for i, _ in enumerate(ip_adapter.sub_adapters):
cross_attn_index = cross_attn_mapping[i]
k_ip = f"{cross_attn_index}.to_k_ip.weight"
v_ip = f"{cross_attn_index}.to_v_ip.weight"
# Ignore Wq, Wk, Wv and Proj (hence strict=False): at runtime, they will be part of the UNet original weights
names = [k for k, _ in cross_attn.named_parameters()]
assert len(names) == 2
cross_attn_state_dict: dict[str, Any] = {
names[0]: ip_adapter_weights[k_ip],
names[1]: ip_adapter_weights[v_ip],
}
cross_attn.load_state_dict(state_dict=cross_attn_state_dict, strict=False)
for k, v in cross_attn_state_dict.items():
state_dict[f"ip_adapter.{i:03d}.{k}"] = v
# the name of the key is not checked at runtime, so we keep the original name
state_dict[f"ip_adapter.{i:03d}.to_k_ip.weight"] = ip_adapter_weights[k_ip]
state_dict[f"ip_adapter.{i:03d}.to_v_ip.weight"] = ip_adapter_weights[v_ip]
if args.half:
state_dict = {key: value.half() for key, value in state_dict.items()}

View file

@ -1,17 +1,15 @@
import math
from enum import IntEnum
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from jaxtyping import Float
from PIL import Image
from torch import Tensor, cat, device as Device, dtype as DType, softmax, zeros_like
from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, zeros_like
import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.adapters.lora import Lora
from refiners.fluxion.context import Contexts
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
from refiners.fluxion.layers.chain import Distribute
from refiners.fluxion.utils import image_to_tensor, normalize
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
@ -236,120 +234,89 @@ class PerceiverResampler(fl.Chain):
return {"perceiver_resampler": {"x": None}}
class _CrossAttnIndex(IntEnum):
TXT_CROSS_ATTN = 0 # text cross-attention
IMG_CROSS_ATTN = 1 # image cross-attention
class InjectionPoint(fl.Chain):
pass
class ImageCrossAttention(fl.Chain):
def __init__(self, text_cross_attention: fl.Attention, scale: float = 1.0) -> None:
self.scale = scale
super().__init__(
fl.Distribute(
fl.UseContext(context="ip_adapter", key="query_projection"),
fl.Chain(
fl.UseContext(context="ip_adapter", key="clip_image_embedding"),
fl.Linear(
in_features=text_cross_attention.key_embedding_dim,
out_features=text_cross_attention.inner_dim,
bias=text_cross_attention.use_bias,
device=text_cross_attention.device,
dtype=text_cross_attention.dtype,
),
),
fl.Chain(
fl.UseContext(context="ip_adapter", key="clip_image_embedding"),
fl.Linear(
in_features=text_cross_attention.key_embedding_dim,
out_features=text_cross_attention.inner_dim,
bias=text_cross_attention.use_bias,
device=text_cross_attention.device,
dtype=text_cross_attention.dtype,
),
),
),
ScaledDotProductAttention(
num_heads=text_cross_attention.num_heads, is_causal=text_cross_attention.is_causal
),
fl.Multiply(self.scale),
)
class SetQueryProjection(fl.Passthrough):
def __init__(self) -> None:
super().__init__(fl.GetArg(index=0), fl.SetContext(context="ip_adapter", key="query_projection"))
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
def __init__(
self,
target: fl.Attention,
text_sequence_length: int = 77,
image_sequence_length: int = 4,
scale: float = 1.0,
) -> None:
self.text_sequence_length = text_sequence_length
self.image_sequence_length = image_sequence_length
self.scale = scale
with self.setup_adapter(target):
super().__init__(
fl.Distribute(
# Note: the same query is used for image cross-attention as for text cross-attention
InjectionPoint(), # Wq
fl.Parallel(
fl.Chain(
fl.Slicing(dim=1, end=text_sequence_length),
InjectionPoint(), # Wk
),
fl.Chain(
fl.Slicing(dim=1, start=text_sequence_length),
fl.Linear(
in_features=self.target.key_embedding_dim,
out_features=self.target.inner_dim,
bias=self.target.use_bias,
device=target.device,
dtype=target.dtype,
), # Wk'
),
),
fl.Parallel(
fl.Chain(
fl.Slicing(dim=1, end=text_sequence_length),
InjectionPoint(), # Wv
),
fl.Chain(
fl.Slicing(dim=1, start=text_sequence_length),
fl.Linear(
in_features=self.target.key_embedding_dim,
out_features=self.target.inner_dim,
bias=self.target.use_bias,
device=target.device,
dtype=target.dtype,
), # Wv'
),
),
),
fl.Sum(
fl.Chain(
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.TXT_CROSS_ATTN)),
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
),
fl.Chain(
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.IMG_CROSS_ATTN)),
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
fl.Lambda(func=self.scale_outputs),
),
target[:-1], # original text cross attention
ImageCrossAttention(text_cross_attention=target, scale=scale),
),
InjectionPoint(), # proj
target[-1], # projection
)
self.ensure_find(fl.Attention).insert_after_type(Distribute, SetQueryProjection())
def select_qkv(
self, query: Tensor, keys: tuple[Tensor, Tensor], values: tuple[Tensor, Tensor], index: _CrossAttnIndex
) -> tuple[Tensor, Tensor, Tensor]:
return (query, keys[index.value], values[index.value])
@property
def image_cross_attention(self) -> ImageCrossAttention:
return self.ensure_find(ImageCrossAttention)
def scale_outputs(self, x: Tensor) -> Tensor:
return x * self.scale
@property
def image_key_projection(self) -> fl.Linear:
return self.image_cross_attention.Distribute[1].Linear
def _predicate(self, k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
def f(m: fl.Module, _: fl.Chain) -> bool:
if isinstance(m, Lora): # do not adapt LoRAs
raise StopIteration
return isinstance(m, k)
@property
def image_value_projection(self) -> fl.Linear:
return self.image_cross_attention.Distribute[2].Linear
return f
@property
def scale(self) -> float:
return self.image_cross_attention.scale
def _target_linears(self) -> list[fl.Linear]:
return [m for m, _ in self.target.walk(self._predicate(fl.Linear)) if isinstance(m, fl.Linear)]
@scale.setter
def scale(self, value: float) -> None:
self.image_cross_attention.scale = value
def inject(self: "CrossAttentionAdapter", parent: fl.Chain | None = None) -> "CrossAttentionAdapter":
linears = self._target_linears()
assert len(linears) == 4 # Wq, Wk, Wv and Proj
injection_points = list(self.layers(InjectionPoint))
assert len(injection_points) == 4
for linear, ip in zip(linears, injection_points):
ip.append(linear)
assert len(ip) == 1
return super().inject(parent)
def eject(self) -> None:
injection_points = list(self.layers(InjectionPoint))
assert len(injection_points) == 4
for ip in injection_points:
ip.pop()
assert len(ip) == 0
super().eject()
def load_weights(self, key_tensor: Tensor, value_tensor: Tensor) -> None:
self.image_key_projection.weight = nn.Parameter(key_tensor)
self.image_value_projection.weight = nn.Parameter(value_tensor)
self.image_cross_attention.to(self.device, self.dtype)
class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
@ -377,7 +344,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
self._image_proj = [image_proj]
self.sub_adapters = [
CrossAttentionAdapter(target=cross_attn, scale=scale, image_sequence_length=self.image_proj.num_tokens)
CrossAttentionAdapter(target=cross_attn, scale=scale)
for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention))
]
@ -388,14 +355,15 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
self.image_proj.load_state_dict(image_proj_state_dict)
for i, cross_attn in enumerate(self.sub_adapters):
cross_attn_state_dict: dict[str, Tensor] = {}
cross_attention_weights: list[Tensor] = []
for k, v in weights.items():
prefix = f"ip_adapter.{i:03d}."
if not k.startswith(prefix):
continue
cross_attn_state_dict[k.removeprefix(prefix)] = v
cross_attention_weights.append(v)
cross_attn.load_state_dict(state_dict=cross_attn_state_dict)
assert len(cross_attention_weights) == 2
cross_attn.load_weights(*cross_attention_weights)
@property
def clip_image_encoder(self) -> CLIPImageEncoderH:
@ -420,10 +388,22 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
adapter.eject()
super().eject()
@property
def scale(self) -> float:
return self.sub_adapters[0].scale
@scale.setter
def scale(self, value: float) -> None:
for cross_attn in self.sub_adapters:
cross_attn.scale = value
def set_scale(self, scale: float) -> None:
for cross_attn in self.sub_adapters:
cross_attn.scale = scale
def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})
# These should be concatenated to the CLIP text embedding before setting the UNet context
def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor:
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder

View file

@ -1168,16 +1168,7 @@ def test_diffusion_ip_adapter(
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
ip_adapter.set_clip_image_embedding(clip_image_embedding)
sd15.set_num_inference_steps(n_steps)
@ -1220,16 +1211,8 @@ def test_diffusion_sdxl_ip_adapter(
text=prompt, negative_text=negative_prompt
)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
ip_adapter.set_clip_image_embedding(clip_image_embedding)
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(n_steps)
@ -1285,16 +1268,7 @@ def test_diffusion_ip_adapter_controlnet(
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(input_image))
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
ip_adapter.set_clip_image_embedding(clip_image_embedding)
depth_cn_condition = image_to_tensor(
depth_condition_image.convert("RGB"),
@ -1343,16 +1317,7 @@ def test_diffusion_ip_adapter_plus(
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(statue_image))
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
ip_adapter.set_clip_image_embedding(clip_image_embedding)
sd15.set_num_inference_steps(n_steps)
@ -1396,16 +1361,8 @@ def test_diffusion_sdxl_ip_adapter_plus(
text=prompt, negative_text=negative_prompt
)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
ip_adapter.set_clip_image_embedding(clip_image_embedding)
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(n_steps)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 309 KiB

After

Width:  |  Height:  |  Size: 309 KiB

View file

@ -1,29 +0,0 @@
import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion.image_prompt import CrossAttentionAdapter, InjectionPoint
def test_cross_attention_adapter() -> None:
base = fl.Chain(fl.Attention(embedding_dim=4))
adapter = CrossAttentionAdapter(base.Attention).inject()
assert list(base) == [adapter]
assert len(list(adapter.layers(fl.Linear))) == 6
assert len(list(base.layers(fl.Linear))) == 6
injection_points = list(adapter.layers(InjectionPoint))
assert len(injection_points) == 4
for ip in injection_points:
assert len(ip) == 1
assert isinstance(ip[0], fl.Linear)
adapter.eject()
assert len(base) == 1
assert isinstance(base[0], fl.Attention)
assert len(list(adapter.layers(fl.Linear))) == 2
assert len(list(base.layers(fl.Linear))) == 4
injection_points = list(adapter.layers(InjectionPoint))
assert len(injection_points) == 4
for ip in injection_points:
assert len(ip) == 0