mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 07:08:45 +00:00
fix IP-Adapter weights conversion
This commit is contained in:
parent
5ab5d7fd1c
commit
8139b2dd91
|
@ -1,10 +1,9 @@
|
||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from refiners.fluxion.utils import load_tensors, save_to_safetensors
|
from refiners.fluxion.utils import save_to_safetensors
|
||||||
from refiners.foundationals.latent_diffusion import SD1IPAdapter, SD1UNet, SDXLIPAdapter, SDXLUNet
|
from refiners.foundationals.latent_diffusion import SD1IPAdapter, SD1UNet, SDXLIPAdapter, SDXLUNet
|
||||||
|
|
||||||
# Running:
|
# Running:
|
||||||
|
@ -66,13 +65,17 @@ def main() -> None:
|
||||||
if args.output_path is None:
|
if args.output_path is None:
|
||||||
args.output_path = f"{Path(args.source_path).stem}.safetensors"
|
args.output_path = f"{Path(args.source_path).stem}.safetensors"
|
||||||
|
|
||||||
weights: dict[str, Any] = load_tensors(args.source_path, device="cpu")
|
# Do not use `load_tensors`: first-level values are not tensors.
|
||||||
|
weights: dict[str, dict[str, torch.Tensor]] = torch.load(args.source_path, "cpu") # type: ignore
|
||||||
assert isinstance(weights, dict)
|
assert isinstance(weights, dict)
|
||||||
assert sorted(weights.keys()) == ["image_proj", "ip_adapter"]
|
assert sorted(weights.keys()) == ["image_proj", "ip_adapter"]
|
||||||
|
|
||||||
fine_grained = "latents" in weights["image_proj"] # aka IP-Adapter plus
|
image_proj_weights = weights["image_proj"]
|
||||||
|
ip_adapter_weights = weights["ip_adapter"]
|
||||||
|
|
||||||
match len(weights["ip_adapter"]):
|
fine_grained = "latents" in image_proj_weights # aka IP-Adapter plus
|
||||||
|
|
||||||
|
match len(ip_adapter_weights):
|
||||||
case 32:
|
case 32:
|
||||||
ip_adapter = SD1IPAdapter(target=SD1UNet(in_channels=4), fine_grained=fine_grained)
|
ip_adapter = SD1IPAdapter(target=SD1UNet(in_channels=4), fine_grained=fine_grained)
|
||||||
cross_attn_mapping = CROSS_ATTN_MAPPING["sd15"]
|
cross_attn_mapping = CROSS_ATTN_MAPPING["sd15"]
|
||||||
|
@ -87,7 +90,6 @@ def main() -> None:
|
||||||
|
|
||||||
state_dict: dict[str, torch.Tensor] = {}
|
state_dict: dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
image_proj_weights = weights["image_proj"]
|
|
||||||
image_proj_state_dict: dict[str, torch.Tensor]
|
image_proj_state_dict: dict[str, torch.Tensor]
|
||||||
|
|
||||||
if fine_grained:
|
if fine_grained:
|
||||||
|
@ -130,7 +132,6 @@ def main() -> None:
|
||||||
for k, v in image_proj_state_dict.items():
|
for k, v in image_proj_state_dict.items():
|
||||||
state_dict[f"image_proj.{k}"] = v
|
state_dict[f"image_proj.{k}"] = v
|
||||||
|
|
||||||
ip_adapter_weights: dict[str, torch.Tensor] = weights["ip_adapter"]
|
|
||||||
assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2
|
assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2
|
||||||
|
|
||||||
for i, _ in enumerate(ip_adapter.sub_adapters):
|
for i, _ in enumerate(ip_adapter.sub_adapters):
|
||||||
|
|
Loading…
Reference in a new issue