From 83c95fcf44df8fb829b7d4ac063330199898c251 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Fri, 26 Jan 2024 17:03:53 +0100 Subject: [PATCH] fix sorting method for LoRA keys - support _out_0 - sort _in before _out - avoid false positives by only considering suffixes --- .../foundationals/latent_diffusion/lora.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index eb2c491..a881284 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -44,7 +44,7 @@ class SDLoraManager: loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)} # if no key contains "unet" or "text", assume all keys are for the unet - if all(["unet" not in key and "text" not in key for key in loras.keys()]): + if all("unet" not in key and "text" not in key for key in loras.keys()): loras = {f"unet_{key}": value for key, value in loras.items()} self.add_loras_to_unet(loras) @@ -141,15 +141,12 @@ class SDLoraManager: @staticmethod def sort_keys(key: str, /) -> tuple[str, int]: - # out0 happens sometimes as an alias for out ; this dict might not be exhaustive - key_char_order = {"q": 1, "k": 2, "v": 3, "out": 4, "out0": 4} - - for i, s in enumerate(key.split("_")): - if s in key_char_order: - prefix = SDLoraManager.pad("_".join(key.split("_")[:i])) - return (prefix, key_char_order[s]) - - return (SDLoraManager.pad(key), 5) + # this dict might not be exhaustive + suffix_scores = {"q": 1, "k": 2, "v": 3, "in": 3, "out": 4, "out0": 4, "out_0": 4} + patterns = ["_{}", "_{}_lora"] + key_char_order = {f.format(k): v for k, v in suffix_scores.items() for f in patterns} + (sfx, score) = next(((k, v) for k, v in key_char_order.items() if key.endswith(k)), ("", 5)) + return (SDLoraManager.pad(key.removesuffix(sfx)), score) @staticmethod def auto_attach(