mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
add sanity check to auto_attach_loras
This commit is contained in:
parent
5593b40073
commit
cce2a98fa6
|
@ -82,8 +82,7 @@ def load_lora_layers(
|
||||||
}
|
}
|
||||||
|
|
||||||
# auto-attach the LoRA layers to the U-Net
|
# auto-attach the LoRA layers to the U-Net
|
||||||
failed_keys = auto_attach_loras(lora_layers, control_lora, exclude=["ZeroConvolution", "ConditionEncoder"])
|
auto_attach_loras(lora_layers, control_lora, exclude=["ZeroConvolution", "ConditionEncoder"])
|
||||||
assert not failed_keys, f"Failed to auto-attach {len(failed_keys)}/{len(lora_layers)} LoRA layers."
|
|
||||||
|
|
||||||
# eject all the LoRA adapters from the U-Net
|
# eject all the LoRA adapters from the U-Net
|
||||||
# because we need each target path as if the adapter wasn't injected
|
# because we need each target path as if the adapter wasn't injected
|
||||||
|
|
|
@ -446,7 +446,7 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
|
||||||
return lora
|
return lora
|
||||||
|
|
||||||
|
|
||||||
def auto_attach_loras(
|
def _auto_attach_loras(
|
||||||
loras: dict[str, Lora[Any]],
|
loras: dict[str, Lora[Any]],
|
||||||
target: fl.Chain,
|
target: fl.Chain,
|
||||||
/,
|
/,
|
||||||
|
@ -454,18 +454,6 @@ def auto_attach_loras(
|
||||||
exclude: list[str] | None = None,
|
exclude: list[str] | None = None,
|
||||||
debug_map: list[tuple[str, str]] | None = None,
|
debug_map: list[tuple[str, str]] | None = None,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Auto-attach several LoRA layers to a Chain.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
loras: A dictionary of LoRA layers associated to their respective key.
|
|
||||||
target: The target Chain.
|
|
||||||
include: A list of layer names, only layers with such a layer in its parents will be considered.
|
|
||||||
exclude: A list of layer names, layers with such a layer in its parents will not be considered.
|
|
||||||
debug_map: Pass a list to get a debug mapping of key - path pairs of attached points.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of keys of LoRA layers which failed to attach.
|
|
||||||
"""
|
|
||||||
failed_keys: list[str] = []
|
failed_keys: list[str] = []
|
||||||
for key, lora in loras.items():
|
for key, lora in loras.items():
|
||||||
if attached := lora.auto_attach(target, include=include, exclude=exclude):
|
if attached := lora.auto_attach(target, include=include, exclude=exclude):
|
||||||
|
@ -481,3 +469,49 @@ def auto_attach_loras(
|
||||||
failed_keys.append(key)
|
failed_keys.append(key)
|
||||||
|
|
||||||
return failed_keys
|
return failed_keys
|
||||||
|
|
||||||
|
|
||||||
|
def auto_attach_loras(
|
||||||
|
loras: dict[str, Lora[Any]],
|
||||||
|
target: fl.Chain,
|
||||||
|
/,
|
||||||
|
include: list[str] | None = None,
|
||||||
|
exclude: list[str] | None = None,
|
||||||
|
sanity_check: bool = True,
|
||||||
|
debug_map: list[tuple[str, str]] | None = None,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Auto-attach several LoRA layers to a Chain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loras: A dictionary of LoRA layers associated to their respective key.
|
||||||
|
target: The target Chain.
|
||||||
|
include: A list of layer names, only layers with such a layer in its parents will be considered.
|
||||||
|
exclude: A list of layer names, layers with such a layer in its parents will not be considered.
|
||||||
|
sanity_check: Check that LoRAs passed are correctly attached.
|
||||||
|
debug_map: Pass a list to get a debug mapping of key - path pairs of attached points.
|
||||||
|
Returns:
|
||||||
|
A list of keys of LoRA layers which failed to attach.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not sanity_check:
|
||||||
|
return _auto_attach_loras(loras, target, include=include, exclude=exclude, debug_map=debug_map)
|
||||||
|
|
||||||
|
loras_copy = {key: Lora.from_weights(lora.name, lora.down.weight, lora.up.weight) for key, lora in loras.items()}
|
||||||
|
debug_map_1: list[tuple[str, str]] = []
|
||||||
|
failed_keys_1 = _auto_attach_loras(loras, target, include=include, exclude=exclude, debug_map=debug_map_1)
|
||||||
|
if len(debug_map_1) != len(loras) or failed_keys_1:
|
||||||
|
raise ValueError(
|
||||||
|
f"sanity check failed: {len(debug_map_1)} / {len(loras)} LoRA layers attached, {len(failed_keys_1)} failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sanity check: if we re-run the attach, all layers should fail.
|
||||||
|
debug_map_2: list[tuple[str, str]] = []
|
||||||
|
failed_keys_2 = _auto_attach_loras(loras_copy, target, include=include, exclude=exclude, debug_map=debug_map_2)
|
||||||
|
if debug_map_2 or len(failed_keys_2) != len(loras):
|
||||||
|
raise ValueError(
|
||||||
|
f"sanity check failed: {len(debug_map_2)} / {len(loras)} LoRA layers attached twice, {len(failed_keys_2)} skipped"
|
||||||
|
)
|
||||||
|
|
||||||
|
if debug_map is not None:
|
||||||
|
debug_map += debug_map_1
|
||||||
|
return failed_keys_1
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from typing import Any, Iterator, cast
|
from typing import Any, Iterator, cast
|
||||||
from warnings import warn
|
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
@ -115,9 +114,7 @@ class SDLoraManager:
|
||||||
(keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)
|
(keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)
|
||||||
"""
|
"""
|
||||||
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
|
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
|
||||||
failed = auto_attach_loras(text_encoder_loras, self.clip_text_encoder)
|
auto_attach_loras(text_encoder_loras, self.clip_text_encoder)
|
||||||
if failed:
|
|
||||||
warn(f"failed to attach {len(failed)}/{len(text_encoder_loras)} loras to the text encoder")
|
|
||||||
|
|
||||||
def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None:
|
def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None:
|
||||||
"""Add multiple LoRAs to the U-Net.
|
"""Add multiple LoRAs to the U-Net.
|
||||||
|
@ -130,9 +127,7 @@ class SDLoraManager:
|
||||||
exclude = [
|
exclude = [
|
||||||
block for s, block in self.unet_exclusions.items() if all([s not in key for key in unet_loras.keys()])
|
block for s, block in self.unet_exclusions.items() if all([s not in key for key in unet_loras.keys()])
|
||||||
]
|
]
|
||||||
failed = auto_attach_loras(unet_loras, self.unet, exclude=exclude)
|
auto_attach_loras(unet_loras, self.unet, exclude=exclude)
|
||||||
if failed:
|
|
||||||
warn(f"failed to attach {len(failed)}/{len(unet_loras)} loras to the unet")
|
|
||||||
|
|
||||||
def remove_loras(self, *names: str) -> None:
|
def remove_loras(self, *names: str) -> None:
|
||||||
"""Remove multiple LoRAs from the target.
|
"""Remove multiple LoRAs from the target.
|
||||||
|
|
Loading…
Reference in a new issue