diff --git a/scripts/conversion/convert_diffusers_ip_adapter.py b/scripts/conversion/convert_diffusers_ip_adapter.py index 4e92f59..ead2065 100644 --- a/scripts/conversion/convert_diffusers_ip_adapter.py +++ b/scripts/conversion/convert_diffusers_ip_adapter.py @@ -4,25 +4,61 @@ import argparse import torch -from refiners.foundationals.latent_diffusion import SD1UNet, SD1IPAdapter +from refiners.foundationals.latent_diffusion import SD1UNet, SD1IPAdapter, SDXLUNet, SDXLIPAdapter from refiners.fluxion.utils import save_to_safetensors +# Running: +# +# from diffusers import UNet2DConditionModel +# unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") +# for k in unet.attn_processors.keys(): +# print(k) +# +# Gives: +# +# down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor +# down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor +# ... +# down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor +# up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor +# up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor +# ... +# up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor +# mid_block.attentions.0.transformer_blocks.0.attn1.processor +# mid_block.attentions.0.transformer_blocks.0.attn2.processor +# +# With attn1=self-attention and attn2=cross-attention, and middle block in last position. So in terms of increasing +# indices: +# +# DownBlocks -> [1, 3, 5, 7, 9, 11] +# MiddleBlock -> [31] +# UpBlocks -> [13, 15, 17, 19, 21, 23, 25, 27, 29] +# +# Same for SDXL with more layers (70 cross-attentions vs. 16) +CROSS_ATTN_MAPPING: dict[str, list[int]] = { + "sd15": list(range(1, 12, 2)) + [31] + list(range(13, 30, 2)), + "sdxl": list(range(1, 48, 2)) + list(range(121, 140, 2)) + list(range(49, 120, 2)), +} + def main() -> None: parser = argparse.ArgumentParser(description="Converts a IP-Adapter diffusers model to refiners.") parser.add_argument( "--from", type=str, + required=True, dest="source_path", - default="ip-adapter_sd15.bin", - help="Path to the source model. (default: 'ip-adapter_sd15.bin').", + help="Path to the source model. (e.g.: 'ip-adapter_sd15.bin').", ) parser.add_argument( "--to", type=str, dest="output_path", - default="ip-adapter_sd15.safetensors", - help="Path to save the converted model. (default: 'ip-adapter_sd15.safetensors').", + default=None, + help=( + "Path to save the converted model. If not specified, the output path will be the source path with the" + " extension changed to .safetensors." + ), ) parser.add_argument("--verbose", action="store_true", dest="verbose") parser.add_argument("--half", action="store_true", dest="half") @@ -34,9 +70,15 @@ def main() -> None: assert isinstance(weights, dict) assert sorted(weights.keys()) == ["image_proj", "ip_adapter"] - unet = SD1UNet(in_channels=4) - - ip_adapter = SD1IPAdapter(target=unet) + match len(weights["ip_adapter"]): + case 32: + ip_adapter = SD1IPAdapter(target=SD1UNet(in_channels=4)) + cross_attn_mapping = CROSS_ATTN_MAPPING["sd15"] + case 140: + ip_adapter = SDXLIPAdapter(target=SDXLUNet(in_channels=4)) + cross_attn_mapping = CROSS_ATTN_MAPPING["sdxl"] + case _: + raise ValueError("Unexpected number of keys in input checkpoint") # Manual conversion to avoid any runtime dependency on IP-Adapter[1] custom classes # [1]: https://github.com/tencent-ailab/IP-Adapter @@ -58,34 +100,6 @@ def main() -> None: ip_adapter_weights: dict[str, torch.Tensor] = weights["ip_adapter"] assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2 - # Running: - # - # from diffusers import UNet2DConditionModel - # unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") - # for k in unet.attn_processors.keys(): - # print(k) - # - # Gives: - # - # down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor - # down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor - # ... - # down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor - # up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor - # up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor - # ... - # up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor - # mid_block.attentions.0.transformer_blocks.0.attn1.processor - # mid_block.attentions.0.transformer_blocks.0.attn2.processor - # - # With attn1=self-attention and attn2=cross-attention, and middle block in last position. So in terms of increasing - # indices: - # - # DownBlocks -> [1, 3, 5, 7, 9, 11] - # MiddleBlock -> [31] - # UpBlocks -> [13, 15, 17, 19, 21, 23, 25, 27, 29] - cross_attn_mapping: list[int] = [1, 3, 5, 7, 9, 11, 31, 13, 15, 17, 19, 21, 23, 25, 27, 29] - for i, cross_attn in enumerate(ip_adapter.sub_adapters): cross_attn_index = cross_attn_mapping[i] k_ip = f"{cross_attn_index}.to_k_ip.weight" @@ -107,6 +121,8 @@ def main() -> None: if args.half: state_dict = {key: value.half() for key, value in state_dict.items()} + if args.output_path is None: + args.output_path = f"{Path(args.source_path).stem}.safetensors" save_to_safetensors(path=args.output_path, tensors=state_dict) diff --git a/src/refiners/foundationals/clip/image_encoder.py b/src/refiners/foundationals/clip/image_encoder.py index c752f04..590122a 100644 --- a/src/refiners/foundationals/clip/image_encoder.py +++ b/src/refiners/foundationals/clip/image_encoder.py @@ -121,6 +121,7 @@ class CLIPImageEncoder(fl.Chain): structural_attrs = [ "image_size", "embedding_dim", + "output_dim", "patch_size", "num_layers", "num_attention_heads", @@ -142,6 +143,7 @@ class CLIPImageEncoder(fl.Chain): ) -> None: self.image_size = image_size self.embedding_dim = embedding_dim + self.output_dim = output_dim self.patch_size = patch_size self.num_layers = num_layers self.num_attention_heads = num_attention_heads diff --git a/src/refiners/foundationals/latent_diffusion/__init__.py b/src/refiners/foundationals/latent_diffusion/__init__.py index f548e62..ca07e19 100644 --- a/src/refiners/foundationals/latent_diffusion/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/__init__.py @@ -15,6 +15,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1 import ( from refiners.foundationals.latent_diffusion.stable_diffusion_xl import ( SDXLUNet, DoubleTextEncoder, + SDXLIPAdapter, ) @@ -26,6 +27,7 @@ __all__ = [ "SD1IPAdapter", "SDXLUNet", "DoubleTextEncoder", + "SDXLIPAdapter", "DPMSolver", "Scheduler", "CLIPTextEncoderL", diff --git a/src/refiners/foundationals/latent_diffusion/image_prompt.py b/src/refiners/foundationals/latent_diffusion/image_prompt.py index 0acf298..756f71f 100644 --- a/src/refiners/foundationals/latent_diffusion/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/image_prompt.py @@ -1,21 +1,23 @@ from enum import IntEnum from functools import partial -from typing import Generic, TypeVar, Any, Callable +from typing import Generic, TypeVar, Any, Callable, TYPE_CHECKING from torch import Tensor, as_tensor, cat, zeros_like, device as Device, dtype as DType from PIL import Image from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.adapters.lora import Lora -from refiners.foundationals.clip.image_encoder import CLIPImageEncoder -from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet -from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet -from refiners.fluxion.layers.module import Module +from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH +from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d from refiners.fluxion.layers.attentions import ScaledDotProductAttention from refiners.fluxion.utils import image_to_tensor import refiners.fluxion.layers as fl -T = TypeVar("T", bound=SD1UNet | SDXLUNet) +if TYPE_CHECKING: + from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet + from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet + +T = TypeVar("T", bound="SD1UNet | SDXLUNet") TIPAdapter = TypeVar("TIPAdapter", bound="IPAdapter[Any]") # Self (see PEP 673) @@ -128,7 +130,7 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): def scale_outputs(self, x: Tensor) -> Tensor: return x * self.scale - def _predicate(self, k: type[Module]) -> Callable[[fl.Module, fl.Chain], bool]: + def _predicate(self, k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]: def f(m: fl.Module, _: fl.Chain) -> bool: if isinstance(m, Lora): # do not adapt LoRAs raise StopIteration @@ -167,15 +169,22 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): def __init__( self, target: T, - clip_image_encoder: CLIPImageEncoder, + clip_image_encoder: CLIPImageEncoderH | None = None, scale: float = 1.0, weights: dict[str, Tensor] | None = None, ) -> None: with self.setup_adapter(target): super().__init__(target) - self.clip_image_encoder = clip_image_encoder - self.image_proj = ImageProjection(device=target.device, dtype=target.dtype) + cross_attn_2d = target.ensure_find(CrossAttentionBlock2d) + + self.clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype) + self.image_proj = ImageProjection( + clip_image_embedding_dim=self.clip_image_encoder.output_dim, + clip_text_embedding_dim=cross_attn_2d.context_embedding_dim, + device=target.device, + dtype=target.dtype, + ) self.sub_adapters = [ CrossAttentionAdapter(target=cross_attn, scale=scale) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/image_prompt.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/image_prompt.py index 021fd58..fe16034 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/image_prompt.py @@ -1,22 +1,6 @@ -from torch import Tensor - from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1UNet -from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH - class SD1IPAdapter(IPAdapter[SD1UNet]): - def __init__( - self, - target: SD1UNet, - clip_image_encoder: CLIPImageEncoderH | None = None, - scale: float = 1.0, - weights: dict[str, Tensor] | None = None, - ) -> None: - super().__init__( - target=target, - clip_image_encoder=clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype), - scale=scale, - weights=weights, - ) + pass diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/__init__.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/__init__.py index e3500c6..f2b20aa 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/__init__.py @@ -1,9 +1,12 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter + __all__ = [ "SDXLUNet", "DoubleTextEncoder", "StableDiffusion_XL", + "SDXLIPAdapter", ] diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/image_prompt.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/image_prompt.py new file mode 100644 index 0000000..d78b00a --- /dev/null +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/image_prompt.py @@ -0,0 +1,6 @@ +from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter +from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet + + +class SDXLIPAdapter(IPAdapter[SDXLUNet]): + pass diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index a9bfe96..48f38bf 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -14,6 +14,7 @@ from refiners.foundationals.latent_diffusion import ( SD1UNet, SD1ControlnetAdapter, SD1IPAdapter, + SDXLIPAdapter, ) from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter from refiners.foundationals.latent_diffusion.schedulers import DDIM @@ -74,6 +75,11 @@ def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image: return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB") +@pytest.fixture +def expected_image_sdxl_ip_adapter_woman(ref_path: Path) -> Image.Image: + return Image.open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB") + + @pytest.fixture def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image: return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert(mode="RGB") @@ -217,6 +223,15 @@ def ip_adapter_weights(test_weights_path: Path) -> Path: return ip_adapter_weights +@pytest.fixture(scope="module") +def sdxl_ip_adapter_weights(test_weights_path: Path) -> Path: + ip_adapter_weights = test_weights_path / "ip-adapter_sdxl_vit-h.safetensors" + if not ip_adapter_weights.is_file(): + warn(f"could not find weights at {ip_adapter_weights}, skipping") + pytest.skip(allow_module_level=True) + return ip_adapter_weights + + @pytest.fixture(scope="module") def image_encoder_weights(test_weights_path: Path) -> Path: image_encoder_weights = test_weights_path / "CLIPImageEncoderH.safetensors" @@ -1050,6 +1065,64 @@ def test_diffusion_ip_adapter( ensure_similar_images(predicted_image, expected_image_ip_adapter_woman) +@torch.no_grad() +def test_diffusion_sdxl_ip_adapter( + sdxl_ddim: StableDiffusion_XL, + sdxl_ip_adapter_weights: Path, + image_encoder_weights: Path, + woman_image: Image.Image, + expected_image_sdxl_ip_adapter_woman: Image.Image, + test_device: torch.device, +): + sdxl = sdxl_ddim.to(dtype=torch.float16) + n_steps = 30 + + prompt = "best quality, high quality" + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_weights)) + ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) + ip_adapter.inject() + + with torch.no_grad(): + clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( + text=prompt, negative_text=negative_prompt + ) + clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image)) + + negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2) + negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2) + + clip_text_embedding = torch.cat( + ( + torch.cat([negative_text_embedding, negative_image_embedding], dim=1), + torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1), + ) + ) + time_ids = sdxl.default_time_ids + sdxl.set_num_inference_steps(n_steps) + + manual_seed(2) + x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16) + + with torch.no_grad(): + for step in sdxl.steps: + x = sdxl( + x, + step=step, + clip_text_embedding=clip_text_embedding, + pooled_text_embedding=pooled_text_embedding, + time_ids=time_ids, + condition_scale=5, + ) + # See https://huggingface.co/madebyollin/sdxl-vae-fp16-fix: "SDXL-VAE generates NaNs in fp16 because the + # internal activation values are too big" + sdxl.lda.to(dtype=torch.float32) + predicted_image = sdxl.lda.decode_latents(x.to(dtype=torch.float32)) + + ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_woman) + + @torch.no_grad() def test_sdxl_random_init( sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device diff --git a/tests/e2e/test_diffusion_ref/README.md b/tests/e2e/test_diffusion_ref/README.md index 02ee616..084b38b 100644 --- a/tests/e2e/test_diffusion_ref/README.md +++ b/tests/e2e/test_diffusion_ref/README.md @@ -35,7 +35,7 @@ output.images[0].save("std_random_init_expected.png") Special cases: - `expected_refonly.png` has been generated [with Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui). -- `expected_inpainting_refonly.png`, `expected_image_ip_adapter_woman.png` have been generated with refiners itself (and inspected so that it looks reasonable). +- `expected_inpainting_refonly.png`, `expected_image_ip_adapter_woman.png`, `expected_image_sdxl_ip_adapter_woman.png` have been generated with refiners itself (and inspected so that they look reasonable). ## Other images diff --git a/tests/e2e/test_diffusion_ref/expected_image_sdxl_ip_adapter_woman.png b/tests/e2e/test_diffusion_ref/expected_image_sdxl_ip_adapter_woman.png new file mode 100644 index 0000000..10e767f Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_image_sdxl_ip_adapter_woman.png differ