add IP-Adapter plus (aka fine-grained features)

This commit is contained in:
Cédric Deltheil 2023-09-29 14:34:45 +02:00 committed by Cédric Deltheil
parent 88e454f1cb
commit 5fc6767a4a
9 changed files with 525 additions and 36 deletions

View file

@ -70,12 +70,14 @@ def main() -> None:
assert isinstance(weights, dict) assert isinstance(weights, dict)
assert sorted(weights.keys()) == ["image_proj", "ip_adapter"] assert sorted(weights.keys()) == ["image_proj", "ip_adapter"]
fine_grained = "latents" in weights["image_proj"] # aka IP-Adapter plus
match len(weights["ip_adapter"]): match len(weights["ip_adapter"]):
case 32: case 32:
ip_adapter = SD1IPAdapter(target=SD1UNet(in_channels=4)) ip_adapter = SD1IPAdapter(target=SD1UNet(in_channels=4), fine_grained=fine_grained)
cross_attn_mapping = CROSS_ATTN_MAPPING["sd15"] cross_attn_mapping = CROSS_ATTN_MAPPING["sd15"]
case 140: case 140:
ip_adapter = SDXLIPAdapter(target=SDXLUNet(in_channels=4)) ip_adapter = SDXLIPAdapter(target=SDXLUNet(in_channels=4), fine_grained=fine_grained)
cross_attn_mapping = CROSS_ATTN_MAPPING["sdxl"] cross_attn_mapping = CROSS_ATTN_MAPPING["sdxl"]
case _: case _:
raise ValueError("Unexpected number of keys in input checkpoint") raise ValueError("Unexpected number of keys in input checkpoint")
@ -86,12 +88,43 @@ def main() -> None:
state_dict: dict[str, torch.Tensor] = {} state_dict: dict[str, torch.Tensor] = {}
image_proj_weights = weights["image_proj"] image_proj_weights = weights["image_proj"]
image_proj_state_dict: dict[str, torch.Tensor] = { image_proj_state_dict: dict[str, torch.Tensor]
"Linear.weight": image_proj_weights["proj.weight"],
"Linear.bias": image_proj_weights["proj.bias"], if fine_grained:
"LayerNorm.weight": image_proj_weights["norm.weight"], w = image_proj_weights
"LayerNorm.bias": image_proj_weights["norm.bias"], image_proj_state_dict = {
} "LatentsEncoder.Parallel.Parameter.parameter": w["latents"].squeeze(0), # drop batch dim = 1
"Linear_1.weight": w["proj_in.weight"],
"Linear_1.bias": w["proj_in.bias"],
"Linear_2.weight": w["proj_out.weight"],
"Linear_2.bias": w["proj_out.bias"],
"LayerNorm.weight": w["norm_out.weight"],
"LayerNorm.bias": w["norm_out.bias"],
}
for i in range(4):
t_pfx, s_pfx = f"Transformer.TransformerLayer_{i+1}.Residual_", f"layers.{i}."
image_proj_state_dict.update(
{
f"{t_pfx}1.Chain.PerceiverAttention.Distribute.LayerNorm_1.weight": w[f"{s_pfx}0.norm1.weight"],
f"{t_pfx}1.Chain.PerceiverAttention.Distribute.LayerNorm_1.bias": w[f"{s_pfx}0.norm1.bias"],
f"{t_pfx}1.Chain.PerceiverAttention.Distribute.LayerNorm_2.weight": w[f"{s_pfx}0.norm2.weight"],
f"{t_pfx}1.Chain.PerceiverAttention.Distribute.LayerNorm_2.bias": w[f"{s_pfx}0.norm2.bias"],
f"{t_pfx}1.Chain.PerceiverAttention.Parallel.Chain_2.Linear.weight": w[f"{s_pfx}0.to_q.weight"],
f"{t_pfx}1.Chain.PerceiverAttention.Parallel.Chain_1.Linear.weight": w[f"{s_pfx}0.to_kv.weight"],
f"{t_pfx}1.Chain.PerceiverAttention.Linear.weight": w[f"{s_pfx}0.to_out.weight"],
f"{t_pfx}2.Chain.LayerNorm.weight": w[f"{s_pfx}1.0.weight"],
f"{t_pfx}2.Chain.LayerNorm.bias": w[f"{s_pfx}1.0.bias"],
f"{t_pfx}2.Chain.FeedForward.Linear_1.weight": w[f"{s_pfx}1.1.weight"],
f"{t_pfx}2.Chain.FeedForward.Linear_2.weight": w[f"{s_pfx}1.3.weight"],
}
)
else:
image_proj_state_dict = {
"Linear.weight": image_proj_weights["proj.weight"],
"Linear.bias": image_proj_weights["proj.bias"],
"LayerNorm.weight": image_proj_weights["norm.weight"],
"LayerNorm.bias": image_proj_weights["norm.bias"],
}
ip_adapter.image_proj.load_state_dict(state_dict=image_proj_state_dict) ip_adapter.image_proj.load_state_dict(state_dict=image_proj_state_dict)
for k, v in image_proj_state_dict.items(): for k, v in image_proj_state_dict.items():

View file

@ -1,14 +1,16 @@
from enum import IntEnum from enum import IntEnum
from functools import partial from functools import partial
from typing import Generic, TypeVar, Any, Callable, TYPE_CHECKING from typing import Generic, TypeVar, Any, Callable, TYPE_CHECKING
import math
from torch import Tensor, cat, zeros_like, device as Device, dtype as DType from jaxtyping import Float
from torch import Tensor, cat, softmax, zeros_like, device as Device, dtype as DType
from PIL import Image from PIL import Image
from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.adapters.lora import Lora from refiners.fluxion.adapters.lora import Lora
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d from refiners.fluxion.context import Contexts
from refiners.fluxion.layers.attentions import ScaledDotProductAttention from refiners.fluxion.layers.attentions import ScaledDotProductAttention
from refiners.fluxion.utils import image_to_tensor, normalize from refiners.fluxion.utils import image_to_tensor, normalize
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
@ -26,25 +28,219 @@ class ImageProjection(fl.Chain):
self, self,
clip_image_embedding_dim: int = 1024, clip_image_embedding_dim: int = 1024,
clip_text_embedding_dim: int = 768, clip_text_embedding_dim: int = 768,
sequence_length: int = 4, num_tokens: int = 4,
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
) -> None: ) -> None:
self.clip_image_embedding_dim = clip_image_embedding_dim self.clip_image_embedding_dim = clip_image_embedding_dim
self.clip_text_embedding_dim = clip_text_embedding_dim self.clip_text_embedding_dim = clip_text_embedding_dim
self.sequence_length = sequence_length self.num_tokens = num_tokens
super().__init__( super().__init__(
fl.Linear( fl.Linear(
in_features=clip_image_embedding_dim, in_features=clip_image_embedding_dim,
out_features=clip_text_embedding_dim * sequence_length, out_features=clip_text_embedding_dim * num_tokens,
device=device, device=device,
dtype=dtype, dtype=dtype,
), ),
fl.Reshape(sequence_length, clip_text_embedding_dim), fl.Reshape(num_tokens, clip_text_embedding_dim),
fl.LayerNorm(normalized_shape=clip_text_embedding_dim, device=device, dtype=dtype), fl.LayerNorm(normalized_shape=clip_text_embedding_dim, device=device, dtype=dtype),
) )
class FeedForward(fl.Chain):
def __init__(
self,
embedding_dim: int,
feedforward_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.feedforward_dim = feedforward_dim
super().__init__(
fl.Linear(
in_features=self.embedding_dim,
out_features=self.feedforward_dim,
bias=False,
device=device,
dtype=dtype,
),
fl.GeLU(),
fl.Linear(
in_features=self.feedforward_dim,
out_features=self.embedding_dim,
bias=False,
device=device,
dtype=dtype,
),
)
# Adapted from https://github.com/tencent-ailab/IP-Adapter/blob/6212981/ip_adapter/resampler.py
# See also:
# - https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
# - https://github.com/lucidrains/flamingo-pytorch
class PerceiverScaledDotProductAttention(fl.Module):
def __init__(self, head_dim: int, num_heads: int) -> None:
super().__init__()
self.num_heads = num_heads
# See https://github.com/tencent-ailab/IP-Adapter/blob/6212981/ip_adapter/resampler.py#L69
# -> "More stable with f16 than dividing afterwards"
self.scale = 1 / math.sqrt(math.sqrt(head_dim))
def forward(
self,
key_value: Float[Tensor, "batch sequence_length 2*head_dim*num_heads"],
query: Float[Tensor, "batch num_tokens head_dim*num_heads"],
) -> Float[Tensor, "batch num_tokens head_dim*num_heads"]:
bs, length, _ = query.shape
key, value = key_value.chunk(2, dim=-1)
q = self.reshape_tensor(query)
k = self.reshape_tensor(key)
v = self.reshape_tensor(value)
attention = (q * self.scale) @ (k * self.scale).transpose(-2, -1)
attention = softmax(input=attention.float(), dim=-1).type(attention.dtype)
attention = attention @ v
return attention.permute(0, 2, 1, 3).reshape(bs, length, -1)
def reshape_tensor(
self, x: Float[Tensor, "batch length head_dim*num_heads"]
) -> Float[Tensor, "batch num_heads length head_dim"]:
bs, length, _ = x.shape
x = x.view(bs, length, self.num_heads, -1)
x = x.transpose(1, 2)
x = x.reshape(bs, self.num_heads, length, -1)
return x
class PerceiverAttention(fl.Chain):
def __init__(
self,
embedding_dim: int,
head_dim: int = 64,
num_heads: int = 8,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.head_dim = head_dim
self.inner_dim = head_dim * num_heads
super().__init__(
fl.Distribute(
fl.LayerNorm(normalized_shape=self.embedding_dim, device=device, dtype=dtype),
fl.LayerNorm(normalized_shape=self.embedding_dim, device=device, dtype=dtype),
),
fl.Parallel(
fl.Chain(
fl.Lambda(func=self.to_kv),
fl.Linear(
in_features=self.embedding_dim,
out_features=2 * self.inner_dim,
bias=False,
device=device,
dtype=dtype,
), # Wkv
),
fl.Chain(
fl.GetArg(index=1),
fl.Linear(
in_features=self.embedding_dim,
out_features=self.inner_dim,
bias=False,
device=device,
dtype=dtype,
), # Wq
),
),
PerceiverScaledDotProductAttention(head_dim=head_dim, num_heads=num_heads),
fl.Linear(
in_features=self.inner_dim, out_features=self.embedding_dim, bias=False, device=device, dtype=dtype
),
)
def to_kv(self, x: Tensor, latents: Tensor) -> Tensor:
return cat((x, latents), dim=-2)
class LatentsEncoder(fl.Chain):
def __init__(
self,
num_tokens: int,
embeddding_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
fl.Parallel(fl.Identity(), fl.Parameter(num_tokens, embeddding_dim, device=device, dtype=dtype)),
fl.Lambda(lambda x, p: p.expand(x.shape[0], -1, -1)),
)
class Transformer(fl.Chain):
pass
class TransformerLayer(fl.Chain):
pass
class PerceiverResampler(fl.Chain):
def __init__(
self,
latents_dim: int = 1024,
num_attention_layers: int = 8,
num_attention_heads: int = 16,
head_dim: int = 64,
num_tokens: int = 8,
input_dim: int = 768,
output_dim: int = 1024,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.latents_dim = latents_dim
self.num_attention_layers = num_attention_layers
self.head_dim = head_dim
self.num_attention_heads = num_attention_heads
self.num_tokens = num_tokens
self.input_dim = input_dim
self.output_dim = output_dim
self.feedforward_dim = 4 * self.latents_dim
super().__init__(
fl.Linear(in_features=input_dim, out_features=latents_dim, device=device, dtype=dtype),
fl.SetContext(context="perceiver_resampler", key="x"),
LatentsEncoder(num_tokens=num_tokens, embeddding_dim=latents_dim, device=device, dtype=dtype),
Transformer(
TransformerLayer(
fl.Residual(
fl.Parallel(fl.UseContext(context="perceiver_resampler", key="x"), fl.Identity()),
PerceiverAttention(
embedding_dim=latents_dim,
head_dim=head_dim,
num_heads=num_attention_heads,
device=device,
dtype=dtype,
),
),
fl.Residual(
fl.LayerNorm(normalized_shape=latents_dim, device=device, dtype=dtype),
FeedForward(
embedding_dim=latents_dim, feedforward_dim=self.feedforward_dim, device=device, dtype=dtype
),
),
)
for _ in range(num_attention_layers)
),
fl.Linear(in_features=latents_dim, out_features=output_dim, device=device, dtype=dtype),
fl.LayerNorm(normalized_shape=output_dim, device=device, dtype=dtype),
)
def init_context(self) -> Contexts:
return {"perceiver_resampler": {"x": None}}
class _CrossAttnIndex(IntEnum): class _CrossAttnIndex(IntEnum):
TXT_CROSS_ATTN = 0 # text cross-attention TXT_CROSS_ATTN = 0 # text cross-attention
IMG_CROSS_ATTN = 1 # image cross-attention IMG_CROSS_ATTN = 1 # image cross-attention
@ -164,32 +360,29 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
class IPAdapter(Generic[T], fl.Chain, Adapter[T]): class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
# Prevent PyTorch module registration # Prevent PyTorch module registration
_clip_image_encoder: list[CLIPImageEncoderH] _clip_image_encoder: list[CLIPImageEncoderH]
_image_proj: list[ImageProjection] _grid_image_encoder: list[CLIPImageEncoderH]
_image_proj: list[fl.Module]
def __init__( def __init__(
self, self,
target: T, target: T,
clip_image_encoder: CLIPImageEncoderH | None = None, clip_image_encoder: CLIPImageEncoderH,
image_proj: fl.Module,
scale: float = 1.0, scale: float = 1.0,
fine_grained: bool = False,
weights: dict[str, Tensor] | None = None, weights: dict[str, Tensor] | None = None,
) -> None: ) -> None:
with self.setup_adapter(target): with self.setup_adapter(target):
super().__init__(target) super().__init__(target)
cross_attn_2d = target.ensure_find(CrossAttentionBlock2d) self.fine_grained = fine_grained
self._clip_image_encoder = [clip_image_encoder]
self._clip_image_encoder = [clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)] if fine_grained:
self._image_proj = [ self._grid_image_encoder = [self.convert_to_grid_features(clip_image_encoder)]
ImageProjection( self._image_proj = [image_proj]
clip_image_embedding_dim=self.clip_image_encoder.output_dim,
clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,
device=target.device,
dtype=target.dtype,
)
]
self.sub_adapters = [ self.sub_adapters = [
CrossAttentionAdapter(target=cross_attn, scale=scale) CrossAttentionAdapter(target=cross_attn, scale=scale, image_sequence_length=self.image_proj.num_tokens)
for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention)) for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention))
] ]
@ -214,7 +407,12 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
return self._clip_image_encoder[0] return self._clip_image_encoder[0]
@property @property
def image_proj(self) -> ImageProjection: def grid_image_encoder(self) -> CLIPImageEncoderH:
assert hasattr(self, "_grid_image_encoder")
return self._grid_image_encoder[0]
@property
def image_proj(self) -> fl.Module:
return self._image_proj[0] return self._image_proj[0]
def inject(self: "TIPAdapter", parent: fl.Chain | None = None) -> "TIPAdapter": def inject(self: "TIPAdapter", parent: fl.Chain | None = None) -> "TIPAdapter":
@ -228,10 +426,16 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
super().eject() super().eject()
# These should be concatenated to the CLIP text embedding before setting the UNet context # These should be concatenated to the CLIP text embedding before setting the UNet context
def compute_clip_image_embedding(self, image_prompt: Tensor | None) -> Tensor: def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor:
clip_embedding = self.clip_image_encoder(image_prompt) image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder
clip_embedding = image_encoder(image_prompt)
conditional_embedding = self.image_proj(clip_embedding) conditional_embedding = self.image_proj(clip_embedding)
negative_embedding = self.image_proj(zeros_like(clip_embedding)) if not self.fine_grained:
negative_embedding = self.image_proj(zeros_like(clip_embedding))
else:
# See https://github.com/tencent-ailab/IP-Adapter/blob/d580c50/tutorial_train_plus.py#L351-L352
clip_embedding = image_encoder(zeros_like(image_prompt))
negative_embedding = self.image_proj(clip_embedding)
return cat((negative_embedding, conditional_embedding)) return cat((negative_embedding, conditional_embedding))
def preprocess_image( def preprocess_image(
@ -247,3 +451,16 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean, mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean,
std=[0.26862954, 0.26130258, 0.27577711] if std is None else std, std=[0.26862954, 0.26130258, 0.27577711] if std is None else std,
) )
@staticmethod
def convert_to_grid_features(clip_image_encoder: CLIPImageEncoderH) -> CLIPImageEncoderH:
encoder_clone = clip_image_encoder.structural_copy()
assert isinstance(encoder_clone[-1], fl.Linear) # final proj
assert isinstance(encoder_clone[-2], fl.LayerNorm) # final normalization
assert isinstance(encoder_clone[-3], fl.Lambda) # pooling (classif token)
for _ in range(3):
encoder_clone.pop()
transfomer_layers = encoder_clone[-1]
assert isinstance(transfomer_layers, fl.Chain) and len(transfomer_layers) == 32
transfomer_layers.pop()
return encoder_clone

View file

@ -1,6 +1,53 @@
from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter from torch import Tensor
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter, ImageProjection, PerceiverResampler
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1UNet from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1UNet
class SD1IPAdapter(IPAdapter[SD1UNet]): class SD1IPAdapter(IPAdapter[SD1UNet]):
pass def __init__(
self,
target: SD1UNet,
clip_image_encoder: CLIPImageEncoderH | None = None,
image_proj: ImageProjection | PerceiverResampler | None = None,
scale: float = 1.0,
fine_grained: bool = False,
weights: dict[str, Tensor] | None = None,
) -> None:
clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)
if image_proj is None:
cross_attn_2d = target.ensure_find(CrossAttentionBlock2d)
image_proj = (
ImageProjection(
clip_image_embedding_dim=clip_image_encoder.output_dim,
clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,
device=target.device,
dtype=target.dtype,
)
if not fine_grained
else PerceiverResampler(
latents_dim=cross_attn_2d.context_embedding_dim,
num_attention_layers=4,
num_attention_heads=12,
head_dim=64,
num_tokens=16,
input_dim=clip_image_encoder.embedding_dim, # = dim before final projection
output_dim=cross_attn_2d.context_embedding_dim,
device=target.device,
dtype=target.dtype,
)
)
elif fine_grained:
assert isinstance(image_proj, PerceiverResampler)
super().__init__(
target=target,
clip_image_encoder=clip_image_encoder,
image_proj=image_proj,
scale=scale,
fine_grained=fine_grained,
weights=weights,
)

View file

@ -1,6 +1,53 @@
from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter from torch import Tensor
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter, ImageProjection, PerceiverResampler
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
class SDXLIPAdapter(IPAdapter[SDXLUNet]): class SDXLIPAdapter(IPAdapter[SDXLUNet]):
pass def __init__(
self,
target: SDXLUNet,
clip_image_encoder: CLIPImageEncoderH | None = None,
image_proj: ImageProjection | PerceiverResampler | None = None,
scale: float = 1.0,
fine_grained: bool = False,
weights: dict[str, Tensor] | None = None,
) -> None:
clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)
if image_proj is None:
cross_attn_2d = target.ensure_find(CrossAttentionBlock2d)
image_proj = (
ImageProjection(
clip_image_embedding_dim=clip_image_encoder.output_dim,
clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,
device=target.device,
dtype=target.dtype,
)
if not fine_grained
else PerceiverResampler(
latents_dim=1280, # not `cross_attn_2d.context_embedding_dim` in this case
num_attention_layers=4,
num_attention_heads=20,
head_dim=64,
num_tokens=16,
input_dim=clip_image_encoder.embedding_dim, # = dim before final projection
output_dim=cross_attn_2d.context_embedding_dim,
device=target.device,
dtype=target.dtype,
)
)
elif fine_grained:
assert isinstance(image_proj, PerceiverResampler)
super().__init__(
target=target,
clip_image_encoder=clip_image_encoder,
image_proj=image_proj,
scale=scale,
fine_grained=fine_grained,
weights=weights,
)

View file

@ -54,6 +54,11 @@ def woman_image(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "woman.png").convert("RGB") return Image.open(ref_path / "woman.png").convert("RGB")
@pytest.fixture(scope="module")
def statue_image(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "statue.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_std_random_init(ref_path: Path) -> Image.Image: def expected_image_std_random_init(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_std_random_init.png").convert("RGB") return Image.open(ref_path / "expected_std_random_init.png").convert("RGB")
@ -79,11 +84,21 @@ def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB") return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB")
@pytest.fixture
def expected_image_ip_adapter_plus_statue(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_ip_adapter_plus_statue.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_sdxl_ip_adapter_woman(ref_path: Path) -> Image.Image: def expected_image_sdxl_ip_adapter_woman(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB") return Image.open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB")
@pytest.fixture
def expected_image_sdxl_ip_adapter_plus_woman(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_sdxl_ip_adapter_plus_woman.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_ip_adapter_controlnet(ref_path: Path) -> Image.Image: def expected_image_ip_adapter_controlnet(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_ip_adapter_controlnet.png").convert("RGB") return Image.open(ref_path / "expected_ip_adapter_controlnet.png").convert("RGB")
@ -255,6 +270,15 @@ def ip_adapter_weights(test_weights_path: Path) -> Path:
return ip_adapter_weights return ip_adapter_weights
@pytest.fixture(scope="module")
def ip_adapter_plus_weights(test_weights_path: Path) -> Path:
ip_adapter_weights = test_weights_path / "ip-adapter-plus_sd15.safetensors"
if not ip_adapter_weights.is_file():
warn(f"could not find weights at {ip_adapter_weights}, skipping")
pytest.skip(allow_module_level=True)
return ip_adapter_weights
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def sdxl_ip_adapter_weights(test_weights_path: Path) -> Path: def sdxl_ip_adapter_weights(test_weights_path: Path) -> Path:
ip_adapter_weights = test_weights_path / "ip-adapter_sdxl_vit-h.safetensors" ip_adapter_weights = test_weights_path / "ip-adapter_sdxl_vit-h.safetensors"
@ -264,6 +288,15 @@ def sdxl_ip_adapter_weights(test_weights_path: Path) -> Path:
return ip_adapter_weights return ip_adapter_weights
@pytest.fixture(scope="module")
def sdxl_ip_adapter_plus_weights(test_weights_path: Path) -> Path:
ip_adapter_weights = test_weights_path / "ip-adapter-plus_sdxl_vit-h.safetensors"
if not ip_adapter_weights.is_file():
warn(f"could not find weights at {ip_adapter_weights}, skipping")
pytest.skip(allow_module_level=True)
return ip_adapter_weights
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def image_encoder_weights(test_weights_path: Path) -> Path: def image_encoder_weights(test_weights_path: Path) -> Path:
image_encoder_weights = test_weights_path / "CLIPImageEncoderH.safetensors" image_encoder_weights = test_weights_path / "CLIPImageEncoderH.safetensors"
@ -1189,6 +1222,115 @@ def test_diffusion_ip_adapter_controlnet(
ensure_similar_images(predicted_image, expected_image_ip_adapter_controlnet) ensure_similar_images(predicted_image, expected_image_ip_adapter_controlnet)
@torch.no_grad()
def test_diffusion_ip_adapter_plus(
sd15_ddim_lda_ft_mse: StableDiffusion_1,
ip_adapter_plus_weights: Path,
image_encoder_weights: Path,
statue_image: Image.Image,
expected_image_ip_adapter_plus_statue: Image.Image,
test_device: torch.device,
):
sd15 = sd15_ddim_lda_ft_mse.to(dtype=torch.float16)
n_steps = 50
prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SD1IPAdapter(
target=sd15.unet, weights=load_from_safetensors(ip_adapter_plus_weights), fine_grained=True
)
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.inject()
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(statue_image))
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
sd15.set_num_inference_steps(n_steps)
manual_seed(42) # seed=42 is used in the official IP-Adapter demo
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image.save("output.png")
ensure_similar_images(predicted_image, expected_image_ip_adapter_plus_statue, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
def test_diffusion_sdxl_ip_adapter_plus(
sdxl_ddim: StableDiffusion_XL,
sdxl_ip_adapter_plus_weights: Path,
image_encoder_weights: Path,
woman_image: Image.Image,
expected_image_sdxl_ip_adapter_plus_woman: Image.Image,
test_device: torch.device,
):
sdxl = sdxl_ddim.to(dtype=torch.float16)
n_steps = 30
prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SDXLIPAdapter(
target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_plus_weights), fine_grained=True
)
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.inject()
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt
)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(n_steps)
manual_seed(2)
x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=5,
)
sdxl.lda.to(dtype=torch.float32)
predicted_image = sdxl.lda.decode_latents(x.to(dtype=torch.float32))
predicted_image.save("output.png")
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman)
@torch.no_grad() @torch.no_grad()
def test_sdxl_random_init( def test_sdxl_random_init(
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device

View file

@ -41,6 +41,7 @@ Special cases:
- `expected_image_sdxl_ip_adapter_woman.png` - `expected_image_sdxl_ip_adapter_woman.png`
- `expected_ip_adapter_controlnet.png` - `expected_ip_adapter_controlnet.png`
- `expected_t2i_adapter_xl_canny.png` - `expected_t2i_adapter_xl_canny.png`
- `expected_image_sdxl_ip_adapter_plus_woman.png`
## Other images ## Other images
@ -68,6 +69,8 @@ Special cases:
- `woman.png` [comes from tencent-ailab/IP-Adapter](https://github.com/tencent-ailab/IP-Adapter/blob/8b96670cc5c8ef00278b42c0c7b62fe8a74510b9/assets/images/woman.png). - `woman.png` [comes from tencent-ailab/IP-Adapter](https://github.com/tencent-ailab/IP-Adapter/blob/8b96670cc5c8ef00278b42c0c7b62fe8a74510b9/assets/images/woman.png).
- `statue.png` [comes from tencent-ailab/IP-Adapter](https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/assets/images/statue.png).
## VAE without randomness ## VAE without randomness
```diff ```diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 269 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 330 KiB