From 396d16656443d89796f21c5e9031b1fae1626851 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Wed, 7 Feb 2024 16:08:52 +0100 Subject: [PATCH] make pad method private --- .../foundationals/latent_diffusion/lora.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index a364740..7ff2deb 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -222,9 +222,19 @@ class SDLoraManager: return {name: self.get_scale(name) for name in self.names} @staticmethod - def pad(input: str, /, padding_length: int = 2) -> str: - # make all numbers the same length so they sort correctly, - # e.g. foo.3.bar -> foo.03.bar + def _pad(input: str, /, padding_length: int = 2) -> str: + """Make all numbers the same length so they sort correctly. + + e.g. foo.3.bar -> foo.03.bar + + Args: + input: The string to pad. + padding_length: The length to pad the numbers to. + + Returns: + The padded string. + """ + new_split: list[str] = [] for s in input.split("_"): if s.isdigit(): @@ -242,7 +252,7 @@ class SDLoraManager: same order as the one we expect, for instance `q -> k -> v` or `in -> out`. This attempts to fix that issue, not cases where distant layers are called in a different order. - + Args: key: The key to sort. @@ -250,6 +260,7 @@ class SDLoraManager: The padded prefix of the key. A score depending on the key's suffix. """ + # this dict might not be exhaustive suffix_scores = {"q": 1, "k": 2, "v": 3, "in": 3, "out": 4, "out0": 4, "out_0": 4} patterns = ["_{}", "_{}_lora"] @@ -260,7 +271,7 @@ class SDLoraManager: # get the suffix and score for `key` (default: no suffix, highest score = 5) (sfx, score) = next(((k, v) for k, v in key_char_order.items() if key.endswith(k)), ("", 5)) - padded_key_prefix = SDLoraManager.pad(key.removesuffix(sfx)) + padded_key_prefix = SDLoraManager._pad(key.removesuffix(sfx)) return (padded_key_prefix, score) @staticmethod