diff --git a/scripts/conversion/convert_diffusers_ip_adapter.py b/scripts/conversion/convert_diffusers_ip_adapter.py index 98d35e4..8fcc3fe 100644 --- a/scripts/conversion/convert_diffusers_ip_adapter.py +++ b/scripts/conversion/convert_diffusers_ip_adapter.py @@ -1,10 +1,9 @@ import argparse from pathlib import Path -from typing import Any import torch -from refiners.fluxion.utils import load_tensors, save_to_safetensors +from refiners.fluxion.utils import save_to_safetensors from refiners.foundationals.latent_diffusion import SD1IPAdapter, SD1UNet, SDXLIPAdapter, SDXLUNet # Running: @@ -66,13 +65,17 @@ def main() -> None: if args.output_path is None: args.output_path = f"{Path(args.source_path).stem}.safetensors" - weights: dict[str, Any] = load_tensors(args.source_path, device="cpu") + # Do not use `load_tensors`: first-level values are not tensors. + weights: dict[str, dict[str, torch.Tensor]] = torch.load(args.source_path, "cpu") # type: ignore assert isinstance(weights, dict) assert sorted(weights.keys()) == ["image_proj", "ip_adapter"] - fine_grained = "latents" in weights["image_proj"] # aka IP-Adapter plus + image_proj_weights = weights["image_proj"] + ip_adapter_weights = weights["ip_adapter"] - match len(weights["ip_adapter"]): + fine_grained = "latents" in image_proj_weights # aka IP-Adapter plus + + match len(ip_adapter_weights): case 32: ip_adapter = SD1IPAdapter(target=SD1UNet(in_channels=4), fine_grained=fine_grained) cross_attn_mapping = CROSS_ATTN_MAPPING["sd15"] @@ -87,7 +90,6 @@ def main() -> None: state_dict: dict[str, torch.Tensor] = {} - image_proj_weights = weights["image_proj"] image_proj_state_dict: dict[str, torch.Tensor] if fine_grained: @@ -130,7 +132,6 @@ def main() -> None: for k, v in image_proj_state_dict.items(): state_dict[f"image_proj.{k}"] = v - ip_adapter_weights: dict[str, torch.Tensor] = weights["ip_adapter"] assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2 for i, _ in enumerate(ip_adapter.sub_adapters):