make pad method private

This commit is contained in:
Pierre Chapuis 2024-02-07 16:08:52 +01:00
parent 4d85918336
commit 396d166564

View file

@ -222,9 +222,19 @@ class SDLoraManager:
return {name: self.get_scale(name) for name in self.names} return {name: self.get_scale(name) for name in self.names}
@staticmethod @staticmethod
def pad(input: str, /, padding_length: int = 2) -> str: def _pad(input: str, /, padding_length: int = 2) -> str:
# make all numbers the same length so they sort correctly, """Make all numbers the same length so they sort correctly.
# e.g. foo.3.bar -> foo.03.bar
e.g. foo.3.bar -> foo.03.bar
Args:
input: The string to pad.
padding_length: The length to pad the numbers to.
Returns:
The padded string.
"""
new_split: list[str] = [] new_split: list[str] = []
for s in input.split("_"): for s in input.split("_"):
if s.isdigit(): if s.isdigit():
@ -242,7 +252,7 @@ class SDLoraManager:
same order as the one we expect, for instance `q -> k -> v` or `in -> out`. This same order as the one we expect, for instance `q -> k -> v` or `in -> out`. This
attempts to fix that issue, not cases where distant layers are called in a different attempts to fix that issue, not cases where distant layers are called in a different
order. order.
Args: Args:
key: The key to sort. key: The key to sort.
@ -250,6 +260,7 @@ class SDLoraManager:
The padded prefix of the key. The padded prefix of the key.
A score depending on the key's suffix. A score depending on the key's suffix.
""" """
# this dict might not be exhaustive # this dict might not be exhaustive
suffix_scores = {"q": 1, "k": 2, "v": 3, "in": 3, "out": 4, "out0": 4, "out_0": 4} suffix_scores = {"q": 1, "k": 2, "v": 3, "in": 3, "out": 4, "out0": 4, "out_0": 4}
patterns = ["_{}", "_{}_lora"] patterns = ["_{}", "_{}_lora"]
@ -260,7 +271,7 @@ class SDLoraManager:
# get the suffix and score for `key` (default: no suffix, highest score = 5) # get the suffix and score for `key` (default: no suffix, highest score = 5)
(sfx, score) = next(((k, v) for k, v in key_char_order.items() if key.endswith(k)), ("", 5)) (sfx, score) = next(((k, v) for k, v in key_char_order.items() if key.endswith(k)), ("", 5))
padded_key_prefix = SDLoraManager.pad(key.removesuffix(sfx)) padded_key_prefix = SDLoraManager._pad(key.removesuffix(sfx))
return (padded_key_prefix, score) return (padded_key_prefix, score)
@staticmethod @staticmethod