add sanity check to auto_attach_loras

This commit is contained in:
Pierre Chapuis 2024-03-05 14:53:16 +01:00
parent 5593b40073
commit cce2a98fa6
3 changed files with 50 additions and 22 deletions

View file

@ -82,8 +82,7 @@ def load_lora_layers(
} }
# auto-attach the LoRA layers to the U-Net # auto-attach the LoRA layers to the U-Net
failed_keys = auto_attach_loras(lora_layers, control_lora, exclude=["ZeroConvolution", "ConditionEncoder"]) auto_attach_loras(lora_layers, control_lora, exclude=["ZeroConvolution", "ConditionEncoder"])
assert not failed_keys, f"Failed to auto-attach {len(failed_keys)}/{len(lora_layers)} LoRA layers."
# eject all the LoRA adapters from the U-Net # eject all the LoRA adapters from the U-Net
# because we need each target path as if the adapter wasn't injected # because we need each target path as if the adapter wasn't injected

View file

@ -446,7 +446,7 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
return lora return lora
def auto_attach_loras( def _auto_attach_loras(
loras: dict[str, Lora[Any]], loras: dict[str, Lora[Any]],
target: fl.Chain, target: fl.Chain,
/, /,
@ -454,18 +454,6 @@ def auto_attach_loras(
exclude: list[str] | None = None, exclude: list[str] | None = None,
debug_map: list[tuple[str, str]] | None = None, debug_map: list[tuple[str, str]] | None = None,
) -> list[str]: ) -> list[str]:
"""Auto-attach several LoRA layers to a Chain.
Args:
loras: A dictionary of LoRA layers associated to their respective key.
target: The target Chain.
include: A list of layer names, only layers with such a layer in its parents will be considered.
exclude: A list of layer names, layers with such a layer in its parents will not be considered.
debug_map: Pass a list to get a debug mapping of key - path pairs of attached points.
Returns:
A list of keys of LoRA layers which failed to attach.
"""
failed_keys: list[str] = [] failed_keys: list[str] = []
for key, lora in loras.items(): for key, lora in loras.items():
if attached := lora.auto_attach(target, include=include, exclude=exclude): if attached := lora.auto_attach(target, include=include, exclude=exclude):
@ -481,3 +469,49 @@ def auto_attach_loras(
failed_keys.append(key) failed_keys.append(key)
return failed_keys return failed_keys
def auto_attach_loras(
loras: dict[str, Lora[Any]],
target: fl.Chain,
/,
include: list[str] | None = None,
exclude: list[str] | None = None,
sanity_check: bool = True,
debug_map: list[tuple[str, str]] | None = None,
) -> list[str]:
"""Auto-attach several LoRA layers to a Chain.
Args:
loras: A dictionary of LoRA layers associated to their respective key.
target: The target Chain.
include: A list of layer names, only layers with such a layer in its parents will be considered.
exclude: A list of layer names, layers with such a layer in its parents will not be considered.
sanity_check: Check that LoRAs passed are correctly attached.
debug_map: Pass a list to get a debug mapping of key - path pairs of attached points.
Returns:
A list of keys of LoRA layers which failed to attach.
"""
if not sanity_check:
return _auto_attach_loras(loras, target, include=include, exclude=exclude, debug_map=debug_map)
loras_copy = {key: Lora.from_weights(lora.name, lora.down.weight, lora.up.weight) for key, lora in loras.items()}
debug_map_1: list[tuple[str, str]] = []
failed_keys_1 = _auto_attach_loras(loras, target, include=include, exclude=exclude, debug_map=debug_map_1)
if len(debug_map_1) != len(loras) or failed_keys_1:
raise ValueError(
f"sanity check failed: {len(debug_map_1)} / {len(loras)} LoRA layers attached, {len(failed_keys_1)} failed"
)
# Sanity check: if we re-run the attach, all layers should fail.
debug_map_2: list[tuple[str, str]] = []
failed_keys_2 = _auto_attach_loras(loras_copy, target, include=include, exclude=exclude, debug_map=debug_map_2)
if debug_map_2 or len(failed_keys_2) != len(loras):
raise ValueError(
f"sanity check failed: {len(debug_map_2)} / {len(loras)} LoRA layers attached twice, {len(failed_keys_2)} skipped"
)
if debug_map is not None:
debug_map += debug_map_1
return failed_keys_1

View file

@ -1,5 +1,4 @@
from typing import Any, Iterator, cast from typing import Any, Iterator, cast
from warnings import warn
from torch import Tensor from torch import Tensor
@ -115,9 +114,7 @@ class SDLoraManager:
(keys are the names of the LoRAs, values are the LoRAs to add to the text encoder) (keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)
""" """
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key} text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
failed = auto_attach_loras(text_encoder_loras, self.clip_text_encoder) auto_attach_loras(text_encoder_loras, self.clip_text_encoder)
if failed:
warn(f"failed to attach {len(failed)}/{len(text_encoder_loras)} loras to the text encoder")
def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None: def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None:
"""Add multiple LoRAs to the U-Net. """Add multiple LoRAs to the U-Net.
@ -130,9 +127,7 @@ class SDLoraManager:
exclude = [ exclude = [
block for s, block in self.unet_exclusions.items() if all([s not in key for key in unet_loras.keys()]) block for s, block in self.unet_exclusions.items() if all([s not in key for key in unet_loras.keys()])
] ]
failed = auto_attach_loras(unet_loras, self.unet, exclude=exclude) auto_attach_loras(unet_loras, self.unet, exclude=exclude)
if failed:
warn(f"failed to attach {len(failed)}/{len(unet_loras)} loras to the unet")
def remove_loras(self, *names: str) -> None: def remove_loras(self, *names: str) -> None:
"""Remove multiple LoRAs from the target. """Remove multiple LoRAs from the target.