add support for SDXL IP-Adapter

This only supports the latest SDXL IP-Adapter release (2023.9.8) which
builds upon the ViT-H/14 CLIP image encoder.
This commit is contained in:
Cédric Deltheil 2023-09-12 17:28:13 +02:00 committed by Cédric Deltheil
parent 1b4dcebe06
commit eea340c6c4
10 changed files with 159 additions and 64 deletions

View file

@ -4,25 +4,61 @@ import argparse
import torch
from refiners.foundationals.latent_diffusion import SD1UNet, SD1IPAdapter
from refiners.foundationals.latent_diffusion import SD1UNet, SD1IPAdapter, SDXLUNet, SDXLIPAdapter
from refiners.fluxion.utils import save_to_safetensors
# Running:
#
# from diffusers import UNet2DConditionModel
# unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
# for k in unet.attn_processors.keys():
# print(k)
#
# Gives:
#
# down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor
# down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor
# ...
# down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor
# up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor
# up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor
# ...
# up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor
# mid_block.attentions.0.transformer_blocks.0.attn1.processor
# mid_block.attentions.0.transformer_blocks.0.attn2.processor
#
# With attn1=self-attention and attn2=cross-attention, and middle block in last position. So in terms of increasing
# indices:
#
# DownBlocks -> [1, 3, 5, 7, 9, 11]
# MiddleBlock -> [31]
# UpBlocks -> [13, 15, 17, 19, 21, 23, 25, 27, 29]
#
# Same for SDXL with more layers (70 cross-attentions vs. 16)
CROSS_ATTN_MAPPING: dict[str, list[int]] = {
"sd15": list(range(1, 12, 2)) + [31] + list(range(13, 30, 2)),
"sdxl": list(range(1, 48, 2)) + list(range(121, 140, 2)) + list(range(49, 120, 2)),
}
def main() -> None:
parser = argparse.ArgumentParser(description="Converts a IP-Adapter diffusers model to refiners.")
parser.add_argument(
"--from",
type=str,
required=True,
dest="source_path",
default="ip-adapter_sd15.bin",
help="Path to the source model. (default: 'ip-adapter_sd15.bin').",
help="Path to the source model. (e.g.: 'ip-adapter_sd15.bin').",
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default="ip-adapter_sd15.safetensors",
help="Path to save the converted model. (default: 'ip-adapter_sd15.safetensors').",
default=None,
help=(
"Path to save the converted model. If not specified, the output path will be the source path with the"
" extension changed to .safetensors."
),
)
parser.add_argument("--verbose", action="store_true", dest="verbose")
parser.add_argument("--half", action="store_true", dest="half")
@ -34,9 +70,15 @@ def main() -> None:
assert isinstance(weights, dict)
assert sorted(weights.keys()) == ["image_proj", "ip_adapter"]
unet = SD1UNet(in_channels=4)
ip_adapter = SD1IPAdapter(target=unet)
match len(weights["ip_adapter"]):
case 32:
ip_adapter = SD1IPAdapter(target=SD1UNet(in_channels=4))
cross_attn_mapping = CROSS_ATTN_MAPPING["sd15"]
case 140:
ip_adapter = SDXLIPAdapter(target=SDXLUNet(in_channels=4))
cross_attn_mapping = CROSS_ATTN_MAPPING["sdxl"]
case _:
raise ValueError("Unexpected number of keys in input checkpoint")
# Manual conversion to avoid any runtime dependency on IP-Adapter[1] custom classes
# [1]: https://github.com/tencent-ailab/IP-Adapter
@ -58,34 +100,6 @@ def main() -> None:
ip_adapter_weights: dict[str, torch.Tensor] = weights["ip_adapter"]
assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2
# Running:
#
# from diffusers import UNet2DConditionModel
# unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
# for k in unet.attn_processors.keys():
# print(k)
#
# Gives:
#
# down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor
# down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor
# ...
# down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor
# up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor
# up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor
# ...
# up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor
# mid_block.attentions.0.transformer_blocks.0.attn1.processor
# mid_block.attentions.0.transformer_blocks.0.attn2.processor
#
# With attn1=self-attention and attn2=cross-attention, and middle block in last position. So in terms of increasing
# indices:
#
# DownBlocks -> [1, 3, 5, 7, 9, 11]
# MiddleBlock -> [31]
# UpBlocks -> [13, 15, 17, 19, 21, 23, 25, 27, 29]
cross_attn_mapping: list[int] = [1, 3, 5, 7, 9, 11, 31, 13, 15, 17, 19, 21, 23, 25, 27, 29]
for i, cross_attn in enumerate(ip_adapter.sub_adapters):
cross_attn_index = cross_attn_mapping[i]
k_ip = f"{cross_attn_index}.to_k_ip.weight"
@ -107,6 +121,8 @@ def main() -> None:
if args.half:
state_dict = {key: value.half() for key, value in state_dict.items()}
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}.safetensors"
save_to_safetensors(path=args.output_path, tensors=state_dict)

View file

@ -121,6 +121,7 @@ class CLIPImageEncoder(fl.Chain):
structural_attrs = [
"image_size",
"embedding_dim",
"output_dim",
"patch_size",
"num_layers",
"num_attention_heads",
@ -142,6 +143,7 @@ class CLIPImageEncoder(fl.Chain):
) -> None:
self.image_size = image_size
self.embedding_dim = embedding_dim
self.output_dim = output_dim
self.patch_size = patch_size
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads

View file

@ -15,6 +15,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import (
SDXLUNet,
DoubleTextEncoder,
SDXLIPAdapter,
)
@ -26,6 +27,7 @@ __all__ = [
"SD1IPAdapter",
"SDXLUNet",
"DoubleTextEncoder",
"SDXLIPAdapter",
"DPMSolver",
"Scheduler",
"CLIPTextEncoderL",

View file

@ -1,21 +1,23 @@
from enum import IntEnum
from functools import partial
from typing import Generic, TypeVar, Any, Callable
from typing import Generic, TypeVar, Any, Callable, TYPE_CHECKING
from torch import Tensor, as_tensor, cat, zeros_like, device as Device, dtype as DType
from PIL import Image
from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.adapters.lora import Lora
from refiners.foundationals.clip.image_encoder import CLIPImageEncoder
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
from refiners.fluxion.layers.module import Module
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
from refiners.fluxion.utils import image_to_tensor
import refiners.fluxion.layers as fl
T = TypeVar("T", bound=SD1UNet | SDXLUNet)
if TYPE_CHECKING:
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
T = TypeVar("T", bound="SD1UNet | SDXLUNet")
TIPAdapter = TypeVar("TIPAdapter", bound="IPAdapter[Any]") # Self (see PEP 673)
@ -128,7 +130,7 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
def scale_outputs(self, x: Tensor) -> Tensor:
return x * self.scale
def _predicate(self, k: type[Module]) -> Callable[[fl.Module, fl.Chain], bool]:
def _predicate(self, k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
def f(m: fl.Module, _: fl.Chain) -> bool:
if isinstance(m, Lora): # do not adapt LoRAs
raise StopIteration
@ -167,15 +169,22 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
def __init__(
self,
target: T,
clip_image_encoder: CLIPImageEncoder,
clip_image_encoder: CLIPImageEncoderH | None = None,
scale: float = 1.0,
weights: dict[str, Tensor] | None = None,
) -> None:
with self.setup_adapter(target):
super().__init__(target)
self.clip_image_encoder = clip_image_encoder
self.image_proj = ImageProjection(device=target.device, dtype=target.dtype)
cross_attn_2d = target.ensure_find(CrossAttentionBlock2d)
self.clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)
self.image_proj = ImageProjection(
clip_image_embedding_dim=self.clip_image_encoder.output_dim,
clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,
device=target.device,
dtype=target.dtype,
)
self.sub_adapters = [
CrossAttentionAdapter(target=cross_attn, scale=scale)

View file

@ -1,22 +1,6 @@
from torch import Tensor
from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1UNet
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
class SD1IPAdapter(IPAdapter[SD1UNet]):
def __init__(
self,
target: SD1UNet,
clip_image_encoder: CLIPImageEncoderH | None = None,
scale: float = 1.0,
weights: dict[str, Tensor] | None = None,
) -> None:
super().__init__(
target=target,
clip_image_encoder=clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype),
scale=scale,
weights=weights,
)
pass

View file

@ -1,9 +1,12 @@
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter
__all__ = [
"SDXLUNet",
"DoubleTextEncoder",
"StableDiffusion_XL",
"SDXLIPAdapter",
]

View file

@ -0,0 +1,6 @@
from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
class SDXLIPAdapter(IPAdapter[SDXLUNet]):
pass

View file

@ -14,6 +14,7 @@ from refiners.foundationals.latent_diffusion import (
SD1UNet,
SD1ControlnetAdapter,
SD1IPAdapter,
SDXLIPAdapter,
)
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
from refiners.foundationals.latent_diffusion.schedulers import DDIM
@ -74,6 +75,11 @@ def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB")
@pytest.fixture
def expected_image_sdxl_ip_adapter_woman(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB")
@pytest.fixture
def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image:
return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert(mode="RGB")
@ -217,6 +223,15 @@ def ip_adapter_weights(test_weights_path: Path) -> Path:
return ip_adapter_weights
@pytest.fixture(scope="module")
def sdxl_ip_adapter_weights(test_weights_path: Path) -> Path:
ip_adapter_weights = test_weights_path / "ip-adapter_sdxl_vit-h.safetensors"
if not ip_adapter_weights.is_file():
warn(f"could not find weights at {ip_adapter_weights}, skipping")
pytest.skip(allow_module_level=True)
return ip_adapter_weights
@pytest.fixture(scope="module")
def image_encoder_weights(test_weights_path: Path) -> Path:
image_encoder_weights = test_weights_path / "CLIPImageEncoderH.safetensors"
@ -1050,6 +1065,64 @@ def test_diffusion_ip_adapter(
ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)
@torch.no_grad()
def test_diffusion_sdxl_ip_adapter(
sdxl_ddim: StableDiffusion_XL,
sdxl_ip_adapter_weights: Path,
image_encoder_weights: Path,
woman_image: Image.Image,
expected_image_sdxl_ip_adapter_woman: Image.Image,
test_device: torch.device,
):
sdxl = sdxl_ddim.to(dtype=torch.float16)
n_steps = 30
prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_weights))
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.inject()
with torch.no_grad():
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt
)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(n_steps)
manual_seed(2)
x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16)
with torch.no_grad():
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=5,
)
# See https://huggingface.co/madebyollin/sdxl-vae-fp16-fix: "SDXL-VAE generates NaNs in fp16 because the
# internal activation values are too big"
sdxl.lda.to(dtype=torch.float32)
predicted_image = sdxl.lda.decode_latents(x.to(dtype=torch.float32))
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_woman)
@torch.no_grad()
def test_sdxl_random_init(
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device

View file

@ -35,7 +35,7 @@ output.images[0].save("std_random_init_expected.png")
Special cases:
- `expected_refonly.png` has been generated [with Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui).
- `expected_inpainting_refonly.png`, `expected_image_ip_adapter_woman.png` have been generated with refiners itself (and inspected so that it looks reasonable).
- `expected_inpainting_refonly.png`, `expected_image_ip_adapter_woman.png`, `expected_image_sdxl_ip_adapter_woman.png` have been generated with refiners itself (and inspected so that they look reasonable).
## Other images

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB