fix sorting method for LoRA keys

- support _out_0
- sort _in before _out
- avoid false positives by only considering suffixes
This commit is contained in:
Pierre Chapuis 2024-01-26 17:03:53 +01:00
parent ce22c8f51b
commit 83c95fcf44

View file

@ -44,7 +44,7 @@ class SDLoraManager:
loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)}
# if no key contains "unet" or "text", assume all keys are for the unet
if all(["unet" not in key and "text" not in key for key in loras.keys()]):
if all("unet" not in key and "text" not in key for key in loras.keys()):
loras = {f"unet_{key}": value for key, value in loras.items()}
self.add_loras_to_unet(loras)
@ -141,15 +141,12 @@ class SDLoraManager:
@staticmethod
def sort_keys(key: str, /) -> tuple[str, int]:
# out0 happens sometimes as an alias for out ; this dict might not be exhaustive
key_char_order = {"q": 1, "k": 2, "v": 3, "out": 4, "out0": 4}
for i, s in enumerate(key.split("_")):
if s in key_char_order:
prefix = SDLoraManager.pad("_".join(key.split("_")[:i]))
return (prefix, key_char_order[s])
return (SDLoraManager.pad(key), 5)
# this dict might not be exhaustive
suffix_scores = {"q": 1, "k": 2, "v": 3, "in": 3, "out": 4, "out0": 4, "out_0": 4}
patterns = ["_{}", "_{}_lora"]
key_char_order = {f.format(k): v for k, v in suffix_scores.items() for f in patterns}
(sfx, score) = next(((k, v) for k, v in key_char_order.items() if key.endswith(k)), ("", 5))
return (SDLoraManager.pad(key.removesuffix(sfx)), score)
@staticmethod
def auto_attach(