mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 15:02:01 +00:00
fix IP-Adapter weights conversion
This commit is contained in:
parent
5ab5d7fd1c
commit
8139b2dd91
|
@ -1,10 +1,9 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from refiners.fluxion.utils import load_tensors, save_to_safetensors
|
||||
from refiners.fluxion.utils import save_to_safetensors
|
||||
from refiners.foundationals.latent_diffusion import SD1IPAdapter, SD1UNet, SDXLIPAdapter, SDXLUNet
|
||||
|
||||
# Running:
|
||||
|
@ -66,13 +65,17 @@ def main() -> None:
|
|||
if args.output_path is None:
|
||||
args.output_path = f"{Path(args.source_path).stem}.safetensors"
|
||||
|
||||
weights: dict[str, Any] = load_tensors(args.source_path, device="cpu")
|
||||
# Do not use `load_tensors`: first-level values are not tensors.
|
||||
weights: dict[str, dict[str, torch.Tensor]] = torch.load(args.source_path, "cpu") # type: ignore
|
||||
assert isinstance(weights, dict)
|
||||
assert sorted(weights.keys()) == ["image_proj", "ip_adapter"]
|
||||
|
||||
fine_grained = "latents" in weights["image_proj"] # aka IP-Adapter plus
|
||||
image_proj_weights = weights["image_proj"]
|
||||
ip_adapter_weights = weights["ip_adapter"]
|
||||
|
||||
match len(weights["ip_adapter"]):
|
||||
fine_grained = "latents" in image_proj_weights # aka IP-Adapter plus
|
||||
|
||||
match len(ip_adapter_weights):
|
||||
case 32:
|
||||
ip_adapter = SD1IPAdapter(target=SD1UNet(in_channels=4), fine_grained=fine_grained)
|
||||
cross_attn_mapping = CROSS_ATTN_MAPPING["sd15"]
|
||||
|
@ -87,7 +90,6 @@ def main() -> None:
|
|||
|
||||
state_dict: dict[str, torch.Tensor] = {}
|
||||
|
||||
image_proj_weights = weights["image_proj"]
|
||||
image_proj_state_dict: dict[str, torch.Tensor]
|
||||
|
||||
if fine_grained:
|
||||
|
@ -130,7 +132,6 @@ def main() -> None:
|
|||
for k, v in image_proj_state_dict.items():
|
||||
state_dict[f"image_proj.{k}"] = v
|
||||
|
||||
ip_adapter_weights: dict[str, torch.Tensor] = weights["ip_adapter"]
|
||||
assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2
|
||||
|
||||
for i, _ in enumerate(ip_adapter.sub_adapters):
|
||||
|
|
Loading…
Reference in a new issue