From 4d8591833659ad67c5f893eda18e68d946e3127d Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Wed, 7 Feb 2024 16:00:56 +0100 Subject: [PATCH] Update src/refiners/foundationals/latent_diffusion/lora.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Laureηt --- .../foundationals/latent_diffusion/lora.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index bc16523..a364740 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -235,6 +235,21 @@ class SDLoraManager: @staticmethod def sort_keys(key: str, /) -> tuple[str, int]: + """Compute the score of a key, relatively to its suffix. + + When used by [`sorted`][sorted], the keys will only be sorted "at the suffix level". + The idea is that sometimes closely related keys in the state dict are not in the + same order as the one we expect, for instance `q -> k -> v` or `in -> out`. This + attempts to fix that issue, not cases where distant layers are called in a different + order. + + Args: + key: The key to sort. + + Returns: + The padded prefix of the key. + A score depending on the key's suffix. + """ # this dict might not be exhaustive suffix_scores = {"q": 1, "k": 2, "v": 3, "in": 3, "out": 4, "out0": 4, "out_0": 4} patterns = ["_{}", "_{}_lora"]