add support for SDXL IP-Adapter

This only supports the latest SDXL IP-Adapter release (2023.9.8) which
builds upon the ViT-H/14 CLIP image encoder.
This commit is contained in:
Cédric Deltheil 2023-09-12 17:28:13 +02:00 committed by Cédric Deltheil
parent 1b4dcebe06
commit eea340c6c4
10 changed files with 159 additions and 64 deletions

View file

@ -4,60 +4,9 @@ import argparse
import torch import torch
from refiners.foundationals.latent_diffusion import SD1UNet, SD1IPAdapter from refiners.foundationals.latent_diffusion import SD1UNet, SD1IPAdapter, SDXLUNet, SDXLIPAdapter
from refiners.fluxion.utils import save_to_safetensors from refiners.fluxion.utils import save_to_safetensors
def main() -> None:
parser = argparse.ArgumentParser(description="Converts a IP-Adapter diffusers model to refiners.")
parser.add_argument(
"--from",
type=str,
dest="source_path",
default="ip-adapter_sd15.bin",
help="Path to the source model. (default: 'ip-adapter_sd15.bin').",
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default="ip-adapter_sd15.safetensors",
help="Path to save the converted model. (default: 'ip-adapter_sd15.safetensors').",
)
parser.add_argument("--verbose", action="store_true", dest="verbose")
parser.add_argument("--half", action="store_true", dest="half")
args = parser.parse_args()
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}.safetensors"
weights: dict[str, Any] = torch.load(f=args.source_path, map_location="cpu") # type: ignore
assert isinstance(weights, dict)
assert sorted(weights.keys()) == ["image_proj", "ip_adapter"]
unet = SD1UNet(in_channels=4)
ip_adapter = SD1IPAdapter(target=unet)
# Manual conversion to avoid any runtime dependency on IP-Adapter[1] custom classes
# [1]: https://github.com/tencent-ailab/IP-Adapter
state_dict: dict[str, torch.Tensor] = {}
image_proj_weights = weights["image_proj"]
image_proj_state_dict: dict[str, torch.Tensor] = {
"Linear.weight": image_proj_weights["proj.weight"],
"Linear.bias": image_proj_weights["proj.bias"],
"LayerNorm.weight": image_proj_weights["norm.weight"],
"LayerNorm.bias": image_proj_weights["norm.bias"],
}
ip_adapter.image_proj.load_state_dict(state_dict=image_proj_state_dict)
for k, v in image_proj_state_dict.items():
state_dict[f"image_proj.{k}"] = v
ip_adapter_weights: dict[str, torch.Tensor] = weights["ip_adapter"]
assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2
# Running: # Running:
# #
# from diffusers import UNet2DConditionModel # from diffusers import UNet2DConditionModel
@ -84,7 +33,72 @@ def main() -> None:
# DownBlocks -> [1, 3, 5, 7, 9, 11] # DownBlocks -> [1, 3, 5, 7, 9, 11]
# MiddleBlock -> [31] # MiddleBlock -> [31]
# UpBlocks -> [13, 15, 17, 19, 21, 23, 25, 27, 29] # UpBlocks -> [13, 15, 17, 19, 21, 23, 25, 27, 29]
cross_attn_mapping: list[int] = [1, 3, 5, 7, 9, 11, 31, 13, 15, 17, 19, 21, 23, 25, 27, 29] #
# Same for SDXL with more layers (70 cross-attentions vs. 16)
CROSS_ATTN_MAPPING: dict[str, list[int]] = {
"sd15": list(range(1, 12, 2)) + [31] + list(range(13, 30, 2)),
"sdxl": list(range(1, 48, 2)) + list(range(121, 140, 2)) + list(range(49, 120, 2)),
}
def main() -> None:
parser = argparse.ArgumentParser(description="Converts a IP-Adapter diffusers model to refiners.")
parser.add_argument(
"--from",
type=str,
required=True,
dest="source_path",
help="Path to the source model. (e.g.: 'ip-adapter_sd15.bin').",
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default=None,
help=(
"Path to save the converted model. If not specified, the output path will be the source path with the"
" extension changed to .safetensors."
),
)
parser.add_argument("--verbose", action="store_true", dest="verbose")
parser.add_argument("--half", action="store_true", dest="half")
args = parser.parse_args()
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}.safetensors"
weights: dict[str, Any] = torch.load(f=args.source_path, map_location="cpu") # type: ignore
assert isinstance(weights, dict)
assert sorted(weights.keys()) == ["image_proj", "ip_adapter"]
match len(weights["ip_adapter"]):
case 32:
ip_adapter = SD1IPAdapter(target=SD1UNet(in_channels=4))
cross_attn_mapping = CROSS_ATTN_MAPPING["sd15"]
case 140:
ip_adapter = SDXLIPAdapter(target=SDXLUNet(in_channels=4))
cross_attn_mapping = CROSS_ATTN_MAPPING["sdxl"]
case _:
raise ValueError("Unexpected number of keys in input checkpoint")
# Manual conversion to avoid any runtime dependency on IP-Adapter[1] custom classes
# [1]: https://github.com/tencent-ailab/IP-Adapter
state_dict: dict[str, torch.Tensor] = {}
image_proj_weights = weights["image_proj"]
image_proj_state_dict: dict[str, torch.Tensor] = {
"Linear.weight": image_proj_weights["proj.weight"],
"Linear.bias": image_proj_weights["proj.bias"],
"LayerNorm.weight": image_proj_weights["norm.weight"],
"LayerNorm.bias": image_proj_weights["norm.bias"],
}
ip_adapter.image_proj.load_state_dict(state_dict=image_proj_state_dict)
for k, v in image_proj_state_dict.items():
state_dict[f"image_proj.{k}"] = v
ip_adapter_weights: dict[str, torch.Tensor] = weights["ip_adapter"]
assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2
for i, cross_attn in enumerate(ip_adapter.sub_adapters): for i, cross_attn in enumerate(ip_adapter.sub_adapters):
cross_attn_index = cross_attn_mapping[i] cross_attn_index = cross_attn_mapping[i]
@ -107,6 +121,8 @@ def main() -> None:
if args.half: if args.half:
state_dict = {key: value.half() for key, value in state_dict.items()} state_dict = {key: value.half() for key, value in state_dict.items()}
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}.safetensors"
save_to_safetensors(path=args.output_path, tensors=state_dict) save_to_safetensors(path=args.output_path, tensors=state_dict)

View file

@ -121,6 +121,7 @@ class CLIPImageEncoder(fl.Chain):
structural_attrs = [ structural_attrs = [
"image_size", "image_size",
"embedding_dim", "embedding_dim",
"output_dim",
"patch_size", "patch_size",
"num_layers", "num_layers",
"num_attention_heads", "num_attention_heads",
@ -142,6 +143,7 @@ class CLIPImageEncoder(fl.Chain):
) -> None: ) -> None:
self.image_size = image_size self.image_size = image_size
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.output_dim = output_dim
self.patch_size = patch_size self.patch_size = patch_size
self.num_layers = num_layers self.num_layers = num_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads

View file

@ -15,6 +15,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import ( from refiners.foundationals.latent_diffusion.stable_diffusion_xl import (
SDXLUNet, SDXLUNet,
DoubleTextEncoder, DoubleTextEncoder,
SDXLIPAdapter,
) )
@ -26,6 +27,7 @@ __all__ = [
"SD1IPAdapter", "SD1IPAdapter",
"SDXLUNet", "SDXLUNet",
"DoubleTextEncoder", "DoubleTextEncoder",
"SDXLIPAdapter",
"DPMSolver", "DPMSolver",
"Scheduler", "Scheduler",
"CLIPTextEncoderL", "CLIPTextEncoderL",

View file

@ -1,21 +1,23 @@
from enum import IntEnum from enum import IntEnum
from functools import partial from functools import partial
from typing import Generic, TypeVar, Any, Callable from typing import Generic, TypeVar, Any, Callable, TYPE_CHECKING
from torch import Tensor, as_tensor, cat, zeros_like, device as Device, dtype as DType from torch import Tensor, as_tensor, cat, zeros_like, device as Device, dtype as DType
from PIL import Image from PIL import Image
from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.adapters.lora import Lora from refiners.fluxion.adapters.lora import Lora
from refiners.foundationals.clip.image_encoder import CLIPImageEncoder from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
from refiners.fluxion.layers.module import Module
from refiners.fluxion.layers.attentions import ScaledDotProductAttention from refiners.fluxion.layers.attentions import ScaledDotProductAttention
from refiners.fluxion.utils import image_to_tensor from refiners.fluxion.utils import image_to_tensor
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
T = TypeVar("T", bound=SD1UNet | SDXLUNet) if TYPE_CHECKING:
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
T = TypeVar("T", bound="SD1UNet | SDXLUNet")
TIPAdapter = TypeVar("TIPAdapter", bound="IPAdapter[Any]") # Self (see PEP 673) TIPAdapter = TypeVar("TIPAdapter", bound="IPAdapter[Any]") # Self (see PEP 673)
@ -128,7 +130,7 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
def scale_outputs(self, x: Tensor) -> Tensor: def scale_outputs(self, x: Tensor) -> Tensor:
return x * self.scale return x * self.scale
def _predicate(self, k: type[Module]) -> Callable[[fl.Module, fl.Chain], bool]: def _predicate(self, k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
def f(m: fl.Module, _: fl.Chain) -> bool: def f(m: fl.Module, _: fl.Chain) -> bool:
if isinstance(m, Lora): # do not adapt LoRAs if isinstance(m, Lora): # do not adapt LoRAs
raise StopIteration raise StopIteration
@ -167,15 +169,22 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
def __init__( def __init__(
self, self,
target: T, target: T,
clip_image_encoder: CLIPImageEncoder, clip_image_encoder: CLIPImageEncoderH | None = None,
scale: float = 1.0, scale: float = 1.0,
weights: dict[str, Tensor] | None = None, weights: dict[str, Tensor] | None = None,
) -> None: ) -> None:
with self.setup_adapter(target): with self.setup_adapter(target):
super().__init__(target) super().__init__(target)
self.clip_image_encoder = clip_image_encoder cross_attn_2d = target.ensure_find(CrossAttentionBlock2d)
self.image_proj = ImageProjection(device=target.device, dtype=target.dtype)
self.clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)
self.image_proj = ImageProjection(
clip_image_embedding_dim=self.clip_image_encoder.output_dim,
clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,
device=target.device,
dtype=target.dtype,
)
self.sub_adapters = [ self.sub_adapters = [
CrossAttentionAdapter(target=cross_attn, scale=scale) CrossAttentionAdapter(target=cross_attn, scale=scale)

View file

@ -1,22 +1,6 @@
from torch import Tensor
from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1UNet from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1UNet
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
class SD1IPAdapter(IPAdapter[SD1UNet]): class SD1IPAdapter(IPAdapter[SD1UNet]):
def __init__( pass
self,
target: SD1UNet,
clip_image_encoder: CLIPImageEncoderH | None = None,
scale: float = 1.0,
weights: dict[str, Tensor] | None = None,
) -> None:
super().__init__(
target=target,
clip_image_encoder=clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype),
scale=scale,
weights=weights,
)

View file

@ -1,9 +1,12 @@
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter
__all__ = [ __all__ = [
"SDXLUNet", "SDXLUNet",
"DoubleTextEncoder", "DoubleTextEncoder",
"StableDiffusion_XL", "StableDiffusion_XL",
"SDXLIPAdapter",
] ]

View file

@ -0,0 +1,6 @@
from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
class SDXLIPAdapter(IPAdapter[SDXLUNet]):
pass

View file

@ -14,6 +14,7 @@ from refiners.foundationals.latent_diffusion import (
SD1UNet, SD1UNet,
SD1ControlnetAdapter, SD1ControlnetAdapter,
SD1IPAdapter, SD1IPAdapter,
SDXLIPAdapter,
) )
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
from refiners.foundationals.latent_diffusion.schedulers import DDIM from refiners.foundationals.latent_diffusion.schedulers import DDIM
@ -74,6 +75,11 @@ def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB") return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB")
@pytest.fixture
def expected_image_sdxl_ip_adapter_woman(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image: def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image:
return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert(mode="RGB") return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert(mode="RGB")
@ -217,6 +223,15 @@ def ip_adapter_weights(test_weights_path: Path) -> Path:
return ip_adapter_weights return ip_adapter_weights
@pytest.fixture(scope="module")
def sdxl_ip_adapter_weights(test_weights_path: Path) -> Path:
ip_adapter_weights = test_weights_path / "ip-adapter_sdxl_vit-h.safetensors"
if not ip_adapter_weights.is_file():
warn(f"could not find weights at {ip_adapter_weights}, skipping")
pytest.skip(allow_module_level=True)
return ip_adapter_weights
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def image_encoder_weights(test_weights_path: Path) -> Path: def image_encoder_weights(test_weights_path: Path) -> Path:
image_encoder_weights = test_weights_path / "CLIPImageEncoderH.safetensors" image_encoder_weights = test_weights_path / "CLIPImageEncoderH.safetensors"
@ -1050,6 +1065,64 @@ def test_diffusion_ip_adapter(
ensure_similar_images(predicted_image, expected_image_ip_adapter_woman) ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)
@torch.no_grad()
def test_diffusion_sdxl_ip_adapter(
sdxl_ddim: StableDiffusion_XL,
sdxl_ip_adapter_weights: Path,
image_encoder_weights: Path,
woman_image: Image.Image,
expected_image_sdxl_ip_adapter_woman: Image.Image,
test_device: torch.device,
):
sdxl = sdxl_ddim.to(dtype=torch.float16)
n_steps = 30
prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_weights))
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.inject()
with torch.no_grad():
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt
)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(n_steps)
manual_seed(2)
x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16)
with torch.no_grad():
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=5,
)
# See https://huggingface.co/madebyollin/sdxl-vae-fp16-fix: "SDXL-VAE generates NaNs in fp16 because the
# internal activation values are too big"
sdxl.lda.to(dtype=torch.float32)
predicted_image = sdxl.lda.decode_latents(x.to(dtype=torch.float32))
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_woman)
@torch.no_grad() @torch.no_grad()
def test_sdxl_random_init( def test_sdxl_random_init(
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device

View file

@ -35,7 +35,7 @@ output.images[0].save("std_random_init_expected.png")
Special cases: Special cases:
- `expected_refonly.png` has been generated [with Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui). - `expected_refonly.png` has been generated [with Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui).
- `expected_inpainting_refonly.png`, `expected_image_ip_adapter_woman.png` have been generated with refiners itself (and inspected so that it looks reasonable). - `expected_inpainting_refonly.png`, `expected_image_ip_adapter_woman.png`, `expected_image_sdxl_ip_adapter_woman.png` have been generated with refiners itself (and inspected so that they look reasonable).
## Other images ## Other images

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB