From 43075f60b053d2544cde833ae250931fdea6d3c1 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Fri, 8 Sep 2023 18:26:24 +0200 Subject: [PATCH] do not use get_parameter_name in conversion script --- scripts/conversion/convert_diffusers_ip_adapter.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/scripts/conversion/convert_diffusers_ip_adapter.py b/scripts/conversion/convert_diffusers_ip_adapter.py index 1e50ad1..4e92f59 100644 --- a/scripts/conversion/convert_diffusers_ip_adapter.py +++ b/scripts/conversion/convert_diffusers_ip_adapter.py @@ -92,9 +92,13 @@ def main() -> None: v_ip = f"{cross_attn_index}.to_v_ip.weight" # Ignore Wq, Wk, Wv and Proj (hence strict=False): at runtime, they will be part of the UNet original weights + + names = [k for k, _ in cross_attn.named_parameters()] + assert len(names) == 2 + cross_attn_state_dict: dict[str, Any] = { - cross_attn.get_parameter_name("wk_prime"): ip_adapter_weights[k_ip], - cross_attn.get_parameter_name("wv_prime"): ip_adapter_weights[v_ip], + names[0]: ip_adapter_weights[k_ip], + names[1]: ip_adapter_weights[v_ip], } cross_attn.load_state_dict(state_dict=cross_attn_state_dict, strict=False)