diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index bc16523..a364740 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -235,6 +235,21 @@ class SDLoraManager: @staticmethod def sort_keys(key: str, /) -> tuple[str, int]: + """Compute the score of a key, relatively to its suffix. + + When used by [`sorted`][sorted], the keys will only be sorted "at the suffix level". + The idea is that sometimes closely related keys in the state dict are not in the + same order as the one we expect, for instance `q -> k -> v` or `in -> out`. This + attempts to fix that issue, not cases where distant layers are called in a different + order. + + Args: + key: The key to sort. + + Returns: + The padded prefix of the key. + A score depending on the key's suffix. + """ # this dict might not be exhaustive suffix_scores = {"q": 1, "k": 2, "v": 3, "in": 3, "out": 4, "out0": 4, "out_0": 4} patterns = ["_{}", "_{}_lora"]