Move helper to attach several LoRAs from SD to Fluxion

This commit is contained in:
Pierre Chapuis 2024-02-14 11:37:55 +01:00
parent bec845553f
commit 35868ba34b
2 changed files with 36 additions and 25 deletions

View file

@ -435,3 +435,32 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
lora = self.loras[name] lora = self.loras[name]
self.remove(lora) self.remove(lora)
return lora return lora
def auto_attach_loras(
loras: dict[str, Lora[Any]],
target: fl.Chain,
/,
exclude: list[str] | None = None,
) -> list[str]:
"""Auto-attach several LoRA layers to a Chain.
Args:
loras: A dictionary of LoRA layers associated to their respective key.
target: The target Chain.
Returns:
A list of keys of LoRA layers which failed to attach.
"""
failed_keys: list[str] = []
for key, lora in loras.items():
if attached := lora.auto_attach(target, exclude=exclude):
adapter, parent = attached
if parent is None:
# `adapter` is already attached and `lora` has been added to it
continue
adapter.inject(parent)
else:
failed_keys.append(key)
return failed_keys

View file

@ -4,7 +4,7 @@ from warnings import warn
from torch import Tensor from torch import Tensor
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.lora import Lora, LoraAdapter from refiners.fluxion.adapters.lora import Lora, LoraAdapter, auto_attach_loras
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
@ -115,7 +115,9 @@ class SDLoraManager:
(keys are the names of the LoRAs, values are the LoRAs to add to the text encoder) (keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)
""" """
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key} text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder) failed = auto_attach_loras(text_encoder_loras, self.clip_text_encoder)
if failed:
warn(f"failed to attach {len(failed)}/{len(text_encoder_loras)} loras to the text encoder")
def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None: def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None:
"""Add multiple LoRAs to the U-Net. """Add multiple LoRAs to the U-Net.
@ -128,7 +130,9 @@ class SDLoraManager:
exclude = [ exclude = [
block for s, block in self.unet_exclusions.items() if all([s not in key for key in unet_loras.keys()]) block for s, block in self.unet_exclusions.items() if all([s not in key for key in unet_loras.keys()])
] ]
SDLoraManager.auto_attach(unet_loras, self.unet, exclude=exclude) failed = auto_attach_loras(unet_loras, self.unet, exclude=exclude)
if failed:
warn(f"failed to attach {len(failed)}/{len(unet_loras)} loras to the unet")
def remove_loras(self, *names: str) -> None: def remove_loras(self, *names: str) -> None:
"""Remove multiple LoRAs from the target. """Remove multiple LoRAs from the target.
@ -273,25 +277,3 @@ class SDLoraManager:
padded_key_prefix = SDLoraManager._pad(key.removesuffix(sfx)) padded_key_prefix = SDLoraManager._pad(key.removesuffix(sfx))
return (padded_key_prefix, score) return (padded_key_prefix, score)
@staticmethod
def auto_attach(
loras: dict[str, Lora[Any]],
target: fl.Chain,
/,
exclude: list[str] | None = None,
) -> None:
failed_loras: dict[str, Lora[Any]] = {}
for key, lora in loras.items():
if attach := lora.auto_attach(target, exclude=exclude):
adapter, parent = attach
# if parent is None, `adapter` is already attached and `lora` has been added to it
if parent is not None:
adapter.inject(parent)
else:
failed_loras[key] = lora
if failed_loras:
warn(f"failed to attach {len(failed_loras)}/{len(loras)} loras to {target.__class__.__name__}")
# TODO: add a stronger sanity check to make sure loras are attached correctly