mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
refactor CrossAttentionAdapter to work with context.
This commit is contained in:
parent
a08e04c5af
commit
c9e973ba41
11
README.md
11
README.md
|
@ -128,16 +128,7 @@ with no_grad():
|
|||
text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||
)
|
||||
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt))
|
||||
|
||||
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
|
||||
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
|
||||
|
||||
clip_text_embedding = torch.cat(
|
||||
(
|
||||
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
|
||||
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
|
||||
)
|
||||
)
|
||||
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
||||
time_ids = sdxl.default_time_ids
|
||||
|
||||
condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype)
|
||||
|
|
|
@ -133,24 +133,14 @@ def main() -> None:
|
|||
ip_adapter_weights: dict[str, torch.Tensor] = weights["ip_adapter"]
|
||||
assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2
|
||||
|
||||
for i, cross_attn in enumerate(ip_adapter.sub_adapters):
|
||||
for i, _ in enumerate(ip_adapter.sub_adapters):
|
||||
cross_attn_index = cross_attn_mapping[i]
|
||||
k_ip = f"{cross_attn_index}.to_k_ip.weight"
|
||||
v_ip = f"{cross_attn_index}.to_v_ip.weight"
|
||||
|
||||
# Ignore Wq, Wk, Wv and Proj (hence strict=False): at runtime, they will be part of the UNet original weights
|
||||
|
||||
names = [k for k, _ in cross_attn.named_parameters()]
|
||||
assert len(names) == 2
|
||||
|
||||
cross_attn_state_dict: dict[str, Any] = {
|
||||
names[0]: ip_adapter_weights[k_ip],
|
||||
names[1]: ip_adapter_weights[v_ip],
|
||||
}
|
||||
cross_attn.load_state_dict(state_dict=cross_attn_state_dict, strict=False)
|
||||
|
||||
for k, v in cross_attn_state_dict.items():
|
||||
state_dict[f"ip_adapter.{i:03d}.{k}"] = v
|
||||
# the name of the key is not checked at runtime, so we keep the original name
|
||||
state_dict[f"ip_adapter.{i:03d}.to_k_ip.weight"] = ip_adapter_weights[k_ip]
|
||||
state_dict[f"ip_adapter.{i:03d}.to_v_ip.weight"] = ip_adapter_weights[v_ip]
|
||||
|
||||
if args.half:
|
||||
state_dict = {key: value.half() for key, value in state_dict.items()}
|
||||
|
|
|
@ -1,17 +1,15 @@
|
|||
import math
|
||||
from enum import IntEnum
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||
|
||||
from jaxtyping import Float
|
||||
from PIL import Image
|
||||
from torch import Tensor, cat, device as Device, dtype as DType, softmax, zeros_like
|
||||
from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, zeros_like
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.adapters.adapter import Adapter
|
||||
from refiners.fluxion.adapters.lora import Lora
|
||||
from refiners.fluxion.context import Contexts
|
||||
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
||||
from refiners.fluxion.layers.chain import Distribute
|
||||
from refiners.fluxion.utils import image_to_tensor, normalize
|
||||
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
||||
|
||||
|
@ -236,120 +234,89 @@ class PerceiverResampler(fl.Chain):
|
|||
return {"perceiver_resampler": {"x": None}}
|
||||
|
||||
|
||||
class _CrossAttnIndex(IntEnum):
|
||||
TXT_CROSS_ATTN = 0 # text cross-attention
|
||||
IMG_CROSS_ATTN = 1 # image cross-attention
|
||||
|
||||
|
||||
class InjectionPoint(fl.Chain):
|
||||
pass
|
||||
|
||||
|
||||
class ImageCrossAttention(fl.Chain):
|
||||
def __init__(self, text_cross_attention: fl.Attention, scale: float = 1.0) -> None:
|
||||
self.scale = scale
|
||||
super().__init__(
|
||||
fl.Distribute(
|
||||
fl.UseContext(context="ip_adapter", key="query_projection"),
|
||||
fl.Chain(
|
||||
fl.UseContext(context="ip_adapter", key="clip_image_embedding"),
|
||||
fl.Linear(
|
||||
in_features=text_cross_attention.key_embedding_dim,
|
||||
out_features=text_cross_attention.inner_dim,
|
||||
bias=text_cross_attention.use_bias,
|
||||
device=text_cross_attention.device,
|
||||
dtype=text_cross_attention.dtype,
|
||||
),
|
||||
),
|
||||
fl.Chain(
|
||||
fl.UseContext(context="ip_adapter", key="clip_image_embedding"),
|
||||
fl.Linear(
|
||||
in_features=text_cross_attention.key_embedding_dim,
|
||||
out_features=text_cross_attention.inner_dim,
|
||||
bias=text_cross_attention.use_bias,
|
||||
device=text_cross_attention.device,
|
||||
dtype=text_cross_attention.dtype,
|
||||
),
|
||||
),
|
||||
),
|
||||
ScaledDotProductAttention(
|
||||
num_heads=text_cross_attention.num_heads, is_causal=text_cross_attention.is_causal
|
||||
),
|
||||
fl.Multiply(self.scale),
|
||||
)
|
||||
|
||||
|
||||
class SetQueryProjection(fl.Passthrough):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(fl.GetArg(index=0), fl.SetContext(context="ip_adapter", key="query_projection"))
|
||||
|
||||
|
||||
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
||||
def __init__(
|
||||
self,
|
||||
target: fl.Attention,
|
||||
text_sequence_length: int = 77,
|
||||
image_sequence_length: int = 4,
|
||||
scale: float = 1.0,
|
||||
) -> None:
|
||||
self.text_sequence_length = text_sequence_length
|
||||
self.image_sequence_length = image_sequence_length
|
||||
self.scale = scale
|
||||
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(
|
||||
fl.Distribute(
|
||||
# Note: the same query is used for image cross-attention as for text cross-attention
|
||||
InjectionPoint(), # Wq
|
||||
fl.Parallel(
|
||||
fl.Chain(
|
||||
fl.Slicing(dim=1, end=text_sequence_length),
|
||||
InjectionPoint(), # Wk
|
||||
),
|
||||
fl.Chain(
|
||||
fl.Slicing(dim=1, start=text_sequence_length),
|
||||
fl.Linear(
|
||||
in_features=self.target.key_embedding_dim,
|
||||
out_features=self.target.inner_dim,
|
||||
bias=self.target.use_bias,
|
||||
device=target.device,
|
||||
dtype=target.dtype,
|
||||
), # Wk'
|
||||
),
|
||||
),
|
||||
fl.Parallel(
|
||||
fl.Chain(
|
||||
fl.Slicing(dim=1, end=text_sequence_length),
|
||||
InjectionPoint(), # Wv
|
||||
),
|
||||
fl.Chain(
|
||||
fl.Slicing(dim=1, start=text_sequence_length),
|
||||
fl.Linear(
|
||||
in_features=self.target.key_embedding_dim,
|
||||
out_features=self.target.inner_dim,
|
||||
bias=self.target.use_bias,
|
||||
device=target.device,
|
||||
dtype=target.dtype,
|
||||
), # Wv'
|
||||
),
|
||||
),
|
||||
),
|
||||
fl.Sum(
|
||||
fl.Chain(
|
||||
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.TXT_CROSS_ATTN)),
|
||||
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
|
||||
),
|
||||
fl.Chain(
|
||||
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.IMG_CROSS_ATTN)),
|
||||
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
|
||||
fl.Lambda(func=self.scale_outputs),
|
||||
),
|
||||
target[:-1], # original text cross attention
|
||||
ImageCrossAttention(text_cross_attention=target, scale=scale),
|
||||
),
|
||||
InjectionPoint(), # proj
|
||||
target[-1], # projection
|
||||
)
|
||||
self.ensure_find(fl.Attention).insert_after_type(Distribute, SetQueryProjection())
|
||||
|
||||
def select_qkv(
|
||||
self, query: Tensor, keys: tuple[Tensor, Tensor], values: tuple[Tensor, Tensor], index: _CrossAttnIndex
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
return (query, keys[index.value], values[index.value])
|
||||
@property
|
||||
def image_cross_attention(self) -> ImageCrossAttention:
|
||||
return self.ensure_find(ImageCrossAttention)
|
||||
|
||||
def scale_outputs(self, x: Tensor) -> Tensor:
|
||||
return x * self.scale
|
||||
@property
|
||||
def image_key_projection(self) -> fl.Linear:
|
||||
return self.image_cross_attention.Distribute[1].Linear
|
||||
|
||||
def _predicate(self, k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
|
||||
def f(m: fl.Module, _: fl.Chain) -> bool:
|
||||
if isinstance(m, Lora): # do not adapt LoRAs
|
||||
raise StopIteration
|
||||
return isinstance(m, k)
|
||||
@property
|
||||
def image_value_projection(self) -> fl.Linear:
|
||||
return self.image_cross_attention.Distribute[2].Linear
|
||||
|
||||
return f
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
return self.image_cross_attention.scale
|
||||
|
||||
def _target_linears(self) -> list[fl.Linear]:
|
||||
return [m for m, _ in self.target.walk(self._predicate(fl.Linear)) if isinstance(m, fl.Linear)]
|
||||
@scale.setter
|
||||
def scale(self, value: float) -> None:
|
||||
self.image_cross_attention.scale = value
|
||||
|
||||
def inject(self: "CrossAttentionAdapter", parent: fl.Chain | None = None) -> "CrossAttentionAdapter":
|
||||
linears = self._target_linears()
|
||||
assert len(linears) == 4 # Wq, Wk, Wv and Proj
|
||||
|
||||
injection_points = list(self.layers(InjectionPoint))
|
||||
assert len(injection_points) == 4
|
||||
|
||||
for linear, ip in zip(linears, injection_points):
|
||||
ip.append(linear)
|
||||
assert len(ip) == 1
|
||||
|
||||
return super().inject(parent)
|
||||
|
||||
def eject(self) -> None:
|
||||
injection_points = list(self.layers(InjectionPoint))
|
||||
assert len(injection_points) == 4
|
||||
|
||||
for ip in injection_points:
|
||||
ip.pop()
|
||||
assert len(ip) == 0
|
||||
|
||||
super().eject()
|
||||
def load_weights(self, key_tensor: Tensor, value_tensor: Tensor) -> None:
|
||||
self.image_key_projection.weight = nn.Parameter(key_tensor)
|
||||
self.image_value_projection.weight = nn.Parameter(value_tensor)
|
||||
self.image_cross_attention.to(self.device, self.dtype)
|
||||
|
||||
|
||||
class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||
|
@ -377,7 +344,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
self._image_proj = [image_proj]
|
||||
|
||||
self.sub_adapters = [
|
||||
CrossAttentionAdapter(target=cross_attn, scale=scale, image_sequence_length=self.image_proj.num_tokens)
|
||||
CrossAttentionAdapter(target=cross_attn, scale=scale)
|
||||
for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention))
|
||||
]
|
||||
|
||||
|
@ -388,14 +355,15 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
self.image_proj.load_state_dict(image_proj_state_dict)
|
||||
|
||||
for i, cross_attn in enumerate(self.sub_adapters):
|
||||
cross_attn_state_dict: dict[str, Tensor] = {}
|
||||
cross_attention_weights: list[Tensor] = []
|
||||
for k, v in weights.items():
|
||||
prefix = f"ip_adapter.{i:03d}."
|
||||
if not k.startswith(prefix):
|
||||
continue
|
||||
cross_attn_state_dict[k.removeprefix(prefix)] = v
|
||||
cross_attention_weights.append(v)
|
||||
|
||||
cross_attn.load_state_dict(state_dict=cross_attn_state_dict)
|
||||
assert len(cross_attention_weights) == 2
|
||||
cross_attn.load_weights(*cross_attention_weights)
|
||||
|
||||
@property
|
||||
def clip_image_encoder(self) -> CLIPImageEncoderH:
|
||||
|
@ -420,10 +388,22 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
adapter.eject()
|
||||
super().eject()
|
||||
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
return self.sub_adapters[0].scale
|
||||
|
||||
@scale.setter
|
||||
def scale(self, value: float) -> None:
|
||||
for cross_attn in self.sub_adapters:
|
||||
cross_attn.scale = value
|
||||
|
||||
def set_scale(self, scale: float) -> None:
|
||||
for cross_attn in self.sub_adapters:
|
||||
cross_attn.scale = scale
|
||||
|
||||
def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
|
||||
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})
|
||||
|
||||
# These should be concatenated to the CLIP text embedding before setting the UNet context
|
||||
def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor:
|
||||
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder
|
||||
|
|
|
@ -1168,16 +1168,7 @@ def test_diffusion_ip_adapter(
|
|||
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
|
||||
|
||||
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
|
||||
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
|
||||
|
||||
clip_text_embedding = torch.cat(
|
||||
(
|
||||
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
|
||||
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
|
||||
)
|
||||
)
|
||||
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
|
||||
|
@ -1220,16 +1211,8 @@ def test_diffusion_sdxl_ip_adapter(
|
|||
text=prompt, negative_text=negative_prompt
|
||||
)
|
||||
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
|
||||
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
||||
|
||||
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
|
||||
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
|
||||
|
||||
clip_text_embedding = torch.cat(
|
||||
(
|
||||
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
|
||||
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
|
||||
)
|
||||
)
|
||||
time_ids = sdxl.default_time_ids
|
||||
sdxl.set_num_inference_steps(n_steps)
|
||||
|
||||
|
@ -1285,16 +1268,7 @@ def test_diffusion_ip_adapter_controlnet(
|
|||
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(input_image))
|
||||
|
||||
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
|
||||
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
|
||||
|
||||
clip_text_embedding = torch.cat(
|
||||
(
|
||||
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
|
||||
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
|
||||
)
|
||||
)
|
||||
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
||||
|
||||
depth_cn_condition = image_to_tensor(
|
||||
depth_condition_image.convert("RGB"),
|
||||
|
@ -1343,16 +1317,7 @@ def test_diffusion_ip_adapter_plus(
|
|||
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(statue_image))
|
||||
|
||||
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
|
||||
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
|
||||
|
||||
clip_text_embedding = torch.cat(
|
||||
(
|
||||
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
|
||||
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
|
||||
)
|
||||
)
|
||||
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
|
||||
|
@ -1396,16 +1361,8 @@ def test_diffusion_sdxl_ip_adapter_plus(
|
|||
text=prompt, negative_text=negative_prompt
|
||||
)
|
||||
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
|
||||
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
||||
|
||||
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
|
||||
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
|
||||
|
||||
clip_text_embedding = torch.cat(
|
||||
(
|
||||
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
|
||||
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
|
||||
)
|
||||
)
|
||||
time_ids = sdxl.default_time_ids
|
||||
sdxl.set_num_inference_steps(n_steps)
|
||||
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 309 KiB After Width: | Height: | Size: 309 KiB |
|
@ -1,29 +0,0 @@
|
|||
import refiners.fluxion.layers as fl
|
||||
from refiners.foundationals.latent_diffusion.image_prompt import CrossAttentionAdapter, InjectionPoint
|
||||
|
||||
|
||||
def test_cross_attention_adapter() -> None:
|
||||
base = fl.Chain(fl.Attention(embedding_dim=4))
|
||||
adapter = CrossAttentionAdapter(base.Attention).inject()
|
||||
|
||||
assert list(base) == [adapter]
|
||||
assert len(list(adapter.layers(fl.Linear))) == 6
|
||||
assert len(list(base.layers(fl.Linear))) == 6
|
||||
|
||||
injection_points = list(adapter.layers(InjectionPoint))
|
||||
assert len(injection_points) == 4
|
||||
for ip in injection_points:
|
||||
assert len(ip) == 1
|
||||
assert isinstance(ip[0], fl.Linear)
|
||||
|
||||
adapter.eject()
|
||||
|
||||
assert len(base) == 1
|
||||
assert isinstance(base[0], fl.Attention)
|
||||
assert len(list(adapter.layers(fl.Linear))) == 2
|
||||
assert len(list(base.layers(fl.Linear))) == 4
|
||||
|
||||
injection_points = list(adapter.layers(InjectionPoint))
|
||||
assert len(injection_points) == 4
|
||||
for ip in injection_points:
|
||||
assert len(ip) == 0
|
Loading…
Reference in a new issue