mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
Move helper to attach several LoRAs from SD to Fluxion
This commit is contained in:
parent
bec845553f
commit
35868ba34b
|
@ -435,3 +435,32 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
|
||||||
lora = self.loras[name]
|
lora = self.loras[name]
|
||||||
self.remove(lora)
|
self.remove(lora)
|
||||||
return lora
|
return lora
|
||||||
|
|
||||||
|
|
||||||
|
def auto_attach_loras(
|
||||||
|
loras: dict[str, Lora[Any]],
|
||||||
|
target: fl.Chain,
|
||||||
|
/,
|
||||||
|
exclude: list[str] | None = None,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Auto-attach several LoRA layers to a Chain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loras: A dictionary of LoRA layers associated to their respective key.
|
||||||
|
target: The target Chain.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of keys of LoRA layers which failed to attach.
|
||||||
|
"""
|
||||||
|
failed_keys: list[str] = []
|
||||||
|
for key, lora in loras.items():
|
||||||
|
if attached := lora.auto_attach(target, exclude=exclude):
|
||||||
|
adapter, parent = attached
|
||||||
|
if parent is None:
|
||||||
|
# `adapter` is already attached and `lora` has been added to it
|
||||||
|
continue
|
||||||
|
adapter.inject(parent)
|
||||||
|
else:
|
||||||
|
failed_keys.append(key)
|
||||||
|
|
||||||
|
return failed_keys
|
||||||
|
|
|
@ -4,7 +4,7 @@ from warnings import warn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.adapters.lora import Lora, LoraAdapter
|
from refiners.fluxion.adapters.lora import Lora, LoraAdapter, auto_attach_loras
|
||||||
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||||
|
|
||||||
|
|
||||||
|
@ -115,7 +115,9 @@ class SDLoraManager:
|
||||||
(keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)
|
(keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)
|
||||||
"""
|
"""
|
||||||
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
|
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
|
||||||
SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder)
|
failed = auto_attach_loras(text_encoder_loras, self.clip_text_encoder)
|
||||||
|
if failed:
|
||||||
|
warn(f"failed to attach {len(failed)}/{len(text_encoder_loras)} loras to the text encoder")
|
||||||
|
|
||||||
def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None:
|
def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None:
|
||||||
"""Add multiple LoRAs to the U-Net.
|
"""Add multiple LoRAs to the U-Net.
|
||||||
|
@ -128,7 +130,9 @@ class SDLoraManager:
|
||||||
exclude = [
|
exclude = [
|
||||||
block for s, block in self.unet_exclusions.items() if all([s not in key for key in unet_loras.keys()])
|
block for s, block in self.unet_exclusions.items() if all([s not in key for key in unet_loras.keys()])
|
||||||
]
|
]
|
||||||
SDLoraManager.auto_attach(unet_loras, self.unet, exclude=exclude)
|
failed = auto_attach_loras(unet_loras, self.unet, exclude=exclude)
|
||||||
|
if failed:
|
||||||
|
warn(f"failed to attach {len(failed)}/{len(unet_loras)} loras to the unet")
|
||||||
|
|
||||||
def remove_loras(self, *names: str) -> None:
|
def remove_loras(self, *names: str) -> None:
|
||||||
"""Remove multiple LoRAs from the target.
|
"""Remove multiple LoRAs from the target.
|
||||||
|
@ -273,25 +277,3 @@ class SDLoraManager:
|
||||||
|
|
||||||
padded_key_prefix = SDLoraManager._pad(key.removesuffix(sfx))
|
padded_key_prefix = SDLoraManager._pad(key.removesuffix(sfx))
|
||||||
return (padded_key_prefix, score)
|
return (padded_key_prefix, score)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def auto_attach(
|
|
||||||
loras: dict[str, Lora[Any]],
|
|
||||||
target: fl.Chain,
|
|
||||||
/,
|
|
||||||
exclude: list[str] | None = None,
|
|
||||||
) -> None:
|
|
||||||
failed_loras: dict[str, Lora[Any]] = {}
|
|
||||||
for key, lora in loras.items():
|
|
||||||
if attach := lora.auto_attach(target, exclude=exclude):
|
|
||||||
adapter, parent = attach
|
|
||||||
# if parent is None, `adapter` is already attached and `lora` has been added to it
|
|
||||||
if parent is not None:
|
|
||||||
adapter.inject(parent)
|
|
||||||
else:
|
|
||||||
failed_loras[key] = lora
|
|
||||||
|
|
||||||
if failed_loras:
|
|
||||||
warn(f"failed to attach {len(failed_loras)}/{len(loras)} loras to {target.__class__.__name__}")
|
|
||||||
|
|
||||||
# TODO: add a stronger sanity check to make sure loras are attached correctly
|
|
||||||
|
|
Loading…
Reference in a new issue