diff --git a/src/refiners/fluxion/adapters/lora.py b/src/refiners/fluxion/adapters/lora.py index e773359..2a094c3 100644 --- a/src/refiners/fluxion/adapters/lora.py +++ b/src/refiners/fluxion/adapters/lora.py @@ -435,3 +435,32 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]): lora = self.loras[name] self.remove(lora) return lora + + +def auto_attach_loras( + loras: dict[str, Lora[Any]], + target: fl.Chain, + /, + exclude: list[str] | None = None, +) -> list[str]: + """Auto-attach several LoRA layers to a Chain. + + Args: + loras: A dictionary of LoRA layers associated to their respective key. + target: The target Chain. + + Returns: + A list of keys of LoRA layers which failed to attach. + """ + failed_keys: list[str] = [] + for key, lora in loras.items(): + if attached := lora.auto_attach(target, exclude=exclude): + adapter, parent = attached + if parent is None: + # `adapter` is already attached and `lora` has been added to it + continue + adapter.inject(parent) + else: + failed_keys.append(key) + + return failed_keys diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index 429ad42..9f853f2 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -4,7 +4,7 @@ from warnings import warn from torch import Tensor import refiners.fluxion.layers as fl -from refiners.fluxion.adapters.lora import Lora, LoraAdapter +from refiners.fluxion.adapters.lora import Lora, LoraAdapter, auto_attach_loras from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel @@ -115,7 +115,9 @@ class SDLoraManager: (keys are the names of the LoRAs, values are the LoRAs to add to the text encoder) """ text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key} - SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder) + failed = auto_attach_loras(text_encoder_loras, self.clip_text_encoder) + if failed: + warn(f"failed to attach {len(failed)}/{len(text_encoder_loras)} loras to the text encoder") def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None: """Add multiple LoRAs to the U-Net. @@ -128,7 +130,9 @@ class SDLoraManager: exclude = [ block for s, block in self.unet_exclusions.items() if all([s not in key for key in unet_loras.keys()]) ] - SDLoraManager.auto_attach(unet_loras, self.unet, exclude=exclude) + failed = auto_attach_loras(unet_loras, self.unet, exclude=exclude) + if failed: + warn(f"failed to attach {len(failed)}/{len(unet_loras)} loras to the unet") def remove_loras(self, *names: str) -> None: """Remove multiple LoRAs from the target. @@ -273,25 +277,3 @@ class SDLoraManager: padded_key_prefix = SDLoraManager._pad(key.removesuffix(sfx)) return (padded_key_prefix, score) - - @staticmethod - def auto_attach( - loras: dict[str, Lora[Any]], - target: fl.Chain, - /, - exclude: list[str] | None = None, - ) -> None: - failed_loras: dict[str, Lora[Any]] = {} - for key, lora in loras.items(): - if attach := lora.auto_attach(target, exclude=exclude): - adapter, parent = attach - # if parent is None, `adapter` is already attached and `lora` has been added to it - if parent is not None: - adapter.inject(parent) - else: - failed_loras[key] = lora - - if failed_loras: - warn(f"failed to attach {len(failed_loras)}/{len(loras)} loras to {target.__class__.__name__}") - - # TODO: add a stronger sanity check to make sure loras are attached correctly