mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
add sanity check to auto_attach_loras
This commit is contained in:
parent
5593b40073
commit
cce2a98fa6
|
@ -82,8 +82,7 @@ def load_lora_layers(
|
|||
}
|
||||
|
||||
# auto-attach the LoRA layers to the U-Net
|
||||
failed_keys = auto_attach_loras(lora_layers, control_lora, exclude=["ZeroConvolution", "ConditionEncoder"])
|
||||
assert not failed_keys, f"Failed to auto-attach {len(failed_keys)}/{len(lora_layers)} LoRA layers."
|
||||
auto_attach_loras(lora_layers, control_lora, exclude=["ZeroConvolution", "ConditionEncoder"])
|
||||
|
||||
# eject all the LoRA adapters from the U-Net
|
||||
# because we need each target path as if the adapter wasn't injected
|
||||
|
|
|
@ -446,7 +446,7 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
|
|||
return lora
|
||||
|
||||
|
||||
def auto_attach_loras(
|
||||
def _auto_attach_loras(
|
||||
loras: dict[str, Lora[Any]],
|
||||
target: fl.Chain,
|
||||
/,
|
||||
|
@ -454,18 +454,6 @@ def auto_attach_loras(
|
|||
exclude: list[str] | None = None,
|
||||
debug_map: list[tuple[str, str]] | None = None,
|
||||
) -> list[str]:
|
||||
"""Auto-attach several LoRA layers to a Chain.
|
||||
|
||||
Args:
|
||||
loras: A dictionary of LoRA layers associated to their respective key.
|
||||
target: The target Chain.
|
||||
include: A list of layer names, only layers with such a layer in its parents will be considered.
|
||||
exclude: A list of layer names, layers with such a layer in its parents will not be considered.
|
||||
debug_map: Pass a list to get a debug mapping of key - path pairs of attached points.
|
||||
|
||||
Returns:
|
||||
A list of keys of LoRA layers which failed to attach.
|
||||
"""
|
||||
failed_keys: list[str] = []
|
||||
for key, lora in loras.items():
|
||||
if attached := lora.auto_attach(target, include=include, exclude=exclude):
|
||||
|
@ -481,3 +469,49 @@ def auto_attach_loras(
|
|||
failed_keys.append(key)
|
||||
|
||||
return failed_keys
|
||||
|
||||
|
||||
def auto_attach_loras(
|
||||
loras: dict[str, Lora[Any]],
|
||||
target: fl.Chain,
|
||||
/,
|
||||
include: list[str] | None = None,
|
||||
exclude: list[str] | None = None,
|
||||
sanity_check: bool = True,
|
||||
debug_map: list[tuple[str, str]] | None = None,
|
||||
) -> list[str]:
|
||||
"""Auto-attach several LoRA layers to a Chain.
|
||||
|
||||
Args:
|
||||
loras: A dictionary of LoRA layers associated to their respective key.
|
||||
target: The target Chain.
|
||||
include: A list of layer names, only layers with such a layer in its parents will be considered.
|
||||
exclude: A list of layer names, layers with such a layer in its parents will not be considered.
|
||||
sanity_check: Check that LoRAs passed are correctly attached.
|
||||
debug_map: Pass a list to get a debug mapping of key - path pairs of attached points.
|
||||
Returns:
|
||||
A list of keys of LoRA layers which failed to attach.
|
||||
"""
|
||||
|
||||
if not sanity_check:
|
||||
return _auto_attach_loras(loras, target, include=include, exclude=exclude, debug_map=debug_map)
|
||||
|
||||
loras_copy = {key: Lora.from_weights(lora.name, lora.down.weight, lora.up.weight) for key, lora in loras.items()}
|
||||
debug_map_1: list[tuple[str, str]] = []
|
||||
failed_keys_1 = _auto_attach_loras(loras, target, include=include, exclude=exclude, debug_map=debug_map_1)
|
||||
if len(debug_map_1) != len(loras) or failed_keys_1:
|
||||
raise ValueError(
|
||||
f"sanity check failed: {len(debug_map_1)} / {len(loras)} LoRA layers attached, {len(failed_keys_1)} failed"
|
||||
)
|
||||
|
||||
# Sanity check: if we re-run the attach, all layers should fail.
|
||||
debug_map_2: list[tuple[str, str]] = []
|
||||
failed_keys_2 = _auto_attach_loras(loras_copy, target, include=include, exclude=exclude, debug_map=debug_map_2)
|
||||
if debug_map_2 or len(failed_keys_2) != len(loras):
|
||||
raise ValueError(
|
||||
f"sanity check failed: {len(debug_map_2)} / {len(loras)} LoRA layers attached twice, {len(failed_keys_2)} skipped"
|
||||
)
|
||||
|
||||
if debug_map is not None:
|
||||
debug_map += debug_map_1
|
||||
return failed_keys_1
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from typing import Any, Iterator, cast
|
||||
from warnings import warn
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -115,9 +114,7 @@ class SDLoraManager:
|
|||
(keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)
|
||||
"""
|
||||
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
|
||||
failed = auto_attach_loras(text_encoder_loras, self.clip_text_encoder)
|
||||
if failed:
|
||||
warn(f"failed to attach {len(failed)}/{len(text_encoder_loras)} loras to the text encoder")
|
||||
auto_attach_loras(text_encoder_loras, self.clip_text_encoder)
|
||||
|
||||
def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None:
|
||||
"""Add multiple LoRAs to the U-Net.
|
||||
|
@ -130,9 +127,7 @@ class SDLoraManager:
|
|||
exclude = [
|
||||
block for s, block in self.unet_exclusions.items() if all([s not in key for key in unet_loras.keys()])
|
||||
]
|
||||
failed = auto_attach_loras(unet_loras, self.unet, exclude=exclude)
|
||||
if failed:
|
||||
warn(f"failed to attach {len(failed)}/{len(unet_loras)} loras to the unet")
|
||||
auto_attach_loras(unet_loras, self.unet, exclude=exclude)
|
||||
|
||||
def remove_loras(self, *names: str) -> None:
|
||||
"""Remove multiple LoRAs from the target.
|
||||
|
|
Loading…
Reference in a new issue