mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
add support for SDXL IP-Adapter
This only supports the latest SDXL IP-Adapter release (2023.9.8) which builds upon the ViT-H/14 CLIP image encoder.
This commit is contained in:
parent
1b4dcebe06
commit
eea340c6c4
|
@ -4,25 +4,61 @@ import argparse
|
|||
|
||||
import torch
|
||||
|
||||
from refiners.foundationals.latent_diffusion import SD1UNet, SD1IPAdapter
|
||||
from refiners.foundationals.latent_diffusion import SD1UNet, SD1IPAdapter, SDXLUNet, SDXLIPAdapter
|
||||
from refiners.fluxion.utils import save_to_safetensors
|
||||
|
||||
# Running:
|
||||
#
|
||||
# from diffusers import UNet2DConditionModel
|
||||
# unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
|
||||
# for k in unet.attn_processors.keys():
|
||||
# print(k)
|
||||
#
|
||||
# Gives:
|
||||
#
|
||||
# down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor
|
||||
# down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor
|
||||
# ...
|
||||
# down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor
|
||||
# up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor
|
||||
# up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor
|
||||
# ...
|
||||
# up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor
|
||||
# mid_block.attentions.0.transformer_blocks.0.attn1.processor
|
||||
# mid_block.attentions.0.transformer_blocks.0.attn2.processor
|
||||
#
|
||||
# With attn1=self-attention and attn2=cross-attention, and middle block in last position. So in terms of increasing
|
||||
# indices:
|
||||
#
|
||||
# DownBlocks -> [1, 3, 5, 7, 9, 11]
|
||||
# MiddleBlock -> [31]
|
||||
# UpBlocks -> [13, 15, 17, 19, 21, 23, 25, 27, 29]
|
||||
#
|
||||
# Same for SDXL with more layers (70 cross-attentions vs. 16)
|
||||
CROSS_ATTN_MAPPING: dict[str, list[int]] = {
|
||||
"sd15": list(range(1, 12, 2)) + [31] + list(range(13, 30, 2)),
|
||||
"sdxl": list(range(1, 48, 2)) + list(range(121, 140, 2)) + list(range(49, 120, 2)),
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Converts a IP-Adapter diffusers model to refiners.")
|
||||
parser.add_argument(
|
||||
"--from",
|
||||
type=str,
|
||||
required=True,
|
||||
dest="source_path",
|
||||
default="ip-adapter_sd15.bin",
|
||||
help="Path to the source model. (default: 'ip-adapter_sd15.bin').",
|
||||
help="Path to the source model. (e.g.: 'ip-adapter_sd15.bin').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
type=str,
|
||||
dest="output_path",
|
||||
default="ip-adapter_sd15.safetensors",
|
||||
help="Path to save the converted model. (default: 'ip-adapter_sd15.safetensors').",
|
||||
default=None,
|
||||
help=(
|
||||
"Path to save the converted model. If not specified, the output path will be the source path with the"
|
||||
" extension changed to .safetensors."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--verbose", action="store_true", dest="verbose")
|
||||
parser.add_argument("--half", action="store_true", dest="half")
|
||||
|
@ -34,9 +70,15 @@ def main() -> None:
|
|||
assert isinstance(weights, dict)
|
||||
assert sorted(weights.keys()) == ["image_proj", "ip_adapter"]
|
||||
|
||||
unet = SD1UNet(in_channels=4)
|
||||
|
||||
ip_adapter = SD1IPAdapter(target=unet)
|
||||
match len(weights["ip_adapter"]):
|
||||
case 32:
|
||||
ip_adapter = SD1IPAdapter(target=SD1UNet(in_channels=4))
|
||||
cross_attn_mapping = CROSS_ATTN_MAPPING["sd15"]
|
||||
case 140:
|
||||
ip_adapter = SDXLIPAdapter(target=SDXLUNet(in_channels=4))
|
||||
cross_attn_mapping = CROSS_ATTN_MAPPING["sdxl"]
|
||||
case _:
|
||||
raise ValueError("Unexpected number of keys in input checkpoint")
|
||||
|
||||
# Manual conversion to avoid any runtime dependency on IP-Adapter[1] custom classes
|
||||
# [1]: https://github.com/tencent-ailab/IP-Adapter
|
||||
|
@ -58,34 +100,6 @@ def main() -> None:
|
|||
ip_adapter_weights: dict[str, torch.Tensor] = weights["ip_adapter"]
|
||||
assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2
|
||||
|
||||
# Running:
|
||||
#
|
||||
# from diffusers import UNet2DConditionModel
|
||||
# unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
|
||||
# for k in unet.attn_processors.keys():
|
||||
# print(k)
|
||||
#
|
||||
# Gives:
|
||||
#
|
||||
# down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor
|
||||
# down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor
|
||||
# ...
|
||||
# down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor
|
||||
# up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor
|
||||
# up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor
|
||||
# ...
|
||||
# up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor
|
||||
# mid_block.attentions.0.transformer_blocks.0.attn1.processor
|
||||
# mid_block.attentions.0.transformer_blocks.0.attn2.processor
|
||||
#
|
||||
# With attn1=self-attention and attn2=cross-attention, and middle block in last position. So in terms of increasing
|
||||
# indices:
|
||||
#
|
||||
# DownBlocks -> [1, 3, 5, 7, 9, 11]
|
||||
# MiddleBlock -> [31]
|
||||
# UpBlocks -> [13, 15, 17, 19, 21, 23, 25, 27, 29]
|
||||
cross_attn_mapping: list[int] = [1, 3, 5, 7, 9, 11, 31, 13, 15, 17, 19, 21, 23, 25, 27, 29]
|
||||
|
||||
for i, cross_attn in enumerate(ip_adapter.sub_adapters):
|
||||
cross_attn_index = cross_attn_mapping[i]
|
||||
k_ip = f"{cross_attn_index}.to_k_ip.weight"
|
||||
|
@ -107,6 +121,8 @@ def main() -> None:
|
|||
|
||||
if args.half:
|
||||
state_dict = {key: value.half() for key, value in state_dict.items()}
|
||||
if args.output_path is None:
|
||||
args.output_path = f"{Path(args.source_path).stem}.safetensors"
|
||||
save_to_safetensors(path=args.output_path, tensors=state_dict)
|
||||
|
||||
|
||||
|
|
|
@ -121,6 +121,7 @@ class CLIPImageEncoder(fl.Chain):
|
|||
structural_attrs = [
|
||||
"image_size",
|
||||
"embedding_dim",
|
||||
"output_dim",
|
||||
"patch_size",
|
||||
"num_layers",
|
||||
"num_attention_heads",
|
||||
|
@ -142,6 +143,7 @@ class CLIPImageEncoder(fl.Chain):
|
|||
) -> None:
|
||||
self.image_size = image_size
|
||||
self.embedding_dim = embedding_dim
|
||||
self.output_dim = output_dim
|
||||
self.patch_size = patch_size
|
||||
self.num_layers = num_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
|
|
@ -15,6 +15,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
|
|||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import (
|
||||
SDXLUNet,
|
||||
DoubleTextEncoder,
|
||||
SDXLIPAdapter,
|
||||
)
|
||||
|
||||
|
||||
|
@ -26,6 +27,7 @@ __all__ = [
|
|||
"SD1IPAdapter",
|
||||
"SDXLUNet",
|
||||
"DoubleTextEncoder",
|
||||
"SDXLIPAdapter",
|
||||
"DPMSolver",
|
||||
"Scheduler",
|
||||
"CLIPTextEncoderL",
|
||||
|
|
|
@ -1,21 +1,23 @@
|
|||
from enum import IntEnum
|
||||
from functools import partial
|
||||
from typing import Generic, TypeVar, Any, Callable
|
||||
from typing import Generic, TypeVar, Any, Callable, TYPE_CHECKING
|
||||
|
||||
from torch import Tensor, as_tensor, cat, zeros_like, device as Device, dtype as DType
|
||||
from PIL import Image
|
||||
|
||||
from refiners.fluxion.adapters.adapter import Adapter
|
||||
from refiners.fluxion.adapters.lora import Lora
|
||||
from refiners.foundationals.clip.image_encoder import CLIPImageEncoder
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||
from refiners.fluxion.layers.module import Module
|
||||
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
||||
from refiners.fluxion.utils import image_to_tensor
|
||||
import refiners.fluxion.layers as fl
|
||||
|
||||
T = TypeVar("T", bound=SD1UNet | SDXLUNet)
|
||||
if TYPE_CHECKING:
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||
|
||||
T = TypeVar("T", bound="SD1UNet | SDXLUNet")
|
||||
TIPAdapter = TypeVar("TIPAdapter", bound="IPAdapter[Any]") # Self (see PEP 673)
|
||||
|
||||
|
||||
|
@ -128,7 +130,7 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
|||
def scale_outputs(self, x: Tensor) -> Tensor:
|
||||
return x * self.scale
|
||||
|
||||
def _predicate(self, k: type[Module]) -> Callable[[fl.Module, fl.Chain], bool]:
|
||||
def _predicate(self, k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
|
||||
def f(m: fl.Module, _: fl.Chain) -> bool:
|
||||
if isinstance(m, Lora): # do not adapt LoRAs
|
||||
raise StopIteration
|
||||
|
@ -167,15 +169,22 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
def __init__(
|
||||
self,
|
||||
target: T,
|
||||
clip_image_encoder: CLIPImageEncoder,
|
||||
clip_image_encoder: CLIPImageEncoderH | None = None,
|
||||
scale: float = 1.0,
|
||||
weights: dict[str, Tensor] | None = None,
|
||||
) -> None:
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(target)
|
||||
|
||||
self.clip_image_encoder = clip_image_encoder
|
||||
self.image_proj = ImageProjection(device=target.device, dtype=target.dtype)
|
||||
cross_attn_2d = target.ensure_find(CrossAttentionBlock2d)
|
||||
|
||||
self.clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)
|
||||
self.image_proj = ImageProjection(
|
||||
clip_image_embedding_dim=self.clip_image_encoder.output_dim,
|
||||
clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,
|
||||
device=target.device,
|
||||
dtype=target.dtype,
|
||||
)
|
||||
|
||||
self.sub_adapters = [
|
||||
CrossAttentionAdapter(target=cross_attn, scale=scale)
|
||||
|
|
|
@ -1,22 +1,6 @@
|
|||
from torch import Tensor
|
||||
|
||||
from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1UNet
|
||||
|
||||
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
||||
|
||||
|
||||
class SD1IPAdapter(IPAdapter[SD1UNet]):
|
||||
def __init__(
|
||||
self,
|
||||
target: SD1UNet,
|
||||
clip_image_encoder: CLIPImageEncoderH | None = None,
|
||||
scale: float = 1.0,
|
||||
weights: dict[str, Tensor] | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
target=target,
|
||||
clip_image_encoder=clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype),
|
||||
scale=scale,
|
||||
weights=weights,
|
||||
)
|
||||
pass
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SDXLUNet",
|
||||
"DoubleTextEncoder",
|
||||
"StableDiffusion_XL",
|
||||
"SDXLIPAdapter",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
|
||||
|
||||
|
||||
class SDXLIPAdapter(IPAdapter[SDXLUNet]):
|
||||
pass
|
|
@ -14,6 +14,7 @@ from refiners.foundationals.latent_diffusion import (
|
|||
SD1UNet,
|
||||
SD1ControlnetAdapter,
|
||||
SD1IPAdapter,
|
||||
SDXLIPAdapter,
|
||||
)
|
||||
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
|
||||
from refiners.foundationals.latent_diffusion.schedulers import DDIM
|
||||
|
@ -74,6 +75,11 @@ def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image:
|
|||
return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_sdxl_ip_adapter_woman(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image:
|
||||
return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert(mode="RGB")
|
||||
|
@ -217,6 +223,15 @@ def ip_adapter_weights(test_weights_path: Path) -> Path:
|
|||
return ip_adapter_weights
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sdxl_ip_adapter_weights(test_weights_path: Path) -> Path:
|
||||
ip_adapter_weights = test_weights_path / "ip-adapter_sdxl_vit-h.safetensors"
|
||||
if not ip_adapter_weights.is_file():
|
||||
warn(f"could not find weights at {ip_adapter_weights}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return ip_adapter_weights
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def image_encoder_weights(test_weights_path: Path) -> Path:
|
||||
image_encoder_weights = test_weights_path / "CLIPImageEncoderH.safetensors"
|
||||
|
@ -1050,6 +1065,64 @@ def test_diffusion_ip_adapter(
|
|||
ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_diffusion_sdxl_ip_adapter(
|
||||
sdxl_ddim: StableDiffusion_XL,
|
||||
sdxl_ip_adapter_weights: Path,
|
||||
image_encoder_weights: Path,
|
||||
woman_image: Image.Image,
|
||||
expected_image_sdxl_ip_adapter_woman: Image.Image,
|
||||
test_device: torch.device,
|
||||
):
|
||||
sdxl = sdxl_ddim.to(dtype=torch.float16)
|
||||
n_steps = 30
|
||||
|
||||
prompt = "best quality, high quality"
|
||||
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||
|
||||
ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_weights))
|
||||
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
||||
ip_adapter.inject()
|
||||
|
||||
with torch.no_grad():
|
||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||
text=prompt, negative_text=negative_prompt
|
||||
)
|
||||
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
|
||||
|
||||
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
|
||||
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
|
||||
|
||||
clip_text_embedding = torch.cat(
|
||||
(
|
||||
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
|
||||
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
|
||||
)
|
||||
)
|
||||
time_ids = sdxl.default_time_ids
|
||||
sdxl.set_num_inference_steps(n_steps)
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16)
|
||||
|
||||
with torch.no_grad():
|
||||
for step in sdxl.steps:
|
||||
x = sdxl(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
pooled_text_embedding=pooled_text_embedding,
|
||||
time_ids=time_ids,
|
||||
condition_scale=5,
|
||||
)
|
||||
# See https://huggingface.co/madebyollin/sdxl-vae-fp16-fix: "SDXL-VAE generates NaNs in fp16 because the
|
||||
# internal activation values are too big"
|
||||
sdxl.lda.to(dtype=torch.float32)
|
||||
predicted_image = sdxl.lda.decode_latents(x.to(dtype=torch.float32))
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_woman)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_sdxl_random_init(
|
||||
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device
|
||||
|
|
|
@ -35,7 +35,7 @@ output.images[0].save("std_random_init_expected.png")
|
|||
Special cases:
|
||||
|
||||
- `expected_refonly.png` has been generated [with Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui).
|
||||
- `expected_inpainting_refonly.png`, `expected_image_ip_adapter_woman.png` have been generated with refiners itself (and inspected so that it looks reasonable).
|
||||
- `expected_inpainting_refonly.png`, `expected_image_ip_adapter_woman.png`, `expected_image_sdxl_ip_adapter_woman.png` have been generated with refiners itself (and inspected so that they look reasonable).
|
||||
|
||||
## Other images
|
||||
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 1.1 MiB |
Loading…
Reference in a new issue