diff --git a/scripts/conversion/convert_diffusers_ip_adapter.py b/scripts/conversion/convert_diffusers_ip_adapter.py index ead2065..f2a6b84 100644 --- a/scripts/conversion/convert_diffusers_ip_adapter.py +++ b/scripts/conversion/convert_diffusers_ip_adapter.py @@ -70,12 +70,14 @@ def main() -> None: assert isinstance(weights, dict) assert sorted(weights.keys()) == ["image_proj", "ip_adapter"] + fine_grained = "latents" in weights["image_proj"] # aka IP-Adapter plus + match len(weights["ip_adapter"]): case 32: - ip_adapter = SD1IPAdapter(target=SD1UNet(in_channels=4)) + ip_adapter = SD1IPAdapter(target=SD1UNet(in_channels=4), fine_grained=fine_grained) cross_attn_mapping = CROSS_ATTN_MAPPING["sd15"] case 140: - ip_adapter = SDXLIPAdapter(target=SDXLUNet(in_channels=4)) + ip_adapter = SDXLIPAdapter(target=SDXLUNet(in_channels=4), fine_grained=fine_grained) cross_attn_mapping = CROSS_ATTN_MAPPING["sdxl"] case _: raise ValueError("Unexpected number of keys in input checkpoint") @@ -86,12 +88,43 @@ def main() -> None: state_dict: dict[str, torch.Tensor] = {} image_proj_weights = weights["image_proj"] - image_proj_state_dict: dict[str, torch.Tensor] = { - "Linear.weight": image_proj_weights["proj.weight"], - "Linear.bias": image_proj_weights["proj.bias"], - "LayerNorm.weight": image_proj_weights["norm.weight"], - "LayerNorm.bias": image_proj_weights["norm.bias"], - } + image_proj_state_dict: dict[str, torch.Tensor] + + if fine_grained: + w = image_proj_weights + image_proj_state_dict = { + "LatentsEncoder.Parallel.Parameter.parameter": w["latents"].squeeze(0), # drop batch dim = 1 + "Linear_1.weight": w["proj_in.weight"], + "Linear_1.bias": w["proj_in.bias"], + "Linear_2.weight": w["proj_out.weight"], + "Linear_2.bias": w["proj_out.bias"], + "LayerNorm.weight": w["norm_out.weight"], + "LayerNorm.bias": w["norm_out.bias"], + } + for i in range(4): + t_pfx, s_pfx = f"Transformer.TransformerLayer_{i+1}.Residual_", f"layers.{i}." + image_proj_state_dict.update( + { + f"{t_pfx}1.Chain.PerceiverAttention.Distribute.LayerNorm_1.weight": w[f"{s_pfx}0.norm1.weight"], + f"{t_pfx}1.Chain.PerceiverAttention.Distribute.LayerNorm_1.bias": w[f"{s_pfx}0.norm1.bias"], + f"{t_pfx}1.Chain.PerceiverAttention.Distribute.LayerNorm_2.weight": w[f"{s_pfx}0.norm2.weight"], + f"{t_pfx}1.Chain.PerceiverAttention.Distribute.LayerNorm_2.bias": w[f"{s_pfx}0.norm2.bias"], + f"{t_pfx}1.Chain.PerceiverAttention.Parallel.Chain_2.Linear.weight": w[f"{s_pfx}0.to_q.weight"], + f"{t_pfx}1.Chain.PerceiverAttention.Parallel.Chain_1.Linear.weight": w[f"{s_pfx}0.to_kv.weight"], + f"{t_pfx}1.Chain.PerceiverAttention.Linear.weight": w[f"{s_pfx}0.to_out.weight"], + f"{t_pfx}2.Chain.LayerNorm.weight": w[f"{s_pfx}1.0.weight"], + f"{t_pfx}2.Chain.LayerNorm.bias": w[f"{s_pfx}1.0.bias"], + f"{t_pfx}2.Chain.FeedForward.Linear_1.weight": w[f"{s_pfx}1.1.weight"], + f"{t_pfx}2.Chain.FeedForward.Linear_2.weight": w[f"{s_pfx}1.3.weight"], + } + ) + else: + image_proj_state_dict = { + "Linear.weight": image_proj_weights["proj.weight"], + "Linear.bias": image_proj_weights["proj.bias"], + "LayerNorm.weight": image_proj_weights["norm.weight"], + "LayerNorm.bias": image_proj_weights["norm.bias"], + } ip_adapter.image_proj.load_state_dict(state_dict=image_proj_state_dict) for k, v in image_proj_state_dict.items(): diff --git a/src/refiners/foundationals/latent_diffusion/image_prompt.py b/src/refiners/foundationals/latent_diffusion/image_prompt.py index 1734729..0e67ced 100644 --- a/src/refiners/foundationals/latent_diffusion/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/image_prompt.py @@ -1,14 +1,16 @@ from enum import IntEnum from functools import partial from typing import Generic, TypeVar, Any, Callable, TYPE_CHECKING +import math -from torch import Tensor, cat, zeros_like, device as Device, dtype as DType +from jaxtyping import Float +from torch import Tensor, cat, softmax, zeros_like, device as Device, dtype as DType from PIL import Image from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.adapters.lora import Lora from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH -from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d +from refiners.fluxion.context import Contexts from refiners.fluxion.layers.attentions import ScaledDotProductAttention from refiners.fluxion.utils import image_to_tensor, normalize import refiners.fluxion.layers as fl @@ -26,25 +28,219 @@ class ImageProjection(fl.Chain): self, clip_image_embedding_dim: int = 1024, clip_text_embedding_dim: int = 768, - sequence_length: int = 4, + num_tokens: int = 4, device: Device | str | None = None, dtype: DType | None = None, ) -> None: self.clip_image_embedding_dim = clip_image_embedding_dim self.clip_text_embedding_dim = clip_text_embedding_dim - self.sequence_length = sequence_length + self.num_tokens = num_tokens super().__init__( fl.Linear( in_features=clip_image_embedding_dim, - out_features=clip_text_embedding_dim * sequence_length, + out_features=clip_text_embedding_dim * num_tokens, device=device, dtype=dtype, ), - fl.Reshape(sequence_length, clip_text_embedding_dim), + fl.Reshape(num_tokens, clip_text_embedding_dim), fl.LayerNorm(normalized_shape=clip_text_embedding_dim, device=device, dtype=dtype), ) +class FeedForward(fl.Chain): + def __init__( + self, + embedding_dim: int, + feedforward_dim: int, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + self.embedding_dim = embedding_dim + self.feedforward_dim = feedforward_dim + super().__init__( + fl.Linear( + in_features=self.embedding_dim, + out_features=self.feedforward_dim, + bias=False, + device=device, + dtype=dtype, + ), + fl.GeLU(), + fl.Linear( + in_features=self.feedforward_dim, + out_features=self.embedding_dim, + bias=False, + device=device, + dtype=dtype, + ), + ) + + +# Adapted from https://github.com/tencent-ailab/IP-Adapter/blob/6212981/ip_adapter/resampler.py +# See also: +# - https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +# - https://github.com/lucidrains/flamingo-pytorch +class PerceiverScaledDotProductAttention(fl.Module): + def __init__(self, head_dim: int, num_heads: int) -> None: + super().__init__() + self.num_heads = num_heads + # See https://github.com/tencent-ailab/IP-Adapter/blob/6212981/ip_adapter/resampler.py#L69 + # -> "More stable with f16 than dividing afterwards" + self.scale = 1 / math.sqrt(math.sqrt(head_dim)) + + def forward( + self, + key_value: Float[Tensor, "batch sequence_length 2*head_dim*num_heads"], + query: Float[Tensor, "batch num_tokens head_dim*num_heads"], + ) -> Float[Tensor, "batch num_tokens head_dim*num_heads"]: + bs, length, _ = query.shape + key, value = key_value.chunk(2, dim=-1) + + q = self.reshape_tensor(query) + k = self.reshape_tensor(key) + v = self.reshape_tensor(value) + + attention = (q * self.scale) @ (k * self.scale).transpose(-2, -1) + attention = softmax(input=attention.float(), dim=-1).type(attention.dtype) + attention = attention @ v + + return attention.permute(0, 2, 1, 3).reshape(bs, length, -1) + + def reshape_tensor( + self, x: Float[Tensor, "batch length head_dim*num_heads"] + ) -> Float[Tensor, "batch num_heads length head_dim"]: + bs, length, _ = x.shape + x = x.view(bs, length, self.num_heads, -1) + x = x.transpose(1, 2) + x = x.reshape(bs, self.num_heads, length, -1) + return x + + +class PerceiverAttention(fl.Chain): + def __init__( + self, + embedding_dim: int, + head_dim: int = 64, + num_heads: int = 8, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + self.embedding_dim = embedding_dim + self.head_dim = head_dim + self.inner_dim = head_dim * num_heads + super().__init__( + fl.Distribute( + fl.LayerNorm(normalized_shape=self.embedding_dim, device=device, dtype=dtype), + fl.LayerNorm(normalized_shape=self.embedding_dim, device=device, dtype=dtype), + ), + fl.Parallel( + fl.Chain( + fl.Lambda(func=self.to_kv), + fl.Linear( + in_features=self.embedding_dim, + out_features=2 * self.inner_dim, + bias=False, + device=device, + dtype=dtype, + ), # Wkv + ), + fl.Chain( + fl.GetArg(index=1), + fl.Linear( + in_features=self.embedding_dim, + out_features=self.inner_dim, + bias=False, + device=device, + dtype=dtype, + ), # Wq + ), + ), + PerceiverScaledDotProductAttention(head_dim=head_dim, num_heads=num_heads), + fl.Linear( + in_features=self.inner_dim, out_features=self.embedding_dim, bias=False, device=device, dtype=dtype + ), + ) + + def to_kv(self, x: Tensor, latents: Tensor) -> Tensor: + return cat((x, latents), dim=-2) + + +class LatentsEncoder(fl.Chain): + def __init__( + self, + num_tokens: int, + embeddding_dim: int, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__( + fl.Parallel(fl.Identity(), fl.Parameter(num_tokens, embeddding_dim, device=device, dtype=dtype)), + fl.Lambda(lambda x, p: p.expand(x.shape[0], -1, -1)), + ) + + +class Transformer(fl.Chain): + pass + + +class TransformerLayer(fl.Chain): + pass + + +class PerceiverResampler(fl.Chain): + def __init__( + self, + latents_dim: int = 1024, + num_attention_layers: int = 8, + num_attention_heads: int = 16, + head_dim: int = 64, + num_tokens: int = 8, + input_dim: int = 768, + output_dim: int = 1024, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + self.latents_dim = latents_dim + self.num_attention_layers = num_attention_layers + self.head_dim = head_dim + self.num_attention_heads = num_attention_heads + self.num_tokens = num_tokens + self.input_dim = input_dim + self.output_dim = output_dim + self.feedforward_dim = 4 * self.latents_dim + super().__init__( + fl.Linear(in_features=input_dim, out_features=latents_dim, device=device, dtype=dtype), + fl.SetContext(context="perceiver_resampler", key="x"), + LatentsEncoder(num_tokens=num_tokens, embeddding_dim=latents_dim, device=device, dtype=dtype), + Transformer( + TransformerLayer( + fl.Residual( + fl.Parallel(fl.UseContext(context="perceiver_resampler", key="x"), fl.Identity()), + PerceiverAttention( + embedding_dim=latents_dim, + head_dim=head_dim, + num_heads=num_attention_heads, + device=device, + dtype=dtype, + ), + ), + fl.Residual( + fl.LayerNorm(normalized_shape=latents_dim, device=device, dtype=dtype), + FeedForward( + embedding_dim=latents_dim, feedforward_dim=self.feedforward_dim, device=device, dtype=dtype + ), + ), + ) + for _ in range(num_attention_layers) + ), + fl.Linear(in_features=latents_dim, out_features=output_dim, device=device, dtype=dtype), + fl.LayerNorm(normalized_shape=output_dim, device=device, dtype=dtype), + ) + + def init_context(self) -> Contexts: + return {"perceiver_resampler": {"x": None}} + + class _CrossAttnIndex(IntEnum): TXT_CROSS_ATTN = 0 # text cross-attention IMG_CROSS_ATTN = 1 # image cross-attention @@ -164,32 +360,29 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): class IPAdapter(Generic[T], fl.Chain, Adapter[T]): # Prevent PyTorch module registration _clip_image_encoder: list[CLIPImageEncoderH] - _image_proj: list[ImageProjection] + _grid_image_encoder: list[CLIPImageEncoderH] + _image_proj: list[fl.Module] def __init__( self, target: T, - clip_image_encoder: CLIPImageEncoderH | None = None, + clip_image_encoder: CLIPImageEncoderH, + image_proj: fl.Module, scale: float = 1.0, + fine_grained: bool = False, weights: dict[str, Tensor] | None = None, ) -> None: with self.setup_adapter(target): super().__init__(target) - cross_attn_2d = target.ensure_find(CrossAttentionBlock2d) - - self._clip_image_encoder = [clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)] - self._image_proj = [ - ImageProjection( - clip_image_embedding_dim=self.clip_image_encoder.output_dim, - clip_text_embedding_dim=cross_attn_2d.context_embedding_dim, - device=target.device, - dtype=target.dtype, - ) - ] + self.fine_grained = fine_grained + self._clip_image_encoder = [clip_image_encoder] + if fine_grained: + self._grid_image_encoder = [self.convert_to_grid_features(clip_image_encoder)] + self._image_proj = [image_proj] self.sub_adapters = [ - CrossAttentionAdapter(target=cross_attn, scale=scale) + CrossAttentionAdapter(target=cross_attn, scale=scale, image_sequence_length=self.image_proj.num_tokens) for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention)) ] @@ -214,7 +407,12 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): return self._clip_image_encoder[0] @property - def image_proj(self) -> ImageProjection: + def grid_image_encoder(self) -> CLIPImageEncoderH: + assert hasattr(self, "_grid_image_encoder") + return self._grid_image_encoder[0] + + @property + def image_proj(self) -> fl.Module: return self._image_proj[0] def inject(self: "TIPAdapter", parent: fl.Chain | None = None) -> "TIPAdapter": @@ -228,10 +426,16 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): super().eject() # These should be concatenated to the CLIP text embedding before setting the UNet context - def compute_clip_image_embedding(self, image_prompt: Tensor | None) -> Tensor: - clip_embedding = self.clip_image_encoder(image_prompt) + def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor: + image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder + clip_embedding = image_encoder(image_prompt) conditional_embedding = self.image_proj(clip_embedding) - negative_embedding = self.image_proj(zeros_like(clip_embedding)) + if not self.fine_grained: + negative_embedding = self.image_proj(zeros_like(clip_embedding)) + else: + # See https://github.com/tencent-ailab/IP-Adapter/blob/d580c50/tutorial_train_plus.py#L351-L352 + clip_embedding = image_encoder(zeros_like(image_prompt)) + negative_embedding = self.image_proj(clip_embedding) return cat((negative_embedding, conditional_embedding)) def preprocess_image( @@ -247,3 +451,16 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean, std=[0.26862954, 0.26130258, 0.27577711] if std is None else std, ) + + @staticmethod + def convert_to_grid_features(clip_image_encoder: CLIPImageEncoderH) -> CLIPImageEncoderH: + encoder_clone = clip_image_encoder.structural_copy() + assert isinstance(encoder_clone[-1], fl.Linear) # final proj + assert isinstance(encoder_clone[-2], fl.LayerNorm) # final normalization + assert isinstance(encoder_clone[-3], fl.Lambda) # pooling (classif token) + for _ in range(3): + encoder_clone.pop() + transfomer_layers = encoder_clone[-1] + assert isinstance(transfomer_layers, fl.Chain) and len(transfomer_layers) == 32 + transfomer_layers.pop() + return encoder_clone diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/image_prompt.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/image_prompt.py index fe16034..d1e3379 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/image_prompt.py @@ -1,6 +1,53 @@ -from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter +from torch import Tensor + +from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH +from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d +from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter, ImageProjection, PerceiverResampler from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1UNet class SD1IPAdapter(IPAdapter[SD1UNet]): - pass + def __init__( + self, + target: SD1UNet, + clip_image_encoder: CLIPImageEncoderH | None = None, + image_proj: ImageProjection | PerceiverResampler | None = None, + scale: float = 1.0, + fine_grained: bool = False, + weights: dict[str, Tensor] | None = None, + ) -> None: + clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype) + + if image_proj is None: + cross_attn_2d = target.ensure_find(CrossAttentionBlock2d) + image_proj = ( + ImageProjection( + clip_image_embedding_dim=clip_image_encoder.output_dim, + clip_text_embedding_dim=cross_attn_2d.context_embedding_dim, + device=target.device, + dtype=target.dtype, + ) + if not fine_grained + else PerceiverResampler( + latents_dim=cross_attn_2d.context_embedding_dim, + num_attention_layers=4, + num_attention_heads=12, + head_dim=64, + num_tokens=16, + input_dim=clip_image_encoder.embedding_dim, # = dim before final projection + output_dim=cross_attn_2d.context_embedding_dim, + device=target.device, + dtype=target.dtype, + ) + ) + elif fine_grained: + assert isinstance(image_proj, PerceiverResampler) + + super().__init__( + target=target, + clip_image_encoder=clip_image_encoder, + image_proj=image_proj, + scale=scale, + fine_grained=fine_grained, + weights=weights, + ) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/image_prompt.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/image_prompt.py index d78b00a..93d9e8b 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/image_prompt.py @@ -1,6 +1,53 @@ -from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter +from torch import Tensor + +from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH +from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d +from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter, ImageProjection, PerceiverResampler from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet class SDXLIPAdapter(IPAdapter[SDXLUNet]): - pass + def __init__( + self, + target: SDXLUNet, + clip_image_encoder: CLIPImageEncoderH | None = None, + image_proj: ImageProjection | PerceiverResampler | None = None, + scale: float = 1.0, + fine_grained: bool = False, + weights: dict[str, Tensor] | None = None, + ) -> None: + clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype) + + if image_proj is None: + cross_attn_2d = target.ensure_find(CrossAttentionBlock2d) + image_proj = ( + ImageProjection( + clip_image_embedding_dim=clip_image_encoder.output_dim, + clip_text_embedding_dim=cross_attn_2d.context_embedding_dim, + device=target.device, + dtype=target.dtype, + ) + if not fine_grained + else PerceiverResampler( + latents_dim=1280, # not `cross_attn_2d.context_embedding_dim` in this case + num_attention_layers=4, + num_attention_heads=20, + head_dim=64, + num_tokens=16, + input_dim=clip_image_encoder.embedding_dim, # = dim before final projection + output_dim=cross_attn_2d.context_embedding_dim, + device=target.device, + dtype=target.dtype, + ) + ) + elif fine_grained: + assert isinstance(image_proj, PerceiverResampler) + + super().__init__( + target=target, + clip_image_encoder=clip_image_encoder, + image_proj=image_proj, + scale=scale, + fine_grained=fine_grained, + weights=weights, + ) diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 921bc9b..115e79b 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -54,6 +54,11 @@ def woman_image(ref_path: Path) -> Image.Image: return Image.open(ref_path / "woman.png").convert("RGB") +@pytest.fixture(scope="module") +def statue_image(ref_path: Path) -> Image.Image: + return Image.open(ref_path / "statue.png").convert("RGB") + + @pytest.fixture def expected_image_std_random_init(ref_path: Path) -> Image.Image: return Image.open(ref_path / "expected_std_random_init.png").convert("RGB") @@ -79,11 +84,21 @@ def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image: return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB") +@pytest.fixture +def expected_image_ip_adapter_plus_statue(ref_path: Path) -> Image.Image: + return Image.open(ref_path / "expected_image_ip_adapter_plus_statue.png").convert("RGB") + + @pytest.fixture def expected_image_sdxl_ip_adapter_woman(ref_path: Path) -> Image.Image: return Image.open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB") +@pytest.fixture +def expected_image_sdxl_ip_adapter_plus_woman(ref_path: Path) -> Image.Image: + return Image.open(ref_path / "expected_image_sdxl_ip_adapter_plus_woman.png").convert("RGB") + + @pytest.fixture def expected_image_ip_adapter_controlnet(ref_path: Path) -> Image.Image: return Image.open(ref_path / "expected_ip_adapter_controlnet.png").convert("RGB") @@ -255,6 +270,15 @@ def ip_adapter_weights(test_weights_path: Path) -> Path: return ip_adapter_weights +@pytest.fixture(scope="module") +def ip_adapter_plus_weights(test_weights_path: Path) -> Path: + ip_adapter_weights = test_weights_path / "ip-adapter-plus_sd15.safetensors" + if not ip_adapter_weights.is_file(): + warn(f"could not find weights at {ip_adapter_weights}, skipping") + pytest.skip(allow_module_level=True) + return ip_adapter_weights + + @pytest.fixture(scope="module") def sdxl_ip_adapter_weights(test_weights_path: Path) -> Path: ip_adapter_weights = test_weights_path / "ip-adapter_sdxl_vit-h.safetensors" @@ -264,6 +288,15 @@ def sdxl_ip_adapter_weights(test_weights_path: Path) -> Path: return ip_adapter_weights +@pytest.fixture(scope="module") +def sdxl_ip_adapter_plus_weights(test_weights_path: Path) -> Path: + ip_adapter_weights = test_weights_path / "ip-adapter-plus_sdxl_vit-h.safetensors" + if not ip_adapter_weights.is_file(): + warn(f"could not find weights at {ip_adapter_weights}, skipping") + pytest.skip(allow_module_level=True) + return ip_adapter_weights + + @pytest.fixture(scope="module") def image_encoder_weights(test_weights_path: Path) -> Path: image_encoder_weights = test_weights_path / "CLIPImageEncoderH.safetensors" @@ -1189,6 +1222,115 @@ def test_diffusion_ip_adapter_controlnet( ensure_similar_images(predicted_image, expected_image_ip_adapter_controlnet) +@torch.no_grad() +def test_diffusion_ip_adapter_plus( + sd15_ddim_lda_ft_mse: StableDiffusion_1, + ip_adapter_plus_weights: Path, + image_encoder_weights: Path, + statue_image: Image.Image, + expected_image_ip_adapter_plus_statue: Image.Image, + test_device: torch.device, +): + sd15 = sd15_ddim_lda_ft_mse.to(dtype=torch.float16) + n_steps = 50 + + prompt = "best quality, high quality" + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + ip_adapter = SD1IPAdapter( + target=sd15.unet, weights=load_from_safetensors(ip_adapter_plus_weights), fine_grained=True + ) + ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) + ip_adapter.inject() + + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(statue_image)) + + negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2) + negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2) + + clip_text_embedding = torch.cat( + ( + torch.cat([negative_text_embedding, negative_image_embedding], dim=1), + torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1), + ) + ) + + sd15.set_num_inference_steps(n_steps) + + manual_seed(42) # seed=42 is used in the official IP-Adapter demo + x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) + + for step in sd15.steps: + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.decode_latents(x) + predicted_image.save("output.png") + + ensure_similar_images(predicted_image, expected_image_ip_adapter_plus_statue, min_psnr=35, min_ssim=0.98) + + +@torch.no_grad() +def test_diffusion_sdxl_ip_adapter_plus( + sdxl_ddim: StableDiffusion_XL, + sdxl_ip_adapter_plus_weights: Path, + image_encoder_weights: Path, + woman_image: Image.Image, + expected_image_sdxl_ip_adapter_plus_woman: Image.Image, + test_device: torch.device, +): + sdxl = sdxl_ddim.to(dtype=torch.float16) + n_steps = 30 + + prompt = "best quality, high quality" + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + ip_adapter = SDXLIPAdapter( + target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_plus_weights), fine_grained=True + ) + ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) + ip_adapter.inject() + + clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( + text=prompt, negative_text=negative_prompt + ) + clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image)) + + negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2) + negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2) + + clip_text_embedding = torch.cat( + ( + torch.cat([negative_text_embedding, negative_image_embedding], dim=1), + torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1), + ) + ) + time_ids = sdxl.default_time_ids + sdxl.set_num_inference_steps(n_steps) + + manual_seed(2) + x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16) + + for step in sdxl.steps: + x = sdxl( + x, + step=step, + clip_text_embedding=clip_text_embedding, + pooled_text_embedding=pooled_text_embedding, + time_ids=time_ids, + condition_scale=5, + ) + sdxl.lda.to(dtype=torch.float32) + predicted_image = sdxl.lda.decode_latents(x.to(dtype=torch.float32)) + predicted_image.save("output.png") + + ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman) + + @torch.no_grad() def test_sdxl_random_init( sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device diff --git a/tests/e2e/test_diffusion_ref/README.md b/tests/e2e/test_diffusion_ref/README.md index ece1f96..ca58f65 100644 --- a/tests/e2e/test_diffusion_ref/README.md +++ b/tests/e2e/test_diffusion_ref/README.md @@ -41,6 +41,7 @@ Special cases: - `expected_image_sdxl_ip_adapter_woman.png` - `expected_ip_adapter_controlnet.png` - `expected_t2i_adapter_xl_canny.png` + - `expected_image_sdxl_ip_adapter_plus_woman.png` ## Other images @@ -68,6 +69,8 @@ Special cases: - `woman.png` [comes from tencent-ailab/IP-Adapter](https://github.com/tencent-ailab/IP-Adapter/blob/8b96670cc5c8ef00278b42c0c7b62fe8a74510b9/assets/images/woman.png). +- `statue.png` [comes from tencent-ailab/IP-Adapter](https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/assets/images/statue.png). + ## VAE without randomness ```diff diff --git a/tests/e2e/test_diffusion_ref/expected_image_ip_adapter_plus_statue.png b/tests/e2e/test_diffusion_ref/expected_image_ip_adapter_plus_statue.png new file mode 100644 index 0000000..d719faa Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_image_ip_adapter_plus_statue.png differ diff --git a/tests/e2e/test_diffusion_ref/expected_image_sdxl_ip_adapter_plus_woman.png b/tests/e2e/test_diffusion_ref/expected_image_sdxl_ip_adapter_plus_woman.png new file mode 100644 index 0000000..ebbc319 Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_image_sdxl_ip_adapter_plus_woman.png differ diff --git a/tests/e2e/test_diffusion_ref/statue.png b/tests/e2e/test_diffusion_ref/statue.png new file mode 100644 index 0000000..68fdece Binary files /dev/null and b/tests/e2e/test_diffusion_ref/statue.png differ