mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-25 07:38:45 +00:00
refactor CrossAttentionAdapter to work with context.
This commit is contained in:
parent
a08e04c5af
commit
c9e973ba41
11
README.md
11
README.md
|
@ -128,16 +128,7 @@ with no_grad():
|
||||||
text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality"
|
text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||||
)
|
)
|
||||||
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt))
|
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt))
|
||||||
|
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
||||||
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
|
|
||||||
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
|
|
||||||
|
|
||||||
clip_text_embedding = torch.cat(
|
|
||||||
(
|
|
||||||
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
|
|
||||||
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
time_ids = sdxl.default_time_ids
|
time_ids = sdxl.default_time_ids
|
||||||
|
|
||||||
condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype)
|
condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype)
|
||||||
|
|
|
@ -133,24 +133,14 @@ def main() -> None:
|
||||||
ip_adapter_weights: dict[str, torch.Tensor] = weights["ip_adapter"]
|
ip_adapter_weights: dict[str, torch.Tensor] = weights["ip_adapter"]
|
||||||
assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2
|
assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2
|
||||||
|
|
||||||
for i, cross_attn in enumerate(ip_adapter.sub_adapters):
|
for i, _ in enumerate(ip_adapter.sub_adapters):
|
||||||
cross_attn_index = cross_attn_mapping[i]
|
cross_attn_index = cross_attn_mapping[i]
|
||||||
k_ip = f"{cross_attn_index}.to_k_ip.weight"
|
k_ip = f"{cross_attn_index}.to_k_ip.weight"
|
||||||
v_ip = f"{cross_attn_index}.to_v_ip.weight"
|
v_ip = f"{cross_attn_index}.to_v_ip.weight"
|
||||||
|
|
||||||
# Ignore Wq, Wk, Wv and Proj (hence strict=False): at runtime, they will be part of the UNet original weights
|
# the name of the key is not checked at runtime, so we keep the original name
|
||||||
|
state_dict[f"ip_adapter.{i:03d}.to_k_ip.weight"] = ip_adapter_weights[k_ip]
|
||||||
names = [k for k, _ in cross_attn.named_parameters()]
|
state_dict[f"ip_adapter.{i:03d}.to_v_ip.weight"] = ip_adapter_weights[v_ip]
|
||||||
assert len(names) == 2
|
|
||||||
|
|
||||||
cross_attn_state_dict: dict[str, Any] = {
|
|
||||||
names[0]: ip_adapter_weights[k_ip],
|
|
||||||
names[1]: ip_adapter_weights[v_ip],
|
|
||||||
}
|
|
||||||
cross_attn.load_state_dict(state_dict=cross_attn_state_dict, strict=False)
|
|
||||||
|
|
||||||
for k, v in cross_attn_state_dict.items():
|
|
||||||
state_dict[f"ip_adapter.{i:03d}.{k}"] = v
|
|
||||||
|
|
||||||
if args.half:
|
if args.half:
|
||||||
state_dict = {key: value.half() for key, value in state_dict.items()}
|
state_dict = {key: value.half() for key, value in state_dict.items()}
|
||||||
|
|
|
@ -1,17 +1,15 @@
|
||||||
import math
|
import math
|
||||||
from enum import IntEnum
|
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||||
from functools import partial
|
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar
|
|
||||||
|
|
||||||
from jaxtyping import Float
|
from jaxtyping import Float
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch import Tensor, cat, device as Device, dtype as DType, softmax, zeros_like
|
from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, zeros_like
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.adapters.adapter import Adapter
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
from refiners.fluxion.adapters.lora import Lora
|
|
||||||
from refiners.fluxion.context import Contexts
|
from refiners.fluxion.context import Contexts
|
||||||
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
||||||
|
from refiners.fluxion.layers.chain import Distribute
|
||||||
from refiners.fluxion.utils import image_to_tensor, normalize
|
from refiners.fluxion.utils import image_to_tensor, normalize
|
||||||
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
||||||
|
|
||||||
|
@ -236,120 +234,89 @@ class PerceiverResampler(fl.Chain):
|
||||||
return {"perceiver_resampler": {"x": None}}
|
return {"perceiver_resampler": {"x": None}}
|
||||||
|
|
||||||
|
|
||||||
class _CrossAttnIndex(IntEnum):
|
|
||||||
TXT_CROSS_ATTN = 0 # text cross-attention
|
|
||||||
IMG_CROSS_ATTN = 1 # image cross-attention
|
|
||||||
|
|
||||||
|
|
||||||
class InjectionPoint(fl.Chain):
|
class InjectionPoint(fl.Chain):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ImageCrossAttention(fl.Chain):
|
||||||
|
def __init__(self, text_cross_attention: fl.Attention, scale: float = 1.0) -> None:
|
||||||
|
self.scale = scale
|
||||||
|
super().__init__(
|
||||||
|
fl.Distribute(
|
||||||
|
fl.UseContext(context="ip_adapter", key="query_projection"),
|
||||||
|
fl.Chain(
|
||||||
|
fl.UseContext(context="ip_adapter", key="clip_image_embedding"),
|
||||||
|
fl.Linear(
|
||||||
|
in_features=text_cross_attention.key_embedding_dim,
|
||||||
|
out_features=text_cross_attention.inner_dim,
|
||||||
|
bias=text_cross_attention.use_bias,
|
||||||
|
device=text_cross_attention.device,
|
||||||
|
dtype=text_cross_attention.dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
fl.UseContext(context="ip_adapter", key="clip_image_embedding"),
|
||||||
|
fl.Linear(
|
||||||
|
in_features=text_cross_attention.key_embedding_dim,
|
||||||
|
out_features=text_cross_attention.inner_dim,
|
||||||
|
bias=text_cross_attention.use_bias,
|
||||||
|
device=text_cross_attention.device,
|
||||||
|
dtype=text_cross_attention.dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
ScaledDotProductAttention(
|
||||||
|
num_heads=text_cross_attention.num_heads, is_causal=text_cross_attention.is_causal
|
||||||
|
),
|
||||||
|
fl.Multiply(self.scale),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SetQueryProjection(fl.Passthrough):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(fl.GetArg(index=0), fl.SetContext(context="ip_adapter", key="query_projection"))
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
target: fl.Attention,
|
target: fl.Attention,
|
||||||
text_sequence_length: int = 77,
|
|
||||||
image_sequence_length: int = 4,
|
|
||||||
scale: float = 1.0,
|
scale: float = 1.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.text_sequence_length = text_sequence_length
|
|
||||||
self.image_sequence_length = image_sequence_length
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
with self.setup_adapter(target):
|
with self.setup_adapter(target):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Distribute(
|
|
||||||
# Note: the same query is used for image cross-attention as for text cross-attention
|
|
||||||
InjectionPoint(), # Wq
|
|
||||||
fl.Parallel(
|
|
||||||
fl.Chain(
|
|
||||||
fl.Slicing(dim=1, end=text_sequence_length),
|
|
||||||
InjectionPoint(), # Wk
|
|
||||||
),
|
|
||||||
fl.Chain(
|
|
||||||
fl.Slicing(dim=1, start=text_sequence_length),
|
|
||||||
fl.Linear(
|
|
||||||
in_features=self.target.key_embedding_dim,
|
|
||||||
out_features=self.target.inner_dim,
|
|
||||||
bias=self.target.use_bias,
|
|
||||||
device=target.device,
|
|
||||||
dtype=target.dtype,
|
|
||||||
), # Wk'
|
|
||||||
),
|
|
||||||
),
|
|
||||||
fl.Parallel(
|
|
||||||
fl.Chain(
|
|
||||||
fl.Slicing(dim=1, end=text_sequence_length),
|
|
||||||
InjectionPoint(), # Wv
|
|
||||||
),
|
|
||||||
fl.Chain(
|
|
||||||
fl.Slicing(dim=1, start=text_sequence_length),
|
|
||||||
fl.Linear(
|
|
||||||
in_features=self.target.key_embedding_dim,
|
|
||||||
out_features=self.target.inner_dim,
|
|
||||||
bias=self.target.use_bias,
|
|
||||||
device=target.device,
|
|
||||||
dtype=target.dtype,
|
|
||||||
), # Wv'
|
|
||||||
),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
fl.Sum(
|
fl.Sum(
|
||||||
fl.Chain(
|
target[:-1], # original text cross attention
|
||||||
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.TXT_CROSS_ATTN)),
|
ImageCrossAttention(text_cross_attention=target, scale=scale),
|
||||||
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
|
|
||||||
),
|
),
|
||||||
fl.Chain(
|
target[-1], # projection
|
||||||
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.IMG_CROSS_ATTN)),
|
|
||||||
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
|
|
||||||
fl.Lambda(func=self.scale_outputs),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
InjectionPoint(), # proj
|
|
||||||
)
|
)
|
||||||
|
self.ensure_find(fl.Attention).insert_after_type(Distribute, SetQueryProjection())
|
||||||
|
|
||||||
def select_qkv(
|
@property
|
||||||
self, query: Tensor, keys: tuple[Tensor, Tensor], values: tuple[Tensor, Tensor], index: _CrossAttnIndex
|
def image_cross_attention(self) -> ImageCrossAttention:
|
||||||
) -> tuple[Tensor, Tensor, Tensor]:
|
return self.ensure_find(ImageCrossAttention)
|
||||||
return (query, keys[index.value], values[index.value])
|
|
||||||
|
|
||||||
def scale_outputs(self, x: Tensor) -> Tensor:
|
@property
|
||||||
return x * self.scale
|
def image_key_projection(self) -> fl.Linear:
|
||||||
|
return self.image_cross_attention.Distribute[1].Linear
|
||||||
|
|
||||||
def _predicate(self, k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
|
@property
|
||||||
def f(m: fl.Module, _: fl.Chain) -> bool:
|
def image_value_projection(self) -> fl.Linear:
|
||||||
if isinstance(m, Lora): # do not adapt LoRAs
|
return self.image_cross_attention.Distribute[2].Linear
|
||||||
raise StopIteration
|
|
||||||
return isinstance(m, k)
|
|
||||||
|
|
||||||
return f
|
@property
|
||||||
|
def scale(self) -> float:
|
||||||
|
return self.image_cross_attention.scale
|
||||||
|
|
||||||
def _target_linears(self) -> list[fl.Linear]:
|
@scale.setter
|
||||||
return [m for m, _ in self.target.walk(self._predicate(fl.Linear)) if isinstance(m, fl.Linear)]
|
def scale(self, value: float) -> None:
|
||||||
|
self.image_cross_attention.scale = value
|
||||||
|
|
||||||
def inject(self: "CrossAttentionAdapter", parent: fl.Chain | None = None) -> "CrossAttentionAdapter":
|
def load_weights(self, key_tensor: Tensor, value_tensor: Tensor) -> None:
|
||||||
linears = self._target_linears()
|
self.image_key_projection.weight = nn.Parameter(key_tensor)
|
||||||
assert len(linears) == 4 # Wq, Wk, Wv and Proj
|
self.image_value_projection.weight = nn.Parameter(value_tensor)
|
||||||
|
self.image_cross_attention.to(self.device, self.dtype)
|
||||||
injection_points = list(self.layers(InjectionPoint))
|
|
||||||
assert len(injection_points) == 4
|
|
||||||
|
|
||||||
for linear, ip in zip(linears, injection_points):
|
|
||||||
ip.append(linear)
|
|
||||||
assert len(ip) == 1
|
|
||||||
|
|
||||||
return super().inject(parent)
|
|
||||||
|
|
||||||
def eject(self) -> None:
|
|
||||||
injection_points = list(self.layers(InjectionPoint))
|
|
||||||
assert len(injection_points) == 4
|
|
||||||
|
|
||||||
for ip in injection_points:
|
|
||||||
ip.pop()
|
|
||||||
assert len(ip) == 0
|
|
||||||
|
|
||||||
super().eject()
|
|
||||||
|
|
||||||
|
|
||||||
class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
|
@ -377,7 +344,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
self._image_proj = [image_proj]
|
self._image_proj = [image_proj]
|
||||||
|
|
||||||
self.sub_adapters = [
|
self.sub_adapters = [
|
||||||
CrossAttentionAdapter(target=cross_attn, scale=scale, image_sequence_length=self.image_proj.num_tokens)
|
CrossAttentionAdapter(target=cross_attn, scale=scale)
|
||||||
for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention))
|
for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention))
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -388,14 +355,15 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
self.image_proj.load_state_dict(image_proj_state_dict)
|
self.image_proj.load_state_dict(image_proj_state_dict)
|
||||||
|
|
||||||
for i, cross_attn in enumerate(self.sub_adapters):
|
for i, cross_attn in enumerate(self.sub_adapters):
|
||||||
cross_attn_state_dict: dict[str, Tensor] = {}
|
cross_attention_weights: list[Tensor] = []
|
||||||
for k, v in weights.items():
|
for k, v in weights.items():
|
||||||
prefix = f"ip_adapter.{i:03d}."
|
prefix = f"ip_adapter.{i:03d}."
|
||||||
if not k.startswith(prefix):
|
if not k.startswith(prefix):
|
||||||
continue
|
continue
|
||||||
cross_attn_state_dict[k.removeprefix(prefix)] = v
|
cross_attention_weights.append(v)
|
||||||
|
|
||||||
cross_attn.load_state_dict(state_dict=cross_attn_state_dict)
|
assert len(cross_attention_weights) == 2
|
||||||
|
cross_attn.load_weights(*cross_attention_weights)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def clip_image_encoder(self) -> CLIPImageEncoderH:
|
def clip_image_encoder(self) -> CLIPImageEncoderH:
|
||||||
|
@ -420,10 +388,22 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
adapter.eject()
|
adapter.eject()
|
||||||
super().eject()
|
super().eject()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scale(self) -> float:
|
||||||
|
return self.sub_adapters[0].scale
|
||||||
|
|
||||||
|
@scale.setter
|
||||||
|
def scale(self, value: float) -> None:
|
||||||
|
for cross_attn in self.sub_adapters:
|
||||||
|
cross_attn.scale = value
|
||||||
|
|
||||||
def set_scale(self, scale: float) -> None:
|
def set_scale(self, scale: float) -> None:
|
||||||
for cross_attn in self.sub_adapters:
|
for cross_attn in self.sub_adapters:
|
||||||
cross_attn.scale = scale
|
cross_attn.scale = scale
|
||||||
|
|
||||||
|
def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
|
||||||
|
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})
|
||||||
|
|
||||||
# These should be concatenated to the CLIP text embedding before setting the UNet context
|
# These should be concatenated to the CLIP text embedding before setting the UNet context
|
||||||
def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor:
|
def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor:
|
||||||
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder
|
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder
|
||||||
|
|
|
@ -1168,16 +1168,7 @@ def test_diffusion_ip_adapter(
|
||||||
|
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
|
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
|
||||||
|
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
||||||
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
|
|
||||||
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
|
|
||||||
|
|
||||||
clip_text_embedding = torch.cat(
|
|
||||||
(
|
|
||||||
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
|
|
||||||
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
@ -1220,16 +1211,8 @@ def test_diffusion_sdxl_ip_adapter(
|
||||||
text=prompt, negative_text=negative_prompt
|
text=prompt, negative_text=negative_prompt
|
||||||
)
|
)
|
||||||
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
|
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
|
||||||
|
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
||||||
|
|
||||||
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
|
|
||||||
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
|
|
||||||
|
|
||||||
clip_text_embedding = torch.cat(
|
|
||||||
(
|
|
||||||
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
|
|
||||||
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
time_ids = sdxl.default_time_ids
|
time_ids = sdxl.default_time_ids
|
||||||
sdxl.set_num_inference_steps(n_steps)
|
sdxl.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
@ -1285,16 +1268,7 @@ def test_diffusion_ip_adapter_controlnet(
|
||||||
|
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(input_image))
|
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(input_image))
|
||||||
|
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
||||||
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
|
|
||||||
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
|
|
||||||
|
|
||||||
clip_text_embedding = torch.cat(
|
|
||||||
(
|
|
||||||
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
|
|
||||||
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
depth_cn_condition = image_to_tensor(
|
depth_cn_condition = image_to_tensor(
|
||||||
depth_condition_image.convert("RGB"),
|
depth_condition_image.convert("RGB"),
|
||||||
|
@ -1343,16 +1317,7 @@ def test_diffusion_ip_adapter_plus(
|
||||||
|
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(statue_image))
|
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(statue_image))
|
||||||
|
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
||||||
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
|
|
||||||
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
|
|
||||||
|
|
||||||
clip_text_embedding = torch.cat(
|
|
||||||
(
|
|
||||||
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
|
|
||||||
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
@ -1396,16 +1361,8 @@ def test_diffusion_sdxl_ip_adapter_plus(
|
||||||
text=prompt, negative_text=negative_prompt
|
text=prompt, negative_text=negative_prompt
|
||||||
)
|
)
|
||||||
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
|
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
|
||||||
|
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
||||||
|
|
||||||
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
|
|
||||||
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
|
|
||||||
|
|
||||||
clip_text_embedding = torch.cat(
|
|
||||||
(
|
|
||||||
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
|
|
||||||
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
time_ids = sdxl.default_time_ids
|
time_ids = sdxl.default_time_ids
|
||||||
sdxl.set_num_inference_steps(n_steps)
|
sdxl.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 309 KiB After Width: | Height: | Size: 309 KiB |
|
@ -1,29 +0,0 @@
|
||||||
import refiners.fluxion.layers as fl
|
|
||||||
from refiners.foundationals.latent_diffusion.image_prompt import CrossAttentionAdapter, InjectionPoint
|
|
||||||
|
|
||||||
|
|
||||||
def test_cross_attention_adapter() -> None:
|
|
||||||
base = fl.Chain(fl.Attention(embedding_dim=4))
|
|
||||||
adapter = CrossAttentionAdapter(base.Attention).inject()
|
|
||||||
|
|
||||||
assert list(base) == [adapter]
|
|
||||||
assert len(list(adapter.layers(fl.Linear))) == 6
|
|
||||||
assert len(list(base.layers(fl.Linear))) == 6
|
|
||||||
|
|
||||||
injection_points = list(adapter.layers(InjectionPoint))
|
|
||||||
assert len(injection_points) == 4
|
|
||||||
for ip in injection_points:
|
|
||||||
assert len(ip) == 1
|
|
||||||
assert isinstance(ip[0], fl.Linear)
|
|
||||||
|
|
||||||
adapter.eject()
|
|
||||||
|
|
||||||
assert len(base) == 1
|
|
||||||
assert isinstance(base[0], fl.Attention)
|
|
||||||
assert len(list(adapter.layers(fl.Linear))) == 2
|
|
||||||
assert len(list(base.layers(fl.Linear))) == 4
|
|
||||||
|
|
||||||
injection_points = list(adapter.layers(InjectionPoint))
|
|
||||||
assert len(injection_points) == 4
|
|
||||||
for ip in injection_points:
|
|
||||||
assert len(ip) == 0
|
|
Loading…
Reference in a new issue