diff --git a/scripts/conversion/convert_diffusers_ip_adapter.py b/scripts/conversion/convert_diffusers_ip_adapter.py index 1e50ad1..4e92f59 100644 --- a/scripts/conversion/convert_diffusers_ip_adapter.py +++ b/scripts/conversion/convert_diffusers_ip_adapter.py @@ -92,9 +92,13 @@ def main() -> None: v_ip = f"{cross_attn_index}.to_v_ip.weight" # Ignore Wq, Wk, Wv and Proj (hence strict=False): at runtime, they will be part of the UNet original weights + + names = [k for k, _ in cross_attn.named_parameters()] + assert len(names) == 2 + cross_attn_state_dict: dict[str, Any] = { - cross_attn.get_parameter_name("wk_prime"): ip_adapter_weights[k_ip], - cross_attn.get_parameter_name("wv_prime"): ip_adapter_weights[v_ip], + names[0]: ip_adapter_weights[k_ip], + names[1]: ip_adapter_weights[v_ip], } cross_attn.load_state_dict(state_dict=cross_attn_state_dict, strict=False)