mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
add support for SDXL IP-Adapter
This only supports the latest SDXL IP-Adapter release (2023.9.8) which builds upon the ViT-H/14 CLIP image encoder.
This commit is contained in:
parent
1b4dcebe06
commit
eea340c6c4
|
@ -4,60 +4,9 @@ import argparse
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion import SD1UNet, SD1IPAdapter
|
from refiners.foundationals.latent_diffusion import SD1UNet, SD1IPAdapter, SDXLUNet, SDXLIPAdapter
|
||||||
from refiners.fluxion.utils import save_to_safetensors
|
from refiners.fluxion.utils import save_to_safetensors
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
parser = argparse.ArgumentParser(description="Converts a IP-Adapter diffusers model to refiners.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--from",
|
|
||||||
type=str,
|
|
||||||
dest="source_path",
|
|
||||||
default="ip-adapter_sd15.bin",
|
|
||||||
help="Path to the source model. (default: 'ip-adapter_sd15.bin').",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--to",
|
|
||||||
type=str,
|
|
||||||
dest="output_path",
|
|
||||||
default="ip-adapter_sd15.safetensors",
|
|
||||||
help="Path to save the converted model. (default: 'ip-adapter_sd15.safetensors').",
|
|
||||||
)
|
|
||||||
parser.add_argument("--verbose", action="store_true", dest="verbose")
|
|
||||||
parser.add_argument("--half", action="store_true", dest="half")
|
|
||||||
args = parser.parse_args()
|
|
||||||
if args.output_path is None:
|
|
||||||
args.output_path = f"{Path(args.source_path).stem}.safetensors"
|
|
||||||
|
|
||||||
weights: dict[str, Any] = torch.load(f=args.source_path, map_location="cpu") # type: ignore
|
|
||||||
assert isinstance(weights, dict)
|
|
||||||
assert sorted(weights.keys()) == ["image_proj", "ip_adapter"]
|
|
||||||
|
|
||||||
unet = SD1UNet(in_channels=4)
|
|
||||||
|
|
||||||
ip_adapter = SD1IPAdapter(target=unet)
|
|
||||||
|
|
||||||
# Manual conversion to avoid any runtime dependency on IP-Adapter[1] custom classes
|
|
||||||
# [1]: https://github.com/tencent-ailab/IP-Adapter
|
|
||||||
|
|
||||||
state_dict: dict[str, torch.Tensor] = {}
|
|
||||||
|
|
||||||
image_proj_weights = weights["image_proj"]
|
|
||||||
image_proj_state_dict: dict[str, torch.Tensor] = {
|
|
||||||
"Linear.weight": image_proj_weights["proj.weight"],
|
|
||||||
"Linear.bias": image_proj_weights["proj.bias"],
|
|
||||||
"LayerNorm.weight": image_proj_weights["norm.weight"],
|
|
||||||
"LayerNorm.bias": image_proj_weights["norm.bias"],
|
|
||||||
}
|
|
||||||
ip_adapter.image_proj.load_state_dict(state_dict=image_proj_state_dict)
|
|
||||||
|
|
||||||
for k, v in image_proj_state_dict.items():
|
|
||||||
state_dict[f"image_proj.{k}"] = v
|
|
||||||
|
|
||||||
ip_adapter_weights: dict[str, torch.Tensor] = weights["ip_adapter"]
|
|
||||||
assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2
|
|
||||||
|
|
||||||
# Running:
|
# Running:
|
||||||
#
|
#
|
||||||
# from diffusers import UNet2DConditionModel
|
# from diffusers import UNet2DConditionModel
|
||||||
|
@ -84,7 +33,72 @@ def main() -> None:
|
||||||
# DownBlocks -> [1, 3, 5, 7, 9, 11]
|
# DownBlocks -> [1, 3, 5, 7, 9, 11]
|
||||||
# MiddleBlock -> [31]
|
# MiddleBlock -> [31]
|
||||||
# UpBlocks -> [13, 15, 17, 19, 21, 23, 25, 27, 29]
|
# UpBlocks -> [13, 15, 17, 19, 21, 23, 25, 27, 29]
|
||||||
cross_attn_mapping: list[int] = [1, 3, 5, 7, 9, 11, 31, 13, 15, 17, 19, 21, 23, 25, 27, 29]
|
#
|
||||||
|
# Same for SDXL with more layers (70 cross-attentions vs. 16)
|
||||||
|
CROSS_ATTN_MAPPING: dict[str, list[int]] = {
|
||||||
|
"sd15": list(range(1, 12, 2)) + [31] + list(range(13, 30, 2)),
|
||||||
|
"sdxl": list(range(1, 48, 2)) + list(range(121, 140, 2)) + list(range(49, 120, 2)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description="Converts a IP-Adapter diffusers model to refiners.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--from",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
dest="source_path",
|
||||||
|
help="Path to the source model. (e.g.: 'ip-adapter_sd15.bin').",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--to",
|
||||||
|
type=str,
|
||||||
|
dest="output_path",
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Path to save the converted model. If not specified, the output path will be the source path with the"
|
||||||
|
" extension changed to .safetensors."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument("--verbose", action="store_true", dest="verbose")
|
||||||
|
parser.add_argument("--half", action="store_true", dest="half")
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.output_path is None:
|
||||||
|
args.output_path = f"{Path(args.source_path).stem}.safetensors"
|
||||||
|
|
||||||
|
weights: dict[str, Any] = torch.load(f=args.source_path, map_location="cpu") # type: ignore
|
||||||
|
assert isinstance(weights, dict)
|
||||||
|
assert sorted(weights.keys()) == ["image_proj", "ip_adapter"]
|
||||||
|
|
||||||
|
match len(weights["ip_adapter"]):
|
||||||
|
case 32:
|
||||||
|
ip_adapter = SD1IPAdapter(target=SD1UNet(in_channels=4))
|
||||||
|
cross_attn_mapping = CROSS_ATTN_MAPPING["sd15"]
|
||||||
|
case 140:
|
||||||
|
ip_adapter = SDXLIPAdapter(target=SDXLUNet(in_channels=4))
|
||||||
|
cross_attn_mapping = CROSS_ATTN_MAPPING["sdxl"]
|
||||||
|
case _:
|
||||||
|
raise ValueError("Unexpected number of keys in input checkpoint")
|
||||||
|
|
||||||
|
# Manual conversion to avoid any runtime dependency on IP-Adapter[1] custom classes
|
||||||
|
# [1]: https://github.com/tencent-ailab/IP-Adapter
|
||||||
|
|
||||||
|
state_dict: dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
|
image_proj_weights = weights["image_proj"]
|
||||||
|
image_proj_state_dict: dict[str, torch.Tensor] = {
|
||||||
|
"Linear.weight": image_proj_weights["proj.weight"],
|
||||||
|
"Linear.bias": image_proj_weights["proj.bias"],
|
||||||
|
"LayerNorm.weight": image_proj_weights["norm.weight"],
|
||||||
|
"LayerNorm.bias": image_proj_weights["norm.bias"],
|
||||||
|
}
|
||||||
|
ip_adapter.image_proj.load_state_dict(state_dict=image_proj_state_dict)
|
||||||
|
|
||||||
|
for k, v in image_proj_state_dict.items():
|
||||||
|
state_dict[f"image_proj.{k}"] = v
|
||||||
|
|
||||||
|
ip_adapter_weights: dict[str, torch.Tensor] = weights["ip_adapter"]
|
||||||
|
assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2
|
||||||
|
|
||||||
for i, cross_attn in enumerate(ip_adapter.sub_adapters):
|
for i, cross_attn in enumerate(ip_adapter.sub_adapters):
|
||||||
cross_attn_index = cross_attn_mapping[i]
|
cross_attn_index = cross_attn_mapping[i]
|
||||||
|
@ -107,6 +121,8 @@ def main() -> None:
|
||||||
|
|
||||||
if args.half:
|
if args.half:
|
||||||
state_dict = {key: value.half() for key, value in state_dict.items()}
|
state_dict = {key: value.half() for key, value in state_dict.items()}
|
||||||
|
if args.output_path is None:
|
||||||
|
args.output_path = f"{Path(args.source_path).stem}.safetensors"
|
||||||
save_to_safetensors(path=args.output_path, tensors=state_dict)
|
save_to_safetensors(path=args.output_path, tensors=state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -121,6 +121,7 @@ class CLIPImageEncoder(fl.Chain):
|
||||||
structural_attrs = [
|
structural_attrs = [
|
||||||
"image_size",
|
"image_size",
|
||||||
"embedding_dim",
|
"embedding_dim",
|
||||||
|
"output_dim",
|
||||||
"patch_size",
|
"patch_size",
|
||||||
"num_layers",
|
"num_layers",
|
||||||
"num_attention_heads",
|
"num_attention_heads",
|
||||||
|
@ -142,6 +143,7 @@ class CLIPImageEncoder(fl.Chain):
|
||||||
) -> None:
|
) -> None:
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
|
|
|
@ -15,6 +15,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import (
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import (
|
||||||
SDXLUNet,
|
SDXLUNet,
|
||||||
DoubleTextEncoder,
|
DoubleTextEncoder,
|
||||||
|
SDXLIPAdapter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,6 +27,7 @@ __all__ = [
|
||||||
"SD1IPAdapter",
|
"SD1IPAdapter",
|
||||||
"SDXLUNet",
|
"SDXLUNet",
|
||||||
"DoubleTextEncoder",
|
"DoubleTextEncoder",
|
||||||
|
"SDXLIPAdapter",
|
||||||
"DPMSolver",
|
"DPMSolver",
|
||||||
"Scheduler",
|
"Scheduler",
|
||||||
"CLIPTextEncoderL",
|
"CLIPTextEncoderL",
|
||||||
|
|
|
@ -1,21 +1,23 @@
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Generic, TypeVar, Any, Callable
|
from typing import Generic, TypeVar, Any, Callable, TYPE_CHECKING
|
||||||
|
|
||||||
from torch import Tensor, as_tensor, cat, zeros_like, device as Device, dtype as DType
|
from torch import Tensor, as_tensor, cat, zeros_like, device as Device, dtype as DType
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from refiners.fluxion.adapters.adapter import Adapter
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
from refiners.fluxion.adapters.lora import Lora
|
from refiners.fluxion.adapters.lora import Lora
|
||||||
from refiners.foundationals.clip.image_encoder import CLIPImageEncoder
|
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
|
||||||
from refiners.fluxion.layers.module import Module
|
|
||||||
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
||||||
from refiners.fluxion.utils import image_to_tensor
|
from refiners.fluxion.utils import image_to_tensor
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
T = TypeVar("T", bound=SD1UNet | SDXLUNet)
|
if TYPE_CHECKING:
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="SD1UNet | SDXLUNet")
|
||||||
TIPAdapter = TypeVar("TIPAdapter", bound="IPAdapter[Any]") # Self (see PEP 673)
|
TIPAdapter = TypeVar("TIPAdapter", bound="IPAdapter[Any]") # Self (see PEP 673)
|
||||||
|
|
||||||
|
|
||||||
|
@ -128,7 +130,7 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
||||||
def scale_outputs(self, x: Tensor) -> Tensor:
|
def scale_outputs(self, x: Tensor) -> Tensor:
|
||||||
return x * self.scale
|
return x * self.scale
|
||||||
|
|
||||||
def _predicate(self, k: type[Module]) -> Callable[[fl.Module, fl.Chain], bool]:
|
def _predicate(self, k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
|
||||||
def f(m: fl.Module, _: fl.Chain) -> bool:
|
def f(m: fl.Module, _: fl.Chain) -> bool:
|
||||||
if isinstance(m, Lora): # do not adapt LoRAs
|
if isinstance(m, Lora): # do not adapt LoRAs
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
|
@ -167,15 +169,22 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
target: T,
|
target: T,
|
||||||
clip_image_encoder: CLIPImageEncoder,
|
clip_image_encoder: CLIPImageEncoderH | None = None,
|
||||||
scale: float = 1.0,
|
scale: float = 1.0,
|
||||||
weights: dict[str, Tensor] | None = None,
|
weights: dict[str, Tensor] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
with self.setup_adapter(target):
|
with self.setup_adapter(target):
|
||||||
super().__init__(target)
|
super().__init__(target)
|
||||||
|
|
||||||
self.clip_image_encoder = clip_image_encoder
|
cross_attn_2d = target.ensure_find(CrossAttentionBlock2d)
|
||||||
self.image_proj = ImageProjection(device=target.device, dtype=target.dtype)
|
|
||||||
|
self.clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)
|
||||||
|
self.image_proj = ImageProjection(
|
||||||
|
clip_image_embedding_dim=self.clip_image_encoder.output_dim,
|
||||||
|
clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,
|
||||||
|
device=target.device,
|
||||||
|
dtype=target.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
self.sub_adapters = [
|
self.sub_adapters = [
|
||||||
CrossAttentionAdapter(target=cross_attn, scale=scale)
|
CrossAttentionAdapter(target=cross_attn, scale=scale)
|
||||||
|
|
|
@ -1,22 +1,6 @@
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter
|
from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1UNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1UNet
|
||||||
|
|
||||||
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
|
||||||
|
|
||||||
|
|
||||||
class SD1IPAdapter(IPAdapter[SD1UNet]):
|
class SD1IPAdapter(IPAdapter[SD1UNet]):
|
||||||
def __init__(
|
pass
|
||||||
self,
|
|
||||||
target: SD1UNet,
|
|
||||||
clip_image_encoder: CLIPImageEncoderH | None = None,
|
|
||||||
scale: float = 1.0,
|
|
||||||
weights: dict[str, Tensor] | None = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(
|
|
||||||
target=target,
|
|
||||||
clip_image_encoder=clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype),
|
|
||||||
scale=scale,
|
|
||||||
weights=weights,
|
|
||||||
)
|
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"SDXLUNet",
|
"SDXLUNet",
|
||||||
"DoubleTextEncoder",
|
"DoubleTextEncoder",
|
||||||
"StableDiffusion_XL",
|
"StableDiffusion_XL",
|
||||||
|
"SDXLIPAdapter",
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLIPAdapter(IPAdapter[SDXLUNet]):
|
||||||
|
pass
|
|
@ -14,6 +14,7 @@ from refiners.foundationals.latent_diffusion import (
|
||||||
SD1UNet,
|
SD1UNet,
|
||||||
SD1ControlnetAdapter,
|
SD1ControlnetAdapter,
|
||||||
SD1IPAdapter,
|
SD1IPAdapter,
|
||||||
|
SDXLIPAdapter,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
|
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
|
||||||
from refiners.foundationals.latent_diffusion.schedulers import DDIM
|
from refiners.foundationals.latent_diffusion.schedulers import DDIM
|
||||||
|
@ -74,6 +75,11 @@ def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB")
|
return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_image_sdxl_ip_adapter_woman(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image:
|
def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert(mode="RGB")
|
return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert(mode="RGB")
|
||||||
|
@ -217,6 +223,15 @@ def ip_adapter_weights(test_weights_path: Path) -> Path:
|
||||||
return ip_adapter_weights
|
return ip_adapter_weights
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def sdxl_ip_adapter_weights(test_weights_path: Path) -> Path:
|
||||||
|
ip_adapter_weights = test_weights_path / "ip-adapter_sdxl_vit-h.safetensors"
|
||||||
|
if not ip_adapter_weights.is_file():
|
||||||
|
warn(f"could not find weights at {ip_adapter_weights}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
return ip_adapter_weights
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def image_encoder_weights(test_weights_path: Path) -> Path:
|
def image_encoder_weights(test_weights_path: Path) -> Path:
|
||||||
image_encoder_weights = test_weights_path / "CLIPImageEncoderH.safetensors"
|
image_encoder_weights = test_weights_path / "CLIPImageEncoderH.safetensors"
|
||||||
|
@ -1050,6 +1065,64 @@ def test_diffusion_ip_adapter(
|
||||||
ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)
|
ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_diffusion_sdxl_ip_adapter(
|
||||||
|
sdxl_ddim: StableDiffusion_XL,
|
||||||
|
sdxl_ip_adapter_weights: Path,
|
||||||
|
image_encoder_weights: Path,
|
||||||
|
woman_image: Image.Image,
|
||||||
|
expected_image_sdxl_ip_adapter_woman: Image.Image,
|
||||||
|
test_device: torch.device,
|
||||||
|
):
|
||||||
|
sdxl = sdxl_ddim.to(dtype=torch.float16)
|
||||||
|
n_steps = 30
|
||||||
|
|
||||||
|
prompt = "best quality, high quality"
|
||||||
|
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||||
|
|
||||||
|
ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_weights))
|
||||||
|
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
||||||
|
ip_adapter.inject()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||||
|
text=prompt, negative_text=negative_prompt
|
||||||
|
)
|
||||||
|
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
|
||||||
|
|
||||||
|
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
|
||||||
|
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
|
||||||
|
|
||||||
|
clip_text_embedding = torch.cat(
|
||||||
|
(
|
||||||
|
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
|
||||||
|
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
time_ids = sdxl.default_time_ids
|
||||||
|
sdxl.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
manual_seed(2)
|
||||||
|
x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for step in sdxl.steps:
|
||||||
|
x = sdxl(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
pooled_text_embedding=pooled_text_embedding,
|
||||||
|
time_ids=time_ids,
|
||||||
|
condition_scale=5,
|
||||||
|
)
|
||||||
|
# See https://huggingface.co/madebyollin/sdxl-vae-fp16-fix: "SDXL-VAE generates NaNs in fp16 because the
|
||||||
|
# internal activation values are too big"
|
||||||
|
sdxl.lda.to(dtype=torch.float32)
|
||||||
|
predicted_image = sdxl.lda.decode_latents(x.to(dtype=torch.float32))
|
||||||
|
|
||||||
|
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_woman)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test_sdxl_random_init(
|
def test_sdxl_random_init(
|
||||||
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device
|
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device
|
||||||
|
|
|
@ -35,7 +35,7 @@ output.images[0].save("std_random_init_expected.png")
|
||||||
Special cases:
|
Special cases:
|
||||||
|
|
||||||
- `expected_refonly.png` has been generated [with Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui).
|
- `expected_refonly.png` has been generated [with Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui).
|
||||||
- `expected_inpainting_refonly.png`, `expected_image_ip_adapter_woman.png` have been generated with refiners itself (and inspected so that it looks reasonable).
|
- `expected_inpainting_refonly.png`, `expected_image_ip_adapter_woman.png`, `expected_image_sdxl_ip_adapter_woman.png` have been generated with refiners itself (and inspected so that they look reasonable).
|
||||||
|
|
||||||
## Other images
|
## Other images
|
||||||
|
|
||||||
|
|
Binary file not shown.
After Width: | Height: | Size: 1.1 MiB |
Loading…
Reference in a new issue