fix IP-Adapter weights conversion

This commit is contained in:
Pierre Chapuis 2024-02-21 14:28:29 +01:00
parent 5ab5d7fd1c
commit 8139b2dd91

View file

@ -1,10 +1,9 @@
import argparse
from pathlib import Path
from typing import Any
import torch
from refiners.fluxion.utils import load_tensors, save_to_safetensors
from refiners.fluxion.utils import save_to_safetensors
from refiners.foundationals.latent_diffusion import SD1IPAdapter, SD1UNet, SDXLIPAdapter, SDXLUNet
# Running:
@ -66,13 +65,17 @@ def main() -> None:
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}.safetensors"
weights: dict[str, Any] = load_tensors(args.source_path, device="cpu")
# Do not use `load_tensors`: first-level values are not tensors.
weights: dict[str, dict[str, torch.Tensor]] = torch.load(args.source_path, "cpu") # type: ignore
assert isinstance(weights, dict)
assert sorted(weights.keys()) == ["image_proj", "ip_adapter"]
fine_grained = "latents" in weights["image_proj"] # aka IP-Adapter plus
image_proj_weights = weights["image_proj"]
ip_adapter_weights = weights["ip_adapter"]
match len(weights["ip_adapter"]):
fine_grained = "latents" in image_proj_weights # aka IP-Adapter plus
match len(ip_adapter_weights):
case 32:
ip_adapter = SD1IPAdapter(target=SD1UNet(in_channels=4), fine_grained=fine_grained)
cross_attn_mapping = CROSS_ATTN_MAPPING["sd15"]
@ -87,7 +90,6 @@ def main() -> None:
state_dict: dict[str, torch.Tensor] = {}
image_proj_weights = weights["image_proj"]
image_proj_state_dict: dict[str, torch.Tensor]
if fine_grained:
@ -130,7 +132,6 @@ def main() -> None:
for k, v in image_proj_state_dict.items():
state_dict[f"image_proj.{k}"] = v
ip_adapter_weights: dict[str, torch.Tensor] = weights["ip_adapter"]
assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2
for i, _ in enumerate(ip_adapter.sub_adapters):