mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
add IP-Adapter plus (aka fine-grained features)
This commit is contained in:
parent
88e454f1cb
commit
5fc6767a4a
|
@ -70,12 +70,14 @@ def main() -> None:
|
|||
assert isinstance(weights, dict)
|
||||
assert sorted(weights.keys()) == ["image_proj", "ip_adapter"]
|
||||
|
||||
fine_grained = "latents" in weights["image_proj"] # aka IP-Adapter plus
|
||||
|
||||
match len(weights["ip_adapter"]):
|
||||
case 32:
|
||||
ip_adapter = SD1IPAdapter(target=SD1UNet(in_channels=4))
|
||||
ip_adapter = SD1IPAdapter(target=SD1UNet(in_channels=4), fine_grained=fine_grained)
|
||||
cross_attn_mapping = CROSS_ATTN_MAPPING["sd15"]
|
||||
case 140:
|
||||
ip_adapter = SDXLIPAdapter(target=SDXLUNet(in_channels=4))
|
||||
ip_adapter = SDXLIPAdapter(target=SDXLUNet(in_channels=4), fine_grained=fine_grained)
|
||||
cross_attn_mapping = CROSS_ATTN_MAPPING["sdxl"]
|
||||
case _:
|
||||
raise ValueError("Unexpected number of keys in input checkpoint")
|
||||
|
@ -86,12 +88,43 @@ def main() -> None:
|
|||
state_dict: dict[str, torch.Tensor] = {}
|
||||
|
||||
image_proj_weights = weights["image_proj"]
|
||||
image_proj_state_dict: dict[str, torch.Tensor] = {
|
||||
"Linear.weight": image_proj_weights["proj.weight"],
|
||||
"Linear.bias": image_proj_weights["proj.bias"],
|
||||
"LayerNorm.weight": image_proj_weights["norm.weight"],
|
||||
"LayerNorm.bias": image_proj_weights["norm.bias"],
|
||||
}
|
||||
image_proj_state_dict: dict[str, torch.Tensor]
|
||||
|
||||
if fine_grained:
|
||||
w = image_proj_weights
|
||||
image_proj_state_dict = {
|
||||
"LatentsEncoder.Parallel.Parameter.parameter": w["latents"].squeeze(0), # drop batch dim = 1
|
||||
"Linear_1.weight": w["proj_in.weight"],
|
||||
"Linear_1.bias": w["proj_in.bias"],
|
||||
"Linear_2.weight": w["proj_out.weight"],
|
||||
"Linear_2.bias": w["proj_out.bias"],
|
||||
"LayerNorm.weight": w["norm_out.weight"],
|
||||
"LayerNorm.bias": w["norm_out.bias"],
|
||||
}
|
||||
for i in range(4):
|
||||
t_pfx, s_pfx = f"Transformer.TransformerLayer_{i+1}.Residual_", f"layers.{i}."
|
||||
image_proj_state_dict.update(
|
||||
{
|
||||
f"{t_pfx}1.Chain.PerceiverAttention.Distribute.LayerNorm_1.weight": w[f"{s_pfx}0.norm1.weight"],
|
||||
f"{t_pfx}1.Chain.PerceiverAttention.Distribute.LayerNorm_1.bias": w[f"{s_pfx}0.norm1.bias"],
|
||||
f"{t_pfx}1.Chain.PerceiverAttention.Distribute.LayerNorm_2.weight": w[f"{s_pfx}0.norm2.weight"],
|
||||
f"{t_pfx}1.Chain.PerceiverAttention.Distribute.LayerNorm_2.bias": w[f"{s_pfx}0.norm2.bias"],
|
||||
f"{t_pfx}1.Chain.PerceiverAttention.Parallel.Chain_2.Linear.weight": w[f"{s_pfx}0.to_q.weight"],
|
||||
f"{t_pfx}1.Chain.PerceiverAttention.Parallel.Chain_1.Linear.weight": w[f"{s_pfx}0.to_kv.weight"],
|
||||
f"{t_pfx}1.Chain.PerceiverAttention.Linear.weight": w[f"{s_pfx}0.to_out.weight"],
|
||||
f"{t_pfx}2.Chain.LayerNorm.weight": w[f"{s_pfx}1.0.weight"],
|
||||
f"{t_pfx}2.Chain.LayerNorm.bias": w[f"{s_pfx}1.0.bias"],
|
||||
f"{t_pfx}2.Chain.FeedForward.Linear_1.weight": w[f"{s_pfx}1.1.weight"],
|
||||
f"{t_pfx}2.Chain.FeedForward.Linear_2.weight": w[f"{s_pfx}1.3.weight"],
|
||||
}
|
||||
)
|
||||
else:
|
||||
image_proj_state_dict = {
|
||||
"Linear.weight": image_proj_weights["proj.weight"],
|
||||
"Linear.bias": image_proj_weights["proj.bias"],
|
||||
"LayerNorm.weight": image_proj_weights["norm.weight"],
|
||||
"LayerNorm.bias": image_proj_weights["norm.bias"],
|
||||
}
|
||||
ip_adapter.image_proj.load_state_dict(state_dict=image_proj_state_dict)
|
||||
|
||||
for k, v in image_proj_state_dict.items():
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
from enum import IntEnum
|
||||
from functools import partial
|
||||
from typing import Generic, TypeVar, Any, Callable, TYPE_CHECKING
|
||||
import math
|
||||
|
||||
from torch import Tensor, cat, zeros_like, device as Device, dtype as DType
|
||||
from jaxtyping import Float
|
||||
from torch import Tensor, cat, softmax, zeros_like, device as Device, dtype as DType
|
||||
from PIL import Image
|
||||
|
||||
from refiners.fluxion.adapters.adapter import Adapter
|
||||
from refiners.fluxion.adapters.lora import Lora
|
||||
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||
from refiners.fluxion.context import Contexts
|
||||
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
||||
from refiners.fluxion.utils import image_to_tensor, normalize
|
||||
import refiners.fluxion.layers as fl
|
||||
|
@ -26,25 +28,219 @@ class ImageProjection(fl.Chain):
|
|||
self,
|
||||
clip_image_embedding_dim: int = 1024,
|
||||
clip_text_embedding_dim: int = 768,
|
||||
sequence_length: int = 4,
|
||||
num_tokens: int = 4,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.clip_image_embedding_dim = clip_image_embedding_dim
|
||||
self.clip_text_embedding_dim = clip_text_embedding_dim
|
||||
self.sequence_length = sequence_length
|
||||
self.num_tokens = num_tokens
|
||||
super().__init__(
|
||||
fl.Linear(
|
||||
in_features=clip_image_embedding_dim,
|
||||
out_features=clip_text_embedding_dim * sequence_length,
|
||||
out_features=clip_text_embedding_dim * num_tokens,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.Reshape(sequence_length, clip_text_embedding_dim),
|
||||
fl.Reshape(num_tokens, clip_text_embedding_dim),
|
||||
fl.LayerNorm(normalized_shape=clip_text_embedding_dim, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
|
||||
class FeedForward(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
feedforward_dim: int,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.embedding_dim = embedding_dim
|
||||
self.feedforward_dim = feedforward_dim
|
||||
super().__init__(
|
||||
fl.Linear(
|
||||
in_features=self.embedding_dim,
|
||||
out_features=self.feedforward_dim,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.GeLU(),
|
||||
fl.Linear(
|
||||
in_features=self.feedforward_dim,
|
||||
out_features=self.embedding_dim,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Adapted from https://github.com/tencent-ailab/IP-Adapter/blob/6212981/ip_adapter/resampler.py
|
||||
# See also:
|
||||
# - https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
||||
# - https://github.com/lucidrains/flamingo-pytorch
|
||||
class PerceiverScaledDotProductAttention(fl.Module):
|
||||
def __init__(self, head_dim: int, num_heads: int) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
# See https://github.com/tencent-ailab/IP-Adapter/blob/6212981/ip_adapter/resampler.py#L69
|
||||
# -> "More stable with f16 than dividing afterwards"
|
||||
self.scale = 1 / math.sqrt(math.sqrt(head_dim))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
key_value: Float[Tensor, "batch sequence_length 2*head_dim*num_heads"],
|
||||
query: Float[Tensor, "batch num_tokens head_dim*num_heads"],
|
||||
) -> Float[Tensor, "batch num_tokens head_dim*num_heads"]:
|
||||
bs, length, _ = query.shape
|
||||
key, value = key_value.chunk(2, dim=-1)
|
||||
|
||||
q = self.reshape_tensor(query)
|
||||
k = self.reshape_tensor(key)
|
||||
v = self.reshape_tensor(value)
|
||||
|
||||
attention = (q * self.scale) @ (k * self.scale).transpose(-2, -1)
|
||||
attention = softmax(input=attention.float(), dim=-1).type(attention.dtype)
|
||||
attention = attention @ v
|
||||
|
||||
return attention.permute(0, 2, 1, 3).reshape(bs, length, -1)
|
||||
|
||||
def reshape_tensor(
|
||||
self, x: Float[Tensor, "batch length head_dim*num_heads"]
|
||||
) -> Float[Tensor, "batch num_heads length head_dim"]:
|
||||
bs, length, _ = x.shape
|
||||
x = x.view(bs, length, self.num_heads, -1)
|
||||
x = x.transpose(1, 2)
|
||||
x = x.reshape(bs, self.num_heads, length, -1)
|
||||
return x
|
||||
|
||||
|
||||
class PerceiverAttention(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
head_dim: int = 64,
|
||||
num_heads: int = 8,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.embedding_dim = embedding_dim
|
||||
self.head_dim = head_dim
|
||||
self.inner_dim = head_dim * num_heads
|
||||
super().__init__(
|
||||
fl.Distribute(
|
||||
fl.LayerNorm(normalized_shape=self.embedding_dim, device=device, dtype=dtype),
|
||||
fl.LayerNorm(normalized_shape=self.embedding_dim, device=device, dtype=dtype),
|
||||
),
|
||||
fl.Parallel(
|
||||
fl.Chain(
|
||||
fl.Lambda(func=self.to_kv),
|
||||
fl.Linear(
|
||||
in_features=self.embedding_dim,
|
||||
out_features=2 * self.inner_dim,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
), # Wkv
|
||||
),
|
||||
fl.Chain(
|
||||
fl.GetArg(index=1),
|
||||
fl.Linear(
|
||||
in_features=self.embedding_dim,
|
||||
out_features=self.inner_dim,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
), # Wq
|
||||
),
|
||||
),
|
||||
PerceiverScaledDotProductAttention(head_dim=head_dim, num_heads=num_heads),
|
||||
fl.Linear(
|
||||
in_features=self.inner_dim, out_features=self.embedding_dim, bias=False, device=device, dtype=dtype
|
||||
),
|
||||
)
|
||||
|
||||
def to_kv(self, x: Tensor, latents: Tensor) -> Tensor:
|
||||
return cat((x, latents), dim=-2)
|
||||
|
||||
|
||||
class LatentsEncoder(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
num_tokens: int,
|
||||
embeddding_dim: int,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
fl.Parallel(fl.Identity(), fl.Parameter(num_tokens, embeddding_dim, device=device, dtype=dtype)),
|
||||
fl.Lambda(lambda x, p: p.expand(x.shape[0], -1, -1)),
|
||||
)
|
||||
|
||||
|
||||
class Transformer(fl.Chain):
|
||||
pass
|
||||
|
||||
|
||||
class TransformerLayer(fl.Chain):
|
||||
pass
|
||||
|
||||
|
||||
class PerceiverResampler(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
latents_dim: int = 1024,
|
||||
num_attention_layers: int = 8,
|
||||
num_attention_heads: int = 16,
|
||||
head_dim: int = 64,
|
||||
num_tokens: int = 8,
|
||||
input_dim: int = 768,
|
||||
output_dim: int = 1024,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.latents_dim = latents_dim
|
||||
self.num_attention_layers = num_attention_layers
|
||||
self.head_dim = head_dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_tokens = num_tokens
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.feedforward_dim = 4 * self.latents_dim
|
||||
super().__init__(
|
||||
fl.Linear(in_features=input_dim, out_features=latents_dim, device=device, dtype=dtype),
|
||||
fl.SetContext(context="perceiver_resampler", key="x"),
|
||||
LatentsEncoder(num_tokens=num_tokens, embeddding_dim=latents_dim, device=device, dtype=dtype),
|
||||
Transformer(
|
||||
TransformerLayer(
|
||||
fl.Residual(
|
||||
fl.Parallel(fl.UseContext(context="perceiver_resampler", key="x"), fl.Identity()),
|
||||
PerceiverAttention(
|
||||
embedding_dim=latents_dim,
|
||||
head_dim=head_dim,
|
||||
num_heads=num_attention_heads,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
),
|
||||
fl.Residual(
|
||||
fl.LayerNorm(normalized_shape=latents_dim, device=device, dtype=dtype),
|
||||
FeedForward(
|
||||
embedding_dim=latents_dim, feedforward_dim=self.feedforward_dim, device=device, dtype=dtype
|
||||
),
|
||||
),
|
||||
)
|
||||
for _ in range(num_attention_layers)
|
||||
),
|
||||
fl.Linear(in_features=latents_dim, out_features=output_dim, device=device, dtype=dtype),
|
||||
fl.LayerNorm(normalized_shape=output_dim, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
def init_context(self) -> Contexts:
|
||||
return {"perceiver_resampler": {"x": None}}
|
||||
|
||||
|
||||
class _CrossAttnIndex(IntEnum):
|
||||
TXT_CROSS_ATTN = 0 # text cross-attention
|
||||
IMG_CROSS_ATTN = 1 # image cross-attention
|
||||
|
@ -164,32 +360,29 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
|||
class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||
# Prevent PyTorch module registration
|
||||
_clip_image_encoder: list[CLIPImageEncoderH]
|
||||
_image_proj: list[ImageProjection]
|
||||
_grid_image_encoder: list[CLIPImageEncoderH]
|
||||
_image_proj: list[fl.Module]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: T,
|
||||
clip_image_encoder: CLIPImageEncoderH | None = None,
|
||||
clip_image_encoder: CLIPImageEncoderH,
|
||||
image_proj: fl.Module,
|
||||
scale: float = 1.0,
|
||||
fine_grained: bool = False,
|
||||
weights: dict[str, Tensor] | None = None,
|
||||
) -> None:
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(target)
|
||||
|
||||
cross_attn_2d = target.ensure_find(CrossAttentionBlock2d)
|
||||
|
||||
self._clip_image_encoder = [clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)]
|
||||
self._image_proj = [
|
||||
ImageProjection(
|
||||
clip_image_embedding_dim=self.clip_image_encoder.output_dim,
|
||||
clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,
|
||||
device=target.device,
|
||||
dtype=target.dtype,
|
||||
)
|
||||
]
|
||||
self.fine_grained = fine_grained
|
||||
self._clip_image_encoder = [clip_image_encoder]
|
||||
if fine_grained:
|
||||
self._grid_image_encoder = [self.convert_to_grid_features(clip_image_encoder)]
|
||||
self._image_proj = [image_proj]
|
||||
|
||||
self.sub_adapters = [
|
||||
CrossAttentionAdapter(target=cross_attn, scale=scale)
|
||||
CrossAttentionAdapter(target=cross_attn, scale=scale, image_sequence_length=self.image_proj.num_tokens)
|
||||
for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention))
|
||||
]
|
||||
|
||||
|
@ -214,7 +407,12 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
return self._clip_image_encoder[0]
|
||||
|
||||
@property
|
||||
def image_proj(self) -> ImageProjection:
|
||||
def grid_image_encoder(self) -> CLIPImageEncoderH:
|
||||
assert hasattr(self, "_grid_image_encoder")
|
||||
return self._grid_image_encoder[0]
|
||||
|
||||
@property
|
||||
def image_proj(self) -> fl.Module:
|
||||
return self._image_proj[0]
|
||||
|
||||
def inject(self: "TIPAdapter", parent: fl.Chain | None = None) -> "TIPAdapter":
|
||||
|
@ -228,10 +426,16 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
super().eject()
|
||||
|
||||
# These should be concatenated to the CLIP text embedding before setting the UNet context
|
||||
def compute_clip_image_embedding(self, image_prompt: Tensor | None) -> Tensor:
|
||||
clip_embedding = self.clip_image_encoder(image_prompt)
|
||||
def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor:
|
||||
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder
|
||||
clip_embedding = image_encoder(image_prompt)
|
||||
conditional_embedding = self.image_proj(clip_embedding)
|
||||
negative_embedding = self.image_proj(zeros_like(clip_embedding))
|
||||
if not self.fine_grained:
|
||||
negative_embedding = self.image_proj(zeros_like(clip_embedding))
|
||||
else:
|
||||
# See https://github.com/tencent-ailab/IP-Adapter/blob/d580c50/tutorial_train_plus.py#L351-L352
|
||||
clip_embedding = image_encoder(zeros_like(image_prompt))
|
||||
negative_embedding = self.image_proj(clip_embedding)
|
||||
return cat((negative_embedding, conditional_embedding))
|
||||
|
||||
def preprocess_image(
|
||||
|
@ -247,3 +451,16 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean,
|
||||
std=[0.26862954, 0.26130258, 0.27577711] if std is None else std,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def convert_to_grid_features(clip_image_encoder: CLIPImageEncoderH) -> CLIPImageEncoderH:
|
||||
encoder_clone = clip_image_encoder.structural_copy()
|
||||
assert isinstance(encoder_clone[-1], fl.Linear) # final proj
|
||||
assert isinstance(encoder_clone[-2], fl.LayerNorm) # final normalization
|
||||
assert isinstance(encoder_clone[-3], fl.Lambda) # pooling (classif token)
|
||||
for _ in range(3):
|
||||
encoder_clone.pop()
|
||||
transfomer_layers = encoder_clone[-1]
|
||||
assert isinstance(transfomer_layers, fl.Chain) and len(transfomer_layers) == 32
|
||||
transfomer_layers.pop()
|
||||
return encoder_clone
|
||||
|
|
|
@ -1,6 +1,53 @@
|
|||
from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter
|
||||
from torch import Tensor
|
||||
|
||||
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||
from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter, ImageProjection, PerceiverResampler
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1UNet
|
||||
|
||||
|
||||
class SD1IPAdapter(IPAdapter[SD1UNet]):
|
||||
pass
|
||||
def __init__(
|
||||
self,
|
||||
target: SD1UNet,
|
||||
clip_image_encoder: CLIPImageEncoderH | None = None,
|
||||
image_proj: ImageProjection | PerceiverResampler | None = None,
|
||||
scale: float = 1.0,
|
||||
fine_grained: bool = False,
|
||||
weights: dict[str, Tensor] | None = None,
|
||||
) -> None:
|
||||
clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)
|
||||
|
||||
if image_proj is None:
|
||||
cross_attn_2d = target.ensure_find(CrossAttentionBlock2d)
|
||||
image_proj = (
|
||||
ImageProjection(
|
||||
clip_image_embedding_dim=clip_image_encoder.output_dim,
|
||||
clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,
|
||||
device=target.device,
|
||||
dtype=target.dtype,
|
||||
)
|
||||
if not fine_grained
|
||||
else PerceiverResampler(
|
||||
latents_dim=cross_attn_2d.context_embedding_dim,
|
||||
num_attention_layers=4,
|
||||
num_attention_heads=12,
|
||||
head_dim=64,
|
||||
num_tokens=16,
|
||||
input_dim=clip_image_encoder.embedding_dim, # = dim before final projection
|
||||
output_dim=cross_attn_2d.context_embedding_dim,
|
||||
device=target.device,
|
||||
dtype=target.dtype,
|
||||
)
|
||||
)
|
||||
elif fine_grained:
|
||||
assert isinstance(image_proj, PerceiverResampler)
|
||||
|
||||
super().__init__(
|
||||
target=target,
|
||||
clip_image_encoder=clip_image_encoder,
|
||||
image_proj=image_proj,
|
||||
scale=scale,
|
||||
fine_grained=fine_grained,
|
||||
weights=weights,
|
||||
)
|
||||
|
|
|
@ -1,6 +1,53 @@
|
|||
from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter
|
||||
from torch import Tensor
|
||||
|
||||
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||
from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter, ImageProjection, PerceiverResampler
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
|
||||
|
||||
|
||||
class SDXLIPAdapter(IPAdapter[SDXLUNet]):
|
||||
pass
|
||||
def __init__(
|
||||
self,
|
||||
target: SDXLUNet,
|
||||
clip_image_encoder: CLIPImageEncoderH | None = None,
|
||||
image_proj: ImageProjection | PerceiverResampler | None = None,
|
||||
scale: float = 1.0,
|
||||
fine_grained: bool = False,
|
||||
weights: dict[str, Tensor] | None = None,
|
||||
) -> None:
|
||||
clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)
|
||||
|
||||
if image_proj is None:
|
||||
cross_attn_2d = target.ensure_find(CrossAttentionBlock2d)
|
||||
image_proj = (
|
||||
ImageProjection(
|
||||
clip_image_embedding_dim=clip_image_encoder.output_dim,
|
||||
clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,
|
||||
device=target.device,
|
||||
dtype=target.dtype,
|
||||
)
|
||||
if not fine_grained
|
||||
else PerceiverResampler(
|
||||
latents_dim=1280, # not `cross_attn_2d.context_embedding_dim` in this case
|
||||
num_attention_layers=4,
|
||||
num_attention_heads=20,
|
||||
head_dim=64,
|
||||
num_tokens=16,
|
||||
input_dim=clip_image_encoder.embedding_dim, # = dim before final projection
|
||||
output_dim=cross_attn_2d.context_embedding_dim,
|
||||
device=target.device,
|
||||
dtype=target.dtype,
|
||||
)
|
||||
)
|
||||
elif fine_grained:
|
||||
assert isinstance(image_proj, PerceiverResampler)
|
||||
|
||||
super().__init__(
|
||||
target=target,
|
||||
clip_image_encoder=clip_image_encoder,
|
||||
image_proj=image_proj,
|
||||
scale=scale,
|
||||
fine_grained=fine_grained,
|
||||
weights=weights,
|
||||
)
|
||||
|
|
|
@ -54,6 +54,11 @@ def woman_image(ref_path: Path) -> Image.Image:
|
|||
return Image.open(ref_path / "woman.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def statue_image(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "statue.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_std_random_init(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_std_random_init.png").convert("RGB")
|
||||
|
@ -79,11 +84,21 @@ def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image:
|
|||
return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_ip_adapter_plus_statue(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_image_ip_adapter_plus_statue.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_sdxl_ip_adapter_woman(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_sdxl_ip_adapter_plus_woman(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_image_sdxl_ip_adapter_plus_woman.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_ip_adapter_controlnet(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_ip_adapter_controlnet.png").convert("RGB")
|
||||
|
@ -255,6 +270,15 @@ def ip_adapter_weights(test_weights_path: Path) -> Path:
|
|||
return ip_adapter_weights
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ip_adapter_plus_weights(test_weights_path: Path) -> Path:
|
||||
ip_adapter_weights = test_weights_path / "ip-adapter-plus_sd15.safetensors"
|
||||
if not ip_adapter_weights.is_file():
|
||||
warn(f"could not find weights at {ip_adapter_weights}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return ip_adapter_weights
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sdxl_ip_adapter_weights(test_weights_path: Path) -> Path:
|
||||
ip_adapter_weights = test_weights_path / "ip-adapter_sdxl_vit-h.safetensors"
|
||||
|
@ -264,6 +288,15 @@ def sdxl_ip_adapter_weights(test_weights_path: Path) -> Path:
|
|||
return ip_adapter_weights
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sdxl_ip_adapter_plus_weights(test_weights_path: Path) -> Path:
|
||||
ip_adapter_weights = test_weights_path / "ip-adapter-plus_sdxl_vit-h.safetensors"
|
||||
if not ip_adapter_weights.is_file():
|
||||
warn(f"could not find weights at {ip_adapter_weights}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return ip_adapter_weights
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def image_encoder_weights(test_weights_path: Path) -> Path:
|
||||
image_encoder_weights = test_weights_path / "CLIPImageEncoderH.safetensors"
|
||||
|
@ -1189,6 +1222,115 @@ def test_diffusion_ip_adapter_controlnet(
|
|||
ensure_similar_images(predicted_image, expected_image_ip_adapter_controlnet)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_diffusion_ip_adapter_plus(
|
||||
sd15_ddim_lda_ft_mse: StableDiffusion_1,
|
||||
ip_adapter_plus_weights: Path,
|
||||
image_encoder_weights: Path,
|
||||
statue_image: Image.Image,
|
||||
expected_image_ip_adapter_plus_statue: Image.Image,
|
||||
test_device: torch.device,
|
||||
):
|
||||
sd15 = sd15_ddim_lda_ft_mse.to(dtype=torch.float16)
|
||||
n_steps = 50
|
||||
|
||||
prompt = "best quality, high quality"
|
||||
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||
|
||||
ip_adapter = SD1IPAdapter(
|
||||
target=sd15.unet, weights=load_from_safetensors(ip_adapter_plus_weights), fine_grained=True
|
||||
)
|
||||
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
||||
ip_adapter.inject()
|
||||
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(statue_image))
|
||||
|
||||
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
|
||||
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
|
||||
|
||||
clip_text_embedding = torch.cat(
|
||||
(
|
||||
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
|
||||
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
|
||||
)
|
||||
)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
|
||||
manual_seed(42) # seed=42 is used in the official IP-Adapter demo
|
||||
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
||||
|
||||
for step in sd15.steps:
|
||||
x = sd15(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
predicted_image.save("output.png")
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_ip_adapter_plus_statue, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_diffusion_sdxl_ip_adapter_plus(
|
||||
sdxl_ddim: StableDiffusion_XL,
|
||||
sdxl_ip_adapter_plus_weights: Path,
|
||||
image_encoder_weights: Path,
|
||||
woman_image: Image.Image,
|
||||
expected_image_sdxl_ip_adapter_plus_woman: Image.Image,
|
||||
test_device: torch.device,
|
||||
):
|
||||
sdxl = sdxl_ddim.to(dtype=torch.float16)
|
||||
n_steps = 30
|
||||
|
||||
prompt = "best quality, high quality"
|
||||
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||
|
||||
ip_adapter = SDXLIPAdapter(
|
||||
target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_plus_weights), fine_grained=True
|
||||
)
|
||||
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
||||
ip_adapter.inject()
|
||||
|
||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||
text=prompt, negative_text=negative_prompt
|
||||
)
|
||||
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
|
||||
|
||||
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
|
||||
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
|
||||
|
||||
clip_text_embedding = torch.cat(
|
||||
(
|
||||
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
|
||||
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
|
||||
)
|
||||
)
|
||||
time_ids = sdxl.default_time_ids
|
||||
sdxl.set_num_inference_steps(n_steps)
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16)
|
||||
|
||||
for step in sdxl.steps:
|
||||
x = sdxl(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
pooled_text_embedding=pooled_text_embedding,
|
||||
time_ids=time_ids,
|
||||
condition_scale=5,
|
||||
)
|
||||
sdxl.lda.to(dtype=torch.float32)
|
||||
predicted_image = sdxl.lda.decode_latents(x.to(dtype=torch.float32))
|
||||
predicted_image.save("output.png")
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_sdxl_random_init(
|
||||
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device
|
||||
|
|
|
@ -41,6 +41,7 @@ Special cases:
|
|||
- `expected_image_sdxl_ip_adapter_woman.png`
|
||||
- `expected_ip_adapter_controlnet.png`
|
||||
- `expected_t2i_adapter_xl_canny.png`
|
||||
- `expected_image_sdxl_ip_adapter_plus_woman.png`
|
||||
|
||||
## Other images
|
||||
|
||||
|
@ -68,6 +69,8 @@ Special cases:
|
|||
|
||||
- `woman.png` [comes from tencent-ailab/IP-Adapter](https://github.com/tencent-ailab/IP-Adapter/blob/8b96670cc5c8ef00278b42c0c7b62fe8a74510b9/assets/images/woman.png).
|
||||
|
||||
- `statue.png` [comes from tencent-ailab/IP-Adapter](https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/assets/images/statue.png).
|
||||
|
||||
## VAE without randomness
|
||||
|
||||
```diff
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 269 KiB |
Binary file not shown.
After Width: | Height: | Size: 1.1 MiB |
BIN
tests/e2e/test_diffusion_ref/statue.png
Normal file
BIN
tests/e2e/test_diffusion_ref/statue.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 330 KiB |
Loading…
Reference in a new issue