Update src/refiners/foundationals/latent_diffusion/lora.py

Co-authored-by: Laureηt <laurent@lagon.tech>
This commit is contained in:
Pierre Chapuis 2024-02-07 16:00:56 +01:00
parent b1c200c63a
commit 4d85918336

View file

@ -235,6 +235,21 @@ class SDLoraManager:
@staticmethod
def sort_keys(key: str, /) -> tuple[str, int]:
"""Compute the score of a key, relatively to its suffix.
When used by [`sorted`][sorted], the keys will only be sorted "at the suffix level".
The idea is that sometimes closely related keys in the state dict are not in the
same order as the one we expect, for instance `q -> k -> v` or `in -> out`. This
attempts to fix that issue, not cases where distant layers are called in a different
order.
Args:
key: The key to sort.
Returns:
The padded prefix of the key.
A score depending on the key's suffix.
"""
# this dict might not be exhaustive
suffix_scores = {"q": 1, "k": 2, "v": 3, "in": 3, "out": 4, "out0": 4, "out_0": 4}
patterns = ["_{}", "_{}_lora"]