add sanity check to auto_attach_loras

This commit is contained in:
Pierre Chapuis 2024-03-05 14:53:16 +01:00
parent 5593b40073
commit cce2a98fa6
3 changed files with 50 additions and 22 deletions

View file

@ -82,8 +82,7 @@ def load_lora_layers(
}
# auto-attach the LoRA layers to the U-Net
failed_keys = auto_attach_loras(lora_layers, control_lora, exclude=["ZeroConvolution", "ConditionEncoder"])
assert not failed_keys, f"Failed to auto-attach {len(failed_keys)}/{len(lora_layers)} LoRA layers."
auto_attach_loras(lora_layers, control_lora, exclude=["ZeroConvolution", "ConditionEncoder"])
# eject all the LoRA adapters from the U-Net
# because we need each target path as if the adapter wasn't injected

View file

@ -446,7 +446,7 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
return lora
def auto_attach_loras(
def _auto_attach_loras(
loras: dict[str, Lora[Any]],
target: fl.Chain,
/,
@ -454,18 +454,6 @@ def auto_attach_loras(
exclude: list[str] | None = None,
debug_map: list[tuple[str, str]] | None = None,
) -> list[str]:
"""Auto-attach several LoRA layers to a Chain.
Args:
loras: A dictionary of LoRA layers associated to their respective key.
target: The target Chain.
include: A list of layer names, only layers with such a layer in its parents will be considered.
exclude: A list of layer names, layers with such a layer in its parents will not be considered.
debug_map: Pass a list to get a debug mapping of key - path pairs of attached points.
Returns:
A list of keys of LoRA layers which failed to attach.
"""
failed_keys: list[str] = []
for key, lora in loras.items():
if attached := lora.auto_attach(target, include=include, exclude=exclude):
@ -481,3 +469,49 @@ def auto_attach_loras(
failed_keys.append(key)
return failed_keys
def auto_attach_loras(
loras: dict[str, Lora[Any]],
target: fl.Chain,
/,
include: list[str] | None = None,
exclude: list[str] | None = None,
sanity_check: bool = True,
debug_map: list[tuple[str, str]] | None = None,
) -> list[str]:
"""Auto-attach several LoRA layers to a Chain.
Args:
loras: A dictionary of LoRA layers associated to their respective key.
target: The target Chain.
include: A list of layer names, only layers with such a layer in its parents will be considered.
exclude: A list of layer names, layers with such a layer in its parents will not be considered.
sanity_check: Check that LoRAs passed are correctly attached.
debug_map: Pass a list to get a debug mapping of key - path pairs of attached points.
Returns:
A list of keys of LoRA layers which failed to attach.
"""
if not sanity_check:
return _auto_attach_loras(loras, target, include=include, exclude=exclude, debug_map=debug_map)
loras_copy = {key: Lora.from_weights(lora.name, lora.down.weight, lora.up.weight) for key, lora in loras.items()}
debug_map_1: list[tuple[str, str]] = []
failed_keys_1 = _auto_attach_loras(loras, target, include=include, exclude=exclude, debug_map=debug_map_1)
if len(debug_map_1) != len(loras) or failed_keys_1:
raise ValueError(
f"sanity check failed: {len(debug_map_1)} / {len(loras)} LoRA layers attached, {len(failed_keys_1)} failed"
)
# Sanity check: if we re-run the attach, all layers should fail.
debug_map_2: list[tuple[str, str]] = []
failed_keys_2 = _auto_attach_loras(loras_copy, target, include=include, exclude=exclude, debug_map=debug_map_2)
if debug_map_2 or len(failed_keys_2) != len(loras):
raise ValueError(
f"sanity check failed: {len(debug_map_2)} / {len(loras)} LoRA layers attached twice, {len(failed_keys_2)} skipped"
)
if debug_map is not None:
debug_map += debug_map_1
return failed_keys_1

View file

@ -1,5 +1,4 @@
from typing import Any, Iterator, cast
from warnings import warn
from torch import Tensor
@ -115,9 +114,7 @@ class SDLoraManager:
(keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)
"""
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
failed = auto_attach_loras(text_encoder_loras, self.clip_text_encoder)
if failed:
warn(f"failed to attach {len(failed)}/{len(text_encoder_loras)} loras to the text encoder")
auto_attach_loras(text_encoder_loras, self.clip_text_encoder)
def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None:
"""Add multiple LoRAs to the U-Net.
@ -130,9 +127,7 @@ class SDLoraManager:
exclude = [
block for s, block in self.unet_exclusions.items() if all([s not in key for key in unet_loras.keys()])
]
failed = auto_attach_loras(unet_loras, self.unet, exclude=exclude)
if failed:
warn(f"failed to attach {len(failed)}/{len(unet_loras)} loras to the unet")
auto_attach_loras(unet_loras, self.unet, exclude=exclude)
def remove_loras(self, *names: str) -> None:
"""Remove multiple LoRAs from the target.