add IP-Adapter plus (aka fine-grained features)

This commit is contained in:
Cédric Deltheil 2023-09-29 14:34:45 +02:00 committed by Cédric Deltheil
parent 88e454f1cb
commit 5fc6767a4a
9 changed files with 525 additions and 36 deletions

View file

@ -70,12 +70,14 @@ def main() -> None:
assert isinstance(weights, dict)
assert sorted(weights.keys()) == ["image_proj", "ip_adapter"]
fine_grained = "latents" in weights["image_proj"] # aka IP-Adapter plus
match len(weights["ip_adapter"]):
case 32:
ip_adapter = SD1IPAdapter(target=SD1UNet(in_channels=4))
ip_adapter = SD1IPAdapter(target=SD1UNet(in_channels=4), fine_grained=fine_grained)
cross_attn_mapping = CROSS_ATTN_MAPPING["sd15"]
case 140:
ip_adapter = SDXLIPAdapter(target=SDXLUNet(in_channels=4))
ip_adapter = SDXLIPAdapter(target=SDXLUNet(in_channels=4), fine_grained=fine_grained)
cross_attn_mapping = CROSS_ATTN_MAPPING["sdxl"]
case _:
raise ValueError("Unexpected number of keys in input checkpoint")
@ -86,12 +88,43 @@ def main() -> None:
state_dict: dict[str, torch.Tensor] = {}
image_proj_weights = weights["image_proj"]
image_proj_state_dict: dict[str, torch.Tensor] = {
"Linear.weight": image_proj_weights["proj.weight"],
"Linear.bias": image_proj_weights["proj.bias"],
"LayerNorm.weight": image_proj_weights["norm.weight"],
"LayerNorm.bias": image_proj_weights["norm.bias"],
}
image_proj_state_dict: dict[str, torch.Tensor]
if fine_grained:
w = image_proj_weights
image_proj_state_dict = {
"LatentsEncoder.Parallel.Parameter.parameter": w["latents"].squeeze(0), # drop batch dim = 1
"Linear_1.weight": w["proj_in.weight"],
"Linear_1.bias": w["proj_in.bias"],
"Linear_2.weight": w["proj_out.weight"],
"Linear_2.bias": w["proj_out.bias"],
"LayerNorm.weight": w["norm_out.weight"],
"LayerNorm.bias": w["norm_out.bias"],
}
for i in range(4):
t_pfx, s_pfx = f"Transformer.TransformerLayer_{i+1}.Residual_", f"layers.{i}."
image_proj_state_dict.update(
{
f"{t_pfx}1.Chain.PerceiverAttention.Distribute.LayerNorm_1.weight": w[f"{s_pfx}0.norm1.weight"],
f"{t_pfx}1.Chain.PerceiverAttention.Distribute.LayerNorm_1.bias": w[f"{s_pfx}0.norm1.bias"],
f"{t_pfx}1.Chain.PerceiverAttention.Distribute.LayerNorm_2.weight": w[f"{s_pfx}0.norm2.weight"],
f"{t_pfx}1.Chain.PerceiverAttention.Distribute.LayerNorm_2.bias": w[f"{s_pfx}0.norm2.bias"],
f"{t_pfx}1.Chain.PerceiverAttention.Parallel.Chain_2.Linear.weight": w[f"{s_pfx}0.to_q.weight"],
f"{t_pfx}1.Chain.PerceiverAttention.Parallel.Chain_1.Linear.weight": w[f"{s_pfx}0.to_kv.weight"],
f"{t_pfx}1.Chain.PerceiverAttention.Linear.weight": w[f"{s_pfx}0.to_out.weight"],
f"{t_pfx}2.Chain.LayerNorm.weight": w[f"{s_pfx}1.0.weight"],
f"{t_pfx}2.Chain.LayerNorm.bias": w[f"{s_pfx}1.0.bias"],
f"{t_pfx}2.Chain.FeedForward.Linear_1.weight": w[f"{s_pfx}1.1.weight"],
f"{t_pfx}2.Chain.FeedForward.Linear_2.weight": w[f"{s_pfx}1.3.weight"],
}
)
else:
image_proj_state_dict = {
"Linear.weight": image_proj_weights["proj.weight"],
"Linear.bias": image_proj_weights["proj.bias"],
"LayerNorm.weight": image_proj_weights["norm.weight"],
"LayerNorm.bias": image_proj_weights["norm.bias"],
}
ip_adapter.image_proj.load_state_dict(state_dict=image_proj_state_dict)
for k, v in image_proj_state_dict.items():

View file

@ -1,14 +1,16 @@
from enum import IntEnum
from functools import partial
from typing import Generic, TypeVar, Any, Callable, TYPE_CHECKING
import math
from torch import Tensor, cat, zeros_like, device as Device, dtype as DType
from jaxtyping import Float
from torch import Tensor, cat, softmax, zeros_like, device as Device, dtype as DType
from PIL import Image
from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.adapters.lora import Lora
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.fluxion.context import Contexts
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
from refiners.fluxion.utils import image_to_tensor, normalize
import refiners.fluxion.layers as fl
@ -26,25 +28,219 @@ class ImageProjection(fl.Chain):
self,
clip_image_embedding_dim: int = 1024,
clip_text_embedding_dim: int = 768,
sequence_length: int = 4,
num_tokens: int = 4,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.clip_image_embedding_dim = clip_image_embedding_dim
self.clip_text_embedding_dim = clip_text_embedding_dim
self.sequence_length = sequence_length
self.num_tokens = num_tokens
super().__init__(
fl.Linear(
in_features=clip_image_embedding_dim,
out_features=clip_text_embedding_dim * sequence_length,
out_features=clip_text_embedding_dim * num_tokens,
device=device,
dtype=dtype,
),
fl.Reshape(sequence_length, clip_text_embedding_dim),
fl.Reshape(num_tokens, clip_text_embedding_dim),
fl.LayerNorm(normalized_shape=clip_text_embedding_dim, device=device, dtype=dtype),
)
class FeedForward(fl.Chain):
def __init__(
self,
embedding_dim: int,
feedforward_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.feedforward_dim = feedforward_dim
super().__init__(
fl.Linear(
in_features=self.embedding_dim,
out_features=self.feedforward_dim,
bias=False,
device=device,
dtype=dtype,
),
fl.GeLU(),
fl.Linear(
in_features=self.feedforward_dim,
out_features=self.embedding_dim,
bias=False,
device=device,
dtype=dtype,
),
)
# Adapted from https://github.com/tencent-ailab/IP-Adapter/blob/6212981/ip_adapter/resampler.py
# See also:
# - https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
# - https://github.com/lucidrains/flamingo-pytorch
class PerceiverScaledDotProductAttention(fl.Module):
def __init__(self, head_dim: int, num_heads: int) -> None:
super().__init__()
self.num_heads = num_heads
# See https://github.com/tencent-ailab/IP-Adapter/blob/6212981/ip_adapter/resampler.py#L69
# -> "More stable with f16 than dividing afterwards"
self.scale = 1 / math.sqrt(math.sqrt(head_dim))
def forward(
self,
key_value: Float[Tensor, "batch sequence_length 2*head_dim*num_heads"],
query: Float[Tensor, "batch num_tokens head_dim*num_heads"],
) -> Float[Tensor, "batch num_tokens head_dim*num_heads"]:
bs, length, _ = query.shape
key, value = key_value.chunk(2, dim=-1)
q = self.reshape_tensor(query)
k = self.reshape_tensor(key)
v = self.reshape_tensor(value)
attention = (q * self.scale) @ (k * self.scale).transpose(-2, -1)
attention = softmax(input=attention.float(), dim=-1).type(attention.dtype)
attention = attention @ v
return attention.permute(0, 2, 1, 3).reshape(bs, length, -1)
def reshape_tensor(
self, x: Float[Tensor, "batch length head_dim*num_heads"]
) -> Float[Tensor, "batch num_heads length head_dim"]:
bs, length, _ = x.shape
x = x.view(bs, length, self.num_heads, -1)
x = x.transpose(1, 2)
x = x.reshape(bs, self.num_heads, length, -1)
return x
class PerceiverAttention(fl.Chain):
def __init__(
self,
embedding_dim: int,
head_dim: int = 64,
num_heads: int = 8,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.head_dim = head_dim
self.inner_dim = head_dim * num_heads
super().__init__(
fl.Distribute(
fl.LayerNorm(normalized_shape=self.embedding_dim, device=device, dtype=dtype),
fl.LayerNorm(normalized_shape=self.embedding_dim, device=device, dtype=dtype),
),
fl.Parallel(
fl.Chain(
fl.Lambda(func=self.to_kv),
fl.Linear(
in_features=self.embedding_dim,
out_features=2 * self.inner_dim,
bias=False,
device=device,
dtype=dtype,
), # Wkv
),
fl.Chain(
fl.GetArg(index=1),
fl.Linear(
in_features=self.embedding_dim,
out_features=self.inner_dim,
bias=False,
device=device,
dtype=dtype,
), # Wq
),
),
PerceiverScaledDotProductAttention(head_dim=head_dim, num_heads=num_heads),
fl.Linear(
in_features=self.inner_dim, out_features=self.embedding_dim, bias=False, device=device, dtype=dtype
),
)
def to_kv(self, x: Tensor, latents: Tensor) -> Tensor:
return cat((x, latents), dim=-2)
class LatentsEncoder(fl.Chain):
def __init__(
self,
num_tokens: int,
embeddding_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
fl.Parallel(fl.Identity(), fl.Parameter(num_tokens, embeddding_dim, device=device, dtype=dtype)),
fl.Lambda(lambda x, p: p.expand(x.shape[0], -1, -1)),
)
class Transformer(fl.Chain):
pass
class TransformerLayer(fl.Chain):
pass
class PerceiverResampler(fl.Chain):
def __init__(
self,
latents_dim: int = 1024,
num_attention_layers: int = 8,
num_attention_heads: int = 16,
head_dim: int = 64,
num_tokens: int = 8,
input_dim: int = 768,
output_dim: int = 1024,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.latents_dim = latents_dim
self.num_attention_layers = num_attention_layers
self.head_dim = head_dim
self.num_attention_heads = num_attention_heads
self.num_tokens = num_tokens
self.input_dim = input_dim
self.output_dim = output_dim
self.feedforward_dim = 4 * self.latents_dim
super().__init__(
fl.Linear(in_features=input_dim, out_features=latents_dim, device=device, dtype=dtype),
fl.SetContext(context="perceiver_resampler", key="x"),
LatentsEncoder(num_tokens=num_tokens, embeddding_dim=latents_dim, device=device, dtype=dtype),
Transformer(
TransformerLayer(
fl.Residual(
fl.Parallel(fl.UseContext(context="perceiver_resampler", key="x"), fl.Identity()),
PerceiverAttention(
embedding_dim=latents_dim,
head_dim=head_dim,
num_heads=num_attention_heads,
device=device,
dtype=dtype,
),
),
fl.Residual(
fl.LayerNorm(normalized_shape=latents_dim, device=device, dtype=dtype),
FeedForward(
embedding_dim=latents_dim, feedforward_dim=self.feedforward_dim, device=device, dtype=dtype
),
),
)
for _ in range(num_attention_layers)
),
fl.Linear(in_features=latents_dim, out_features=output_dim, device=device, dtype=dtype),
fl.LayerNorm(normalized_shape=output_dim, device=device, dtype=dtype),
)
def init_context(self) -> Contexts:
return {"perceiver_resampler": {"x": None}}
class _CrossAttnIndex(IntEnum):
TXT_CROSS_ATTN = 0 # text cross-attention
IMG_CROSS_ATTN = 1 # image cross-attention
@ -164,32 +360,29 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
# Prevent PyTorch module registration
_clip_image_encoder: list[CLIPImageEncoderH]
_image_proj: list[ImageProjection]
_grid_image_encoder: list[CLIPImageEncoderH]
_image_proj: list[fl.Module]
def __init__(
self,
target: T,
clip_image_encoder: CLIPImageEncoderH | None = None,
clip_image_encoder: CLIPImageEncoderH,
image_proj: fl.Module,
scale: float = 1.0,
fine_grained: bool = False,
weights: dict[str, Tensor] | None = None,
) -> None:
with self.setup_adapter(target):
super().__init__(target)
cross_attn_2d = target.ensure_find(CrossAttentionBlock2d)
self._clip_image_encoder = [clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)]
self._image_proj = [
ImageProjection(
clip_image_embedding_dim=self.clip_image_encoder.output_dim,
clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,
device=target.device,
dtype=target.dtype,
)
]
self.fine_grained = fine_grained
self._clip_image_encoder = [clip_image_encoder]
if fine_grained:
self._grid_image_encoder = [self.convert_to_grid_features(clip_image_encoder)]
self._image_proj = [image_proj]
self.sub_adapters = [
CrossAttentionAdapter(target=cross_attn, scale=scale)
CrossAttentionAdapter(target=cross_attn, scale=scale, image_sequence_length=self.image_proj.num_tokens)
for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention))
]
@ -214,7 +407,12 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
return self._clip_image_encoder[0]
@property
def image_proj(self) -> ImageProjection:
def grid_image_encoder(self) -> CLIPImageEncoderH:
assert hasattr(self, "_grid_image_encoder")
return self._grid_image_encoder[0]
@property
def image_proj(self) -> fl.Module:
return self._image_proj[0]
def inject(self: "TIPAdapter", parent: fl.Chain | None = None) -> "TIPAdapter":
@ -228,10 +426,16 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
super().eject()
# These should be concatenated to the CLIP text embedding before setting the UNet context
def compute_clip_image_embedding(self, image_prompt: Tensor | None) -> Tensor:
clip_embedding = self.clip_image_encoder(image_prompt)
def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor:
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder
clip_embedding = image_encoder(image_prompt)
conditional_embedding = self.image_proj(clip_embedding)
negative_embedding = self.image_proj(zeros_like(clip_embedding))
if not self.fine_grained:
negative_embedding = self.image_proj(zeros_like(clip_embedding))
else:
# See https://github.com/tencent-ailab/IP-Adapter/blob/d580c50/tutorial_train_plus.py#L351-L352
clip_embedding = image_encoder(zeros_like(image_prompt))
negative_embedding = self.image_proj(clip_embedding)
return cat((negative_embedding, conditional_embedding))
def preprocess_image(
@ -247,3 +451,16 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean,
std=[0.26862954, 0.26130258, 0.27577711] if std is None else std,
)
@staticmethod
def convert_to_grid_features(clip_image_encoder: CLIPImageEncoderH) -> CLIPImageEncoderH:
encoder_clone = clip_image_encoder.structural_copy()
assert isinstance(encoder_clone[-1], fl.Linear) # final proj
assert isinstance(encoder_clone[-2], fl.LayerNorm) # final normalization
assert isinstance(encoder_clone[-3], fl.Lambda) # pooling (classif token)
for _ in range(3):
encoder_clone.pop()
transfomer_layers = encoder_clone[-1]
assert isinstance(transfomer_layers, fl.Chain) and len(transfomer_layers) == 32
transfomer_layers.pop()
return encoder_clone

View file

@ -1,6 +1,53 @@
from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter
from torch import Tensor
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter, ImageProjection, PerceiverResampler
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1UNet
class SD1IPAdapter(IPAdapter[SD1UNet]):
pass
def __init__(
self,
target: SD1UNet,
clip_image_encoder: CLIPImageEncoderH | None = None,
image_proj: ImageProjection | PerceiverResampler | None = None,
scale: float = 1.0,
fine_grained: bool = False,
weights: dict[str, Tensor] | None = None,
) -> None:
clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)
if image_proj is None:
cross_attn_2d = target.ensure_find(CrossAttentionBlock2d)
image_proj = (
ImageProjection(
clip_image_embedding_dim=clip_image_encoder.output_dim,
clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,
device=target.device,
dtype=target.dtype,
)
if not fine_grained
else PerceiverResampler(
latents_dim=cross_attn_2d.context_embedding_dim,
num_attention_layers=4,
num_attention_heads=12,
head_dim=64,
num_tokens=16,
input_dim=clip_image_encoder.embedding_dim, # = dim before final projection
output_dim=cross_attn_2d.context_embedding_dim,
device=target.device,
dtype=target.dtype,
)
)
elif fine_grained:
assert isinstance(image_proj, PerceiverResampler)
super().__init__(
target=target,
clip_image_encoder=clip_image_encoder,
image_proj=image_proj,
scale=scale,
fine_grained=fine_grained,
weights=weights,
)

View file

@ -1,6 +1,53 @@
from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter
from torch import Tensor
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter, ImageProjection, PerceiverResampler
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
class SDXLIPAdapter(IPAdapter[SDXLUNet]):
pass
def __init__(
self,
target: SDXLUNet,
clip_image_encoder: CLIPImageEncoderH | None = None,
image_proj: ImageProjection | PerceiverResampler | None = None,
scale: float = 1.0,
fine_grained: bool = False,
weights: dict[str, Tensor] | None = None,
) -> None:
clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)
if image_proj is None:
cross_attn_2d = target.ensure_find(CrossAttentionBlock2d)
image_proj = (
ImageProjection(
clip_image_embedding_dim=clip_image_encoder.output_dim,
clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,
device=target.device,
dtype=target.dtype,
)
if not fine_grained
else PerceiverResampler(
latents_dim=1280, # not `cross_attn_2d.context_embedding_dim` in this case
num_attention_layers=4,
num_attention_heads=20,
head_dim=64,
num_tokens=16,
input_dim=clip_image_encoder.embedding_dim, # = dim before final projection
output_dim=cross_attn_2d.context_embedding_dim,
device=target.device,
dtype=target.dtype,
)
)
elif fine_grained:
assert isinstance(image_proj, PerceiverResampler)
super().__init__(
target=target,
clip_image_encoder=clip_image_encoder,
image_proj=image_proj,
scale=scale,
fine_grained=fine_grained,
weights=weights,
)

View file

@ -54,6 +54,11 @@ def woman_image(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "woman.png").convert("RGB")
@pytest.fixture(scope="module")
def statue_image(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "statue.png").convert("RGB")
@pytest.fixture
def expected_image_std_random_init(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_std_random_init.png").convert("RGB")
@ -79,11 +84,21 @@ def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB")
@pytest.fixture
def expected_image_ip_adapter_plus_statue(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_ip_adapter_plus_statue.png").convert("RGB")
@pytest.fixture
def expected_image_sdxl_ip_adapter_woman(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB")
@pytest.fixture
def expected_image_sdxl_ip_adapter_plus_woman(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_sdxl_ip_adapter_plus_woman.png").convert("RGB")
@pytest.fixture
def expected_image_ip_adapter_controlnet(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_ip_adapter_controlnet.png").convert("RGB")
@ -255,6 +270,15 @@ def ip_adapter_weights(test_weights_path: Path) -> Path:
return ip_adapter_weights
@pytest.fixture(scope="module")
def ip_adapter_plus_weights(test_weights_path: Path) -> Path:
ip_adapter_weights = test_weights_path / "ip-adapter-plus_sd15.safetensors"
if not ip_adapter_weights.is_file():
warn(f"could not find weights at {ip_adapter_weights}, skipping")
pytest.skip(allow_module_level=True)
return ip_adapter_weights
@pytest.fixture(scope="module")
def sdxl_ip_adapter_weights(test_weights_path: Path) -> Path:
ip_adapter_weights = test_weights_path / "ip-adapter_sdxl_vit-h.safetensors"
@ -264,6 +288,15 @@ def sdxl_ip_adapter_weights(test_weights_path: Path) -> Path:
return ip_adapter_weights
@pytest.fixture(scope="module")
def sdxl_ip_adapter_plus_weights(test_weights_path: Path) -> Path:
ip_adapter_weights = test_weights_path / "ip-adapter-plus_sdxl_vit-h.safetensors"
if not ip_adapter_weights.is_file():
warn(f"could not find weights at {ip_adapter_weights}, skipping")
pytest.skip(allow_module_level=True)
return ip_adapter_weights
@pytest.fixture(scope="module")
def image_encoder_weights(test_weights_path: Path) -> Path:
image_encoder_weights = test_weights_path / "CLIPImageEncoderH.safetensors"
@ -1189,6 +1222,115 @@ def test_diffusion_ip_adapter_controlnet(
ensure_similar_images(predicted_image, expected_image_ip_adapter_controlnet)
@torch.no_grad()
def test_diffusion_ip_adapter_plus(
sd15_ddim_lda_ft_mse: StableDiffusion_1,
ip_adapter_plus_weights: Path,
image_encoder_weights: Path,
statue_image: Image.Image,
expected_image_ip_adapter_plus_statue: Image.Image,
test_device: torch.device,
):
sd15 = sd15_ddim_lda_ft_mse.to(dtype=torch.float16)
n_steps = 50
prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SD1IPAdapter(
target=sd15.unet, weights=load_from_safetensors(ip_adapter_plus_weights), fine_grained=True
)
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.inject()
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(statue_image))
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
sd15.set_num_inference_steps(n_steps)
manual_seed(42) # seed=42 is used in the official IP-Adapter demo
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image.save("output.png")
ensure_similar_images(predicted_image, expected_image_ip_adapter_plus_statue, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
def test_diffusion_sdxl_ip_adapter_plus(
sdxl_ddim: StableDiffusion_XL,
sdxl_ip_adapter_plus_weights: Path,
image_encoder_weights: Path,
woman_image: Image.Image,
expected_image_sdxl_ip_adapter_plus_woman: Image.Image,
test_device: torch.device,
):
sdxl = sdxl_ddim.to(dtype=torch.float16)
n_steps = 30
prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SDXLIPAdapter(
target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_plus_weights), fine_grained=True
)
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.inject()
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt
)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(n_steps)
manual_seed(2)
x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=5,
)
sdxl.lda.to(dtype=torch.float32)
predicted_image = sdxl.lda.decode_latents(x.to(dtype=torch.float32))
predicted_image.save("output.png")
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman)
@torch.no_grad()
def test_sdxl_random_init(
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device

View file

@ -41,6 +41,7 @@ Special cases:
- `expected_image_sdxl_ip_adapter_woman.png`
- `expected_ip_adapter_controlnet.png`
- `expected_t2i_adapter_xl_canny.png`
- `expected_image_sdxl_ip_adapter_plus_woman.png`
## Other images
@ -68,6 +69,8 @@ Special cases:
- `woman.png` [comes from tencent-ailab/IP-Adapter](https://github.com/tencent-ailab/IP-Adapter/blob/8b96670cc5c8ef00278b42c0c7b62fe8a74510b9/assets/images/woman.png).
- `statue.png` [comes from tencent-ailab/IP-Adapter](https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/assets/images/statue.png).
## VAE without randomness
```diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 269 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 330 KiB