From cce2a98fa668f7a6beef780cb2f06a0f684a4ccf Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Tue, 5 Mar 2024 14:53:16 +0100 Subject: [PATCH] add sanity check to auto_attach_loras --- .../convert_fooocus_control_lora.py | 3 +- src/refiners/fluxion/adapters/lora.py | 60 +++++++++++++++---- .../foundationals/latent_diffusion/lora.py | 9 +-- 3 files changed, 50 insertions(+), 22 deletions(-) diff --git a/scripts/conversion/convert_fooocus_control_lora.py b/scripts/conversion/convert_fooocus_control_lora.py index ceec2d2..b378973 100644 --- a/scripts/conversion/convert_fooocus_control_lora.py +++ b/scripts/conversion/convert_fooocus_control_lora.py @@ -82,8 +82,7 @@ def load_lora_layers( } # auto-attach the LoRA layers to the U-Net - failed_keys = auto_attach_loras(lora_layers, control_lora, exclude=["ZeroConvolution", "ConditionEncoder"]) - assert not failed_keys, f"Failed to auto-attach {len(failed_keys)}/{len(lora_layers)} LoRA layers." + auto_attach_loras(lora_layers, control_lora, exclude=["ZeroConvolution", "ConditionEncoder"]) # eject all the LoRA adapters from the U-Net # because we need each target path as if the adapter wasn't injected diff --git a/src/refiners/fluxion/adapters/lora.py b/src/refiners/fluxion/adapters/lora.py index 2868159..043a973 100644 --- a/src/refiners/fluxion/adapters/lora.py +++ b/src/refiners/fluxion/adapters/lora.py @@ -446,7 +446,7 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]): return lora -def auto_attach_loras( +def _auto_attach_loras( loras: dict[str, Lora[Any]], target: fl.Chain, /, @@ -454,18 +454,6 @@ def auto_attach_loras( exclude: list[str] | None = None, debug_map: list[tuple[str, str]] | None = None, ) -> list[str]: - """Auto-attach several LoRA layers to a Chain. - - Args: - loras: A dictionary of LoRA layers associated to their respective key. - target: The target Chain. - include: A list of layer names, only layers with such a layer in its parents will be considered. - exclude: A list of layer names, layers with such a layer in its parents will not be considered. - debug_map: Pass a list to get a debug mapping of key - path pairs of attached points. - - Returns: - A list of keys of LoRA layers which failed to attach. - """ failed_keys: list[str] = [] for key, lora in loras.items(): if attached := lora.auto_attach(target, include=include, exclude=exclude): @@ -481,3 +469,49 @@ def auto_attach_loras( failed_keys.append(key) return failed_keys + + +def auto_attach_loras( + loras: dict[str, Lora[Any]], + target: fl.Chain, + /, + include: list[str] | None = None, + exclude: list[str] | None = None, + sanity_check: bool = True, + debug_map: list[tuple[str, str]] | None = None, +) -> list[str]: + """Auto-attach several LoRA layers to a Chain. + + Args: + loras: A dictionary of LoRA layers associated to their respective key. + target: The target Chain. + include: A list of layer names, only layers with such a layer in its parents will be considered. + exclude: A list of layer names, layers with such a layer in its parents will not be considered. + sanity_check: Check that LoRAs passed are correctly attached. + debug_map: Pass a list to get a debug mapping of key - path pairs of attached points. + Returns: + A list of keys of LoRA layers which failed to attach. + """ + + if not sanity_check: + return _auto_attach_loras(loras, target, include=include, exclude=exclude, debug_map=debug_map) + + loras_copy = {key: Lora.from_weights(lora.name, lora.down.weight, lora.up.weight) for key, lora in loras.items()} + debug_map_1: list[tuple[str, str]] = [] + failed_keys_1 = _auto_attach_loras(loras, target, include=include, exclude=exclude, debug_map=debug_map_1) + if len(debug_map_1) != len(loras) or failed_keys_1: + raise ValueError( + f"sanity check failed: {len(debug_map_1)} / {len(loras)} LoRA layers attached, {len(failed_keys_1)} failed" + ) + + # Sanity check: if we re-run the attach, all layers should fail. + debug_map_2: list[tuple[str, str]] = [] + failed_keys_2 = _auto_attach_loras(loras_copy, target, include=include, exclude=exclude, debug_map=debug_map_2) + if debug_map_2 or len(failed_keys_2) != len(loras): + raise ValueError( + f"sanity check failed: {len(debug_map_2)} / {len(loras)} LoRA layers attached twice, {len(failed_keys_2)} skipped" + ) + + if debug_map is not None: + debug_map += debug_map_1 + return failed_keys_1 diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index 9f853f2..822a0c4 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -1,5 +1,4 @@ from typing import Any, Iterator, cast -from warnings import warn from torch import Tensor @@ -115,9 +114,7 @@ class SDLoraManager: (keys are the names of the LoRAs, values are the LoRAs to add to the text encoder) """ text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key} - failed = auto_attach_loras(text_encoder_loras, self.clip_text_encoder) - if failed: - warn(f"failed to attach {len(failed)}/{len(text_encoder_loras)} loras to the text encoder") + auto_attach_loras(text_encoder_loras, self.clip_text_encoder) def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None: """Add multiple LoRAs to the U-Net. @@ -130,9 +127,7 @@ class SDLoraManager: exclude = [ block for s, block in self.unet_exclusions.items() if all([s not in key for key in unet_loras.keys()]) ] - failed = auto_attach_loras(unet_loras, self.unet, exclude=exclude) - if failed: - warn(f"failed to attach {len(failed)}/{len(unet_loras)} loras to the unet") + auto_attach_loras(unet_loras, self.unet, exclude=exclude) def remove_loras(self, *names: str) -> None: """Remove multiple LoRAs from the target.