do not use get_parameter_name in conversion script

This commit is contained in:
Pierre Chapuis 2023-09-08 18:26:24 +02:00
parent 3c056e2231
commit 43075f60b0

View file

@ -92,9 +92,13 @@ def main() -> None:
v_ip = f"{cross_attn_index}.to_v_ip.weight"
# Ignore Wq, Wk, Wv and Proj (hence strict=False): at runtime, they will be part of the UNet original weights
names = [k for k, _ in cross_attn.named_parameters()]
assert len(names) == 2
cross_attn_state_dict: dict[str, Any] = {
cross_attn.get_parameter_name("wk_prime"): ip_adapter_weights[k_ip],
cross_attn.get_parameter_name("wv_prime"): ip_adapter_weights[v_ip],
names[0]: ip_adapter_weights[k_ip],
names[1]: ip_adapter_weights[v_ip],
}
cross_attn.load_state_dict(state_dict=cross_attn_state_dict, strict=False)