diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index ca23adb..aaccd4d 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -223,6 +223,8 @@ class SDLoraManager: @staticmethod def pad(input: str, /, padding_length: int = 2) -> str: + # make all numbers the same length so they sort correctly, + # e.g. foo.3.bar -> foo.03.bar new_split: list[str] = [] for s in input.split("_"): if s.isdigit(): @@ -236,8 +238,14 @@ class SDLoraManager: # this dict might not be exhaustive suffix_scores = {"q": 1, "k": 2, "v": 3, "in": 3, "out": 4, "out0": 4, "out_0": 4} patterns = ["_{}", "_{}_lora"] + + # apply patterns to the keys of suffix_scores key_char_order = {f.format(k): v for k, v in suffix_scores.items() for f in patterns} + + # get the suffix and score for `key` (default: no suffix, highest score = 5) (sfx, score) = next(((k, v) for k, v in key_char_order.items() if key.endswith(k)), ("", 5)) + + # return tuple of: (padded key prefix, score) return (SDLoraManager.pad(key.removesuffix(sfx)), score) @staticmethod