mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-25 07:38:45 +00:00
add IP-Adapter support for SD 1.5
Official repo: https://github.com/tencent-ailab/IP-Adapter
This commit is contained in:
parent
d4dd45fd4d
commit
c55917e293
111
scripts/conversion/convert_diffusers_ip_adapter.py
Normal file
111
scripts/conversion/convert_diffusers_ip_adapter.py
Normal file
|
@ -0,0 +1,111 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from refiners.foundationals.latent_diffusion import SD1UNet
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1IPAdapter
|
||||||
|
from refiners.fluxion.utils import save_to_safetensors
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description="Converts a IP-Adapter diffusers model to refiners.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--from",
|
||||||
|
type=str,
|
||||||
|
dest="source_path",
|
||||||
|
default="ip-adapter_sd15.bin",
|
||||||
|
help="Path to the source model. (default: 'ip-adapter_sd15.bin').",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--to",
|
||||||
|
type=str,
|
||||||
|
dest="output_path",
|
||||||
|
default="ip-adapter_sd15.safetensors",
|
||||||
|
help="Path to save the converted model. (default: 'ip-adapter_sd15.safetensors').",
|
||||||
|
)
|
||||||
|
parser.add_argument("--verbose", action="store_true", dest="verbose")
|
||||||
|
parser.add_argument("--half", action="store_true", dest="half")
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.output_path is None:
|
||||||
|
args.output_path = f"{Path(args.source_path).stem}.safetensors"
|
||||||
|
|
||||||
|
weights: dict[str, Any] = torch.load(f=args.source_path, map_location="cpu") # type: ignore
|
||||||
|
assert isinstance(weights, dict)
|
||||||
|
assert sorted(weights.keys()) == ["image_proj", "ip_adapter"]
|
||||||
|
|
||||||
|
unet = SD1UNet(in_channels=4)
|
||||||
|
|
||||||
|
ip_adapter = SD1IPAdapter(target=unet)
|
||||||
|
|
||||||
|
# Manual conversion to avoid any runtime dependency on IP-Adapter[1] custom classes
|
||||||
|
# [1]: https://github.com/tencent-ailab/IP-Adapter
|
||||||
|
|
||||||
|
state_dict: dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
|
image_proj_weights = weights["image_proj"]
|
||||||
|
image_proj_state_dict: dict[str, torch.Tensor] = {
|
||||||
|
"Linear.weight": image_proj_weights["proj.weight"],
|
||||||
|
"Linear.bias": image_proj_weights["proj.bias"],
|
||||||
|
"LayerNorm.weight": image_proj_weights["norm.weight"],
|
||||||
|
"LayerNorm.bias": image_proj_weights["norm.bias"],
|
||||||
|
}
|
||||||
|
ip_adapter.image_proj.load_state_dict(state_dict=image_proj_state_dict)
|
||||||
|
|
||||||
|
for k, v in image_proj_state_dict.items():
|
||||||
|
state_dict[f"image_proj.{k}"] = v
|
||||||
|
|
||||||
|
ip_adapter_weights: dict[str, torch.Tensor] = weights["ip_adapter"]
|
||||||
|
assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2
|
||||||
|
|
||||||
|
# Running:
|
||||||
|
#
|
||||||
|
# from diffusers import UNet2DConditionModel
|
||||||
|
# unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
|
||||||
|
# for k in unet.attn_processors.keys():
|
||||||
|
# print(k)
|
||||||
|
#
|
||||||
|
# Gives:
|
||||||
|
#
|
||||||
|
# down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor
|
||||||
|
# down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor
|
||||||
|
# ...
|
||||||
|
# down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor
|
||||||
|
# up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor
|
||||||
|
# up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor
|
||||||
|
# ...
|
||||||
|
# up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor
|
||||||
|
# mid_block.attentions.0.transformer_blocks.0.attn1.processor
|
||||||
|
# mid_block.attentions.0.transformer_blocks.0.attn2.processor
|
||||||
|
#
|
||||||
|
# With attn1=self-attention and attn2=cross-attention, and middle block in last position. So in terms of increasing
|
||||||
|
# indices:
|
||||||
|
#
|
||||||
|
# DownBlocks -> [1, 3, 5, 7, 9, 11]
|
||||||
|
# MiddleBlock -> [31]
|
||||||
|
# UpBlocks -> [13, 15, 17, 19, 21, 23, 25, 27, 29]
|
||||||
|
cross_attn_mapping: list[int] = [1, 3, 5, 7, 9, 11, 31, 13, 15, 17, 19, 21, 23, 25, 27, 29]
|
||||||
|
|
||||||
|
for i, cross_attn in enumerate(ip_adapter.sub_adapters):
|
||||||
|
cross_attn_index = cross_attn_mapping[i]
|
||||||
|
k_ip = f"{cross_attn_index}.to_k_ip.weight"
|
||||||
|
v_ip = f"{cross_attn_index}.to_v_ip.weight"
|
||||||
|
|
||||||
|
# Ignore Wq, Wk, Wv and Proj (hence strict=False): at runtime, they will be part of the UNet original weights
|
||||||
|
cross_attn_state_dict: dict[str, Any] = {
|
||||||
|
cross_attn.get_parameter_name("wk_prime"): ip_adapter_weights[k_ip],
|
||||||
|
cross_attn.get_parameter_name("wv_prime"): ip_adapter_weights[v_ip],
|
||||||
|
}
|
||||||
|
cross_attn.load_state_dict(state_dict=cross_attn_state_dict, strict=False)
|
||||||
|
|
||||||
|
for k, v in cross_attn_state_dict.items():
|
||||||
|
state_dict[f"ip_adapter.{i:03d}.{k}"] = v
|
||||||
|
|
||||||
|
if args.half:
|
||||||
|
state_dict = {key: value.half() for key, value in state_dict.items()}
|
||||||
|
save_to_safetensors(path=args.output_path, tensors=state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -10,6 +10,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
|
||||||
StableDiffusion_1_Inpainting,
|
StableDiffusion_1_Inpainting,
|
||||||
SD1UNet,
|
SD1UNet,
|
||||||
SD1ControlnetAdapter,
|
SD1ControlnetAdapter,
|
||||||
|
SD1IPAdapter,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import (
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import (
|
||||||
SDXLUNet,
|
SDXLUNet,
|
||||||
|
@ -22,6 +23,7 @@ __all__ = [
|
||||||
"StableDiffusion_1_Inpainting",
|
"StableDiffusion_1_Inpainting",
|
||||||
"SD1UNet",
|
"SD1UNet",
|
||||||
"SD1ControlnetAdapter",
|
"SD1ControlnetAdapter",
|
||||||
|
"SD1IPAdapter",
|
||||||
"SDXLUNet",
|
"SDXLUNet",
|
||||||
"DoubleTextEncoder",
|
"DoubleTextEncoder",
|
||||||
"DPMSolver",
|
"DPMSolver",
|
||||||
|
|
|
@ -4,10 +4,12 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
||||||
StableDiffusion_1_Inpainting,
|
StableDiffusion_1_Inpainting,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.image_prompt import SD1IPAdapter
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"StableDiffusion_1",
|
"StableDiffusion_1",
|
||||||
"StableDiffusion_1_Inpainting",
|
"StableDiffusion_1_Inpainting",
|
||||||
"SD1UNet",
|
"SD1UNet",
|
||||||
"SD1ControlnetAdapter",
|
"SD1ControlnetAdapter",
|
||||||
|
"SD1IPAdapter",
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,348 @@
|
||||||
|
from enum import IntEnum
|
||||||
|
from pathlib import Path
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from torch import Tensor, as_tensor, cat, zeros_like, device as Device, dtype as DType
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
|
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
||||||
|
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
|
||||||
|
class ImageProjection(fl.Chain):
|
||||||
|
structural_attrs = ["clip_image_embedding_dim", "clip_text_embedding_dim", "sequence_length"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
clip_image_embedding_dim: int = 1024,
|
||||||
|
clip_text_embedding_dim: int = 768,
|
||||||
|
sequence_length: int = 4,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.clip_image_embedding_dim = clip_image_embedding_dim
|
||||||
|
self.clip_text_embedding_dim = clip_text_embedding_dim
|
||||||
|
self.sequence_length = sequence_length
|
||||||
|
super().__init__(
|
||||||
|
fl.Linear(
|
||||||
|
in_features=clip_image_embedding_dim,
|
||||||
|
out_features=clip_text_embedding_dim * sequence_length,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.Reshape(sequence_length, clip_text_embedding_dim),
|
||||||
|
fl.LayerNorm(normalized_shape=clip_text_embedding_dim, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _CrossAttnIndex(IntEnum):
|
||||||
|
TXT_CROSS_ATTN = 0 # text cross-attention
|
||||||
|
IMG_CROSS_ATTN = 1 # image cross-attention
|
||||||
|
|
||||||
|
|
||||||
|
# Fluxion's Attention layer drop-in replacement implementing Decoupled Cross-Attention
|
||||||
|
class IPAttention(fl.Chain):
|
||||||
|
structural_attrs = [
|
||||||
|
"embedding_dim",
|
||||||
|
"text_sequence_length",
|
||||||
|
"image_sequence_length",
|
||||||
|
"scale",
|
||||||
|
"num_heads",
|
||||||
|
"heads_dim",
|
||||||
|
"key_embedding_dim",
|
||||||
|
"value_embedding_dim",
|
||||||
|
"inner_dim",
|
||||||
|
"use_bias",
|
||||||
|
"is_causal",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
text_sequence_length: int = 77,
|
||||||
|
image_sequence_length: int = 4,
|
||||||
|
scale: float = 1.0,
|
||||||
|
num_heads: int = 1,
|
||||||
|
key_embedding_dim: int | None = None,
|
||||||
|
value_embedding_dim: int | None = None,
|
||||||
|
inner_dim: int | None = None,
|
||||||
|
use_bias: bool = True,
|
||||||
|
is_causal: bool | None = None,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
assert (
|
||||||
|
embedding_dim % num_heads == 0
|
||||||
|
), f"embedding_dim {embedding_dim} must be divisible by num_heads {num_heads}"
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.text_sequence_length = text_sequence_length
|
||||||
|
self.image_sequence_length = image_sequence_length
|
||||||
|
self.scale = scale
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.heads_dim = embedding_dim // num_heads
|
||||||
|
self.key_embedding_dim = key_embedding_dim or embedding_dim
|
||||||
|
self.value_embedding_dim = value_embedding_dim or embedding_dim
|
||||||
|
self.inner_dim = inner_dim or embedding_dim
|
||||||
|
self.use_bias = use_bias
|
||||||
|
self.is_causal = is_causal
|
||||||
|
super().__init__(
|
||||||
|
fl.Distribute(
|
||||||
|
# Note: the same query is used for image cross-attention as for text cross-attention
|
||||||
|
fl.Linear(
|
||||||
|
in_features=self.embedding_dim,
|
||||||
|
out_features=self.inner_dim,
|
||||||
|
bias=self.use_bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
), # Wq
|
||||||
|
fl.Parallel(
|
||||||
|
fl.Chain(
|
||||||
|
fl.Slicing(dim=1, start=0, length=text_sequence_length),
|
||||||
|
fl.Linear(
|
||||||
|
in_features=self.key_embedding_dim,
|
||||||
|
out_features=self.inner_dim,
|
||||||
|
bias=self.use_bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
), # Wk
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length),
|
||||||
|
fl.Linear(
|
||||||
|
in_features=self.key_embedding_dim,
|
||||||
|
out_features=self.inner_dim,
|
||||||
|
bias=self.use_bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
), # Wk'
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Parallel(
|
||||||
|
fl.Chain(
|
||||||
|
fl.Slicing(dim=1, start=0, length=text_sequence_length),
|
||||||
|
fl.Linear(
|
||||||
|
in_features=self.key_embedding_dim,
|
||||||
|
out_features=self.inner_dim,
|
||||||
|
bias=self.use_bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
), # Wv
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length),
|
||||||
|
fl.Linear(
|
||||||
|
in_features=self.key_embedding_dim,
|
||||||
|
out_features=self.inner_dim,
|
||||||
|
bias=self.use_bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
), # Wv'
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Sum(
|
||||||
|
fl.Chain(
|
||||||
|
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.TXT_CROSS_ATTN)),
|
||||||
|
ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.IMG_CROSS_ATTN)),
|
||||||
|
ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal),
|
||||||
|
fl.Lambda(func=self.scale_outputs),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Linear(
|
||||||
|
in_features=self.inner_dim,
|
||||||
|
out_features=self.embedding_dim,
|
||||||
|
bias=True,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def select_qkv(
|
||||||
|
self, query: Tensor, keys: tuple[Tensor, Tensor], values: tuple[Tensor, Tensor], index: _CrossAttnIndex
|
||||||
|
) -> tuple[Tensor, Tensor, Tensor]:
|
||||||
|
return (query, keys[index.value], values[index.value])
|
||||||
|
|
||||||
|
def scale_outputs(self, x: Tensor) -> Tensor:
|
||||||
|
return x * self.scale
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
||||||
|
structural_attrs = ["text_sequence_length", "image_sequence_length", "scale"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: fl.Attention,
|
||||||
|
text_sequence_length: int = 77,
|
||||||
|
image_sequence_length: int = 4,
|
||||||
|
scale: float = 1.0,
|
||||||
|
) -> None:
|
||||||
|
self.text_sequence_length = text_sequence_length
|
||||||
|
self.image_sequence_length = image_sequence_length
|
||||||
|
self.scale = scale
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(
|
||||||
|
IPAttention(
|
||||||
|
embedding_dim=target.embedding_dim,
|
||||||
|
text_sequence_length=text_sequence_length,
|
||||||
|
image_sequence_length=image_sequence_length,
|
||||||
|
scale=scale,
|
||||||
|
num_heads=target.num_heads,
|
||||||
|
key_embedding_dim=target.key_embedding_dim,
|
||||||
|
value_embedding_dim=target.value_embedding_dim,
|
||||||
|
inner_dim=target.inner_dim,
|
||||||
|
use_bias=target.use_bias,
|
||||||
|
is_causal=target.is_causal,
|
||||||
|
device=target.device,
|
||||||
|
dtype=target.dtype,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_parameter_name(self, matrix: str, bias: bool = False) -> str:
|
||||||
|
match matrix:
|
||||||
|
case "wq":
|
||||||
|
index = 0
|
||||||
|
case "wk":
|
||||||
|
index = 1
|
||||||
|
case "wk_prime":
|
||||||
|
index = 2
|
||||||
|
case "wv":
|
||||||
|
index = 3
|
||||||
|
case "wv_prime":
|
||||||
|
index = 4
|
||||||
|
case "proj":
|
||||||
|
index = 5
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unexpected matrix name {matrix}")
|
||||||
|
|
||||||
|
linear = list(self.IPAttention.layers(fl.Linear))[index]
|
||||||
|
param = getattr(linear, "bias" if bias else "weight")
|
||||||
|
name = next((n for n, p in self.named_parameters() if id(p) == id(param)), None)
|
||||||
|
assert name is not None
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
class SD1IPAdapter(fl.Chain, Adapter[SD1UNet]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: SD1UNet,
|
||||||
|
clip_image_encoder: CLIPImageEncoderH | None = None,
|
||||||
|
scale: float = 1.0,
|
||||||
|
weights: dict[str, Tensor] | None = None,
|
||||||
|
) -> None:
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(target)
|
||||||
|
|
||||||
|
self.clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)
|
||||||
|
self.image_proj = ImageProjection(device=target.device, dtype=target.dtype)
|
||||||
|
|
||||||
|
self.sub_adapters = [
|
||||||
|
CrossAttentionAdapter(target=cross_attn, scale=scale)
|
||||||
|
for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention))
|
||||||
|
]
|
||||||
|
|
||||||
|
if weights is not None:
|
||||||
|
image_proj_state_dict: dict[str, Tensor] = {
|
||||||
|
k.removeprefix("image_proj."): v for k, v in weights.items() if k.startswith("image_proj.")
|
||||||
|
}
|
||||||
|
self.image_proj.load_state_dict(image_proj_state_dict)
|
||||||
|
|
||||||
|
for i, cross_attn in enumerate(self.sub_adapters):
|
||||||
|
cross_attn_state_dict: dict[str, Tensor] = {}
|
||||||
|
for k, v in weights.items():
|
||||||
|
prefix = f"ip_adapter.{i:03d}."
|
||||||
|
if not k.startswith(prefix):
|
||||||
|
continue
|
||||||
|
cross_attn_state_dict[k.removeprefix(prefix)] = v
|
||||||
|
|
||||||
|
# Retrieve original (frozen) cross-attention weights
|
||||||
|
# Note: this assumes the target UNet has already loaded weights
|
||||||
|
cross_attn_linears = list(cross_attn.target.layers(fl.Linear))
|
||||||
|
assert len(cross_attn_linears) == 4 # Wq, Wk, Wv and Proj
|
||||||
|
|
||||||
|
cross_attn_state_dict[cross_attn.get_parameter_name("wq")] = cross_attn_linears[0].weight
|
||||||
|
cross_attn_state_dict[cross_attn.get_parameter_name("wk")] = cross_attn_linears[1].weight
|
||||||
|
cross_attn_state_dict[cross_attn.get_parameter_name("wv")] = cross_attn_linears[2].weight
|
||||||
|
cross_attn_state_dict[cross_attn.get_parameter_name("proj")] = cross_attn_linears[3].weight
|
||||||
|
cross_attn_state_dict[cross_attn.get_parameter_name("proj", bias=True)] = cross_attn_linears[3].bias
|
||||||
|
|
||||||
|
cross_attn.load_state_dict(state_dict=cross_attn_state_dict)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_safetensors(
|
||||||
|
cls,
|
||||||
|
target: SD1UNet,
|
||||||
|
checkpoint_path: Path | str,
|
||||||
|
clip_image_encoder: CLIPImageEncoderH | None = None,
|
||||||
|
scale: float = 1.0,
|
||||||
|
):
|
||||||
|
weights = load_from_safetensors(checkpoint_path, device=target.device if target.device is not None else "cpu")
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
target=target,
|
||||||
|
clip_image_encoder=clip_image_encoder,
|
||||||
|
scale=scale,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def inject(self: "SD1IPAdapter", parent: fl.Chain | None = None) -> "SD1IPAdapter":
|
||||||
|
for adapter in self.sub_adapters:
|
||||||
|
adapter.inject()
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
for adapter in self.sub_adapters:
|
||||||
|
adapter.eject()
|
||||||
|
super().eject()
|
||||||
|
|
||||||
|
# These should be concatenated to the CLIP text embedding before setting the UNet context
|
||||||
|
def compute_clip_image_embedding(self, image_prompt: Tensor | None) -> Tensor:
|
||||||
|
clip_embedding = self.clip_image_encoder(image_prompt)
|
||||||
|
conditional_embedding = self.image_proj(clip_embedding)
|
||||||
|
negative_embedding = self.image_proj(zeros_like(clip_embedding))
|
||||||
|
return cat((negative_embedding, conditional_embedding))
|
||||||
|
|
||||||
|
def preprocess_image(
|
||||||
|
self,
|
||||||
|
image: Image.Image,
|
||||||
|
size: tuple[int, int] = (224, 224),
|
||||||
|
mean: list[float] | None = None,
|
||||||
|
std: list[float] | None = None,
|
||||||
|
) -> Tensor:
|
||||||
|
# Default mean and std are parameters from https://github.com/openai/CLIP
|
||||||
|
return self._normalize(
|
||||||
|
image_to_tensor(image.resize(size), device=self.target.device, dtype=self.target.dtype),
|
||||||
|
mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean,
|
||||||
|
std=[0.26862954, 0.26130258, 0.27577711] if std is None else std,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py
|
||||||
|
@staticmethod
|
||||||
|
def _normalize(tensor: Tensor, mean: list[float], std: list[float], inplace: bool = False) -> Tensor:
|
||||||
|
assert tensor.is_floating_point()
|
||||||
|
assert tensor.ndim >= 3
|
||||||
|
|
||||||
|
if not inplace:
|
||||||
|
tensor = tensor.clone()
|
||||||
|
|
||||||
|
dtype = tensor.dtype
|
||||||
|
|
||||||
|
mean_tensor = as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
std_tensor = as_tensor(std, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
|
||||||
|
if (std_tensor == 0).any():
|
||||||
|
raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
|
||||||
|
|
||||||
|
if mean_tensor.ndim == 1:
|
||||||
|
mean_tensor = mean_tensor.view(-1, 1, 1)
|
||||||
|
|
||||||
|
if std_tensor.ndim == 1:
|
||||||
|
std_tensor = std_tensor.view(-1, 1, 1)
|
||||||
|
|
||||||
|
return tensor.sub_(mean_tensor).div_(std_tensor)
|
|
@ -13,6 +13,7 @@ from refiners.foundationals.latent_diffusion import (
|
||||||
StableDiffusion_1_Inpainting,
|
StableDiffusion_1_Inpainting,
|
||||||
SD1UNet,
|
SD1UNet,
|
||||||
SD1ControlnetAdapter,
|
SD1ControlnetAdapter,
|
||||||
|
SD1IPAdapter,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
|
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
|
||||||
from refiners.foundationals.latent_diffusion.schedulers import DDIM
|
from refiners.foundationals.latent_diffusion.schedulers import DDIM
|
||||||
|
@ -42,6 +43,11 @@ def kitchen_dog_mask(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "kitchen_dog_mask.png").convert("RGB")
|
return Image.open(ref_path / "kitchen_dog_mask.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def woman_image(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(ref_path / "woman.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_std_random_init(ref_path: Path) -> Image.Image:
|
def expected_image_std_random_init(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_std_random_init.png").convert("RGB")
|
return Image.open(ref_path / "expected_std_random_init.png").convert("RGB")
|
||||||
|
@ -62,6 +68,11 @@ def expected_image_controlnet_stack(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_controlnet_stack.png").convert("RGB")
|
return Image.open(ref_path / "expected_controlnet_stack.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
|
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
|
||||||
def controlnet_data(
|
def controlnet_data(
|
||||||
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
|
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
|
||||||
|
@ -182,6 +193,33 @@ def unet_weights_inpainting(test_weights_path: Path) -> Path:
|
||||||
return unet_weights_inpainting
|
return unet_weights_inpainting
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def lda_ft_mse_weights(test_weights_path: Path) -> Path:
|
||||||
|
lda_weights = test_weights_path / "lda_ft_mse.safetensors"
|
||||||
|
if not lda_weights.is_file():
|
||||||
|
warn(f"could not find weights at {lda_weights}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
return lda_weights
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def ip_adapter_weights(test_weights_path: Path) -> Path:
|
||||||
|
ip_adapter_weights = test_weights_path / "ip-adapter_sd15.safetensors"
|
||||||
|
if not ip_adapter_weights.is_file():
|
||||||
|
warn(f"could not find weights at {ip_adapter_weights}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
return ip_adapter_weights
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def image_encoder_weights(test_weights_path: Path) -> Path:
|
||||||
|
image_encoder_weights = test_weights_path / "CLIPImageEncoderH.safetensors"
|
||||||
|
if not image_encoder_weights.is_file():
|
||||||
|
warn(f"could not find weights at {image_encoder_weights}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
return image_encoder_weights
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sd15_std(
|
def sd15_std(
|
||||||
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
|
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
|
||||||
|
@ -270,6 +308,24 @@ def sd15_ddim(
|
||||||
return sd15
|
return sd15
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sd15_ddim_lda_ft_mse(
|
||||||
|
text_encoder_weights: Path, lda_ft_mse_weights: Path, unet_weights_std: Path, test_device: torch.device
|
||||||
|
) -> StableDiffusion_1:
|
||||||
|
if test_device.type == "cpu":
|
||||||
|
warn("not running on CPU, skipping")
|
||||||
|
pytest.skip()
|
||||||
|
|
||||||
|
ddim_scheduler = DDIM(num_inference_steps=20)
|
||||||
|
sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device)
|
||||||
|
|
||||||
|
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
||||||
|
sd15.lda.load_state_dict(load_from_safetensors(lda_ft_mse_weights))
|
||||||
|
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std))
|
||||||
|
|
||||||
|
return sd15
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test_diffusion_std_random_init(
|
def test_diffusion_std_random_init(
|
||||||
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
|
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
|
||||||
|
@ -844,3 +900,60 @@ def test_diffusion_textual_inversion_random_init(
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
|
||||||
ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_diffusion_ip_adapter(
|
||||||
|
sd15_ddim_lda_ft_mse: StableDiffusion_1,
|
||||||
|
ip_adapter_weights: Path,
|
||||||
|
image_encoder_weights: Path,
|
||||||
|
woman_image: Image.Image,
|
||||||
|
expected_image_ip_adapter_woman: Image.Image,
|
||||||
|
test_device: torch.device,
|
||||||
|
):
|
||||||
|
sd15 = sd15_ddim_lda_ft_mse.to(dtype=torch.float16)
|
||||||
|
n_steps = 50
|
||||||
|
|
||||||
|
# See tencent-ailab/IP-Adapter best practices section:
|
||||||
|
#
|
||||||
|
# If you only use the image prompt, you can set the scale=1.0 and text_prompt="" (or some generic text
|
||||||
|
# prompts, e.g. "best quality", you can also use any negative text prompt).
|
||||||
|
#
|
||||||
|
# The prompts below are the ones used by default by IPAdapter's generate method if none are specified
|
||||||
|
prompt = "best quality, high quality"
|
||||||
|
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||||
|
|
||||||
|
ip_adapter = SD1IPAdapter.from_safetensors(target=sd15.unet, checkpoint_path=ip_adapter_weights)
|
||||||
|
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
||||||
|
ip_adapter.inject()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
|
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
|
||||||
|
|
||||||
|
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
|
||||||
|
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
|
||||||
|
|
||||||
|
clip_text_embedding = torch.cat(
|
||||||
|
(
|
||||||
|
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
|
||||||
|
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
manual_seed(2)
|
||||||
|
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for step in sd15.steps:
|
||||||
|
x = sd15(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
condition_scale=7.5,
|
||||||
|
)
|
||||||
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
|
||||||
|
ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)
|
||||||
|
|
|
@ -35,7 +35,7 @@ output.images[0].save("std_random_init_expected.png")
|
||||||
Special cases:
|
Special cases:
|
||||||
|
|
||||||
- `expected_refonly.png` has been generated [with Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui).
|
- `expected_refonly.png` has been generated [with Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui).
|
||||||
- `expected_inpainting_refonly.png` has been generated with refiners itself (and inspected so that it looks reasonable).
|
- `expected_inpainting_refonly.png`, `expected_image_ip_adapter_woman.png` have been generated with refiners itself (and inspected so that it looks reasonable).
|
||||||
|
|
||||||
## Other images
|
## Other images
|
||||||
|
|
||||||
|
@ -59,6 +59,8 @@ Special cases:
|
||||||
- `inpainting-scene.png`: cropped-to-square-and-resized version of https://unsplash.com/photos/RCz6eSVPGYU by @jannerboy62
|
- `inpainting-scene.png`: cropped-to-square-and-resized version of https://unsplash.com/photos/RCz6eSVPGYU by @jannerboy62
|
||||||
- `inpainting-target.png`: computed with `convert <(convert -size 512x512 xc:white png:-) kitchen_dog.png <(convert inpainting-mask.png -negate png:-) -compose Over -composite inpainting-target.png`
|
- `inpainting-target.png`: computed with `convert <(convert -size 512x512 xc:white png:-) kitchen_dog.png <(convert inpainting-mask.png -negate png:-) -compose Over -composite inpainting-target.png`
|
||||||
|
|
||||||
|
- `woman.png` [comes from tencent-ailab/IP-Adapter](https://github.com/tencent-ailab/IP-Adapter/blob/8b96670cc5c8ef00278b42c0c7b62fe8a74510b9/assets/images/woman.png).
|
||||||
|
|
||||||
## VAE without randomness
|
## VAE without randomness
|
||||||
|
|
||||||
```diff
|
```diff
|
||||||
|
|
BIN
tests/e2e/test_diffusion_ref/expected_image_ip_adapter_woman.png
Normal file
BIN
tests/e2e/test_diffusion_ref/expected_image_ip_adapter_woman.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 317 KiB |
BIN
tests/e2e/test_diffusion_ref/woman.png
Normal file
BIN
tests/e2e/test_diffusion_ref/woman.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.2 MiB |
Loading…
Reference in a new issue