mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
do not use get_parameter_name in conversion script
This commit is contained in:
parent
3c056e2231
commit
43075f60b0
|
@ -92,9 +92,13 @@ def main() -> None:
|
||||||
v_ip = f"{cross_attn_index}.to_v_ip.weight"
|
v_ip = f"{cross_attn_index}.to_v_ip.weight"
|
||||||
|
|
||||||
# Ignore Wq, Wk, Wv and Proj (hence strict=False): at runtime, they will be part of the UNet original weights
|
# Ignore Wq, Wk, Wv and Proj (hence strict=False): at runtime, they will be part of the UNet original weights
|
||||||
|
|
||||||
|
names = [k for k, _ in cross_attn.named_parameters()]
|
||||||
|
assert len(names) == 2
|
||||||
|
|
||||||
cross_attn_state_dict: dict[str, Any] = {
|
cross_attn_state_dict: dict[str, Any] = {
|
||||||
cross_attn.get_parameter_name("wk_prime"): ip_adapter_weights[k_ip],
|
names[0]: ip_adapter_weights[k_ip],
|
||||||
cross_attn.get_parameter_name("wv_prime"): ip_adapter_weights[v_ip],
|
names[1]: ip_adapter_weights[v_ip],
|
||||||
}
|
}
|
||||||
cross_attn.load_state_dict(state_dict=cross_attn_state_dict, strict=False)
|
cross_attn.load_state_dict(state_dict=cross_attn_state_dict, strict=False)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue