Improve filtering when auto-attaching LoRAs.

Also support debug output to help diagnose bad mappings.
This commit is contained in:
Pierre Chapuis 2024-02-15 18:59:32 +01:00
parent f8d55ccb20
commit fafe5f8f5a

View file

@ -133,15 +133,23 @@ class Lora(Generic[T], fl.Chain, ABC):
... ...
def auto_attach( def auto_attach(
self, target: fl.Chain, exclude: list[str] | None = None self,
target: fl.Chain,
include: list[str] | None = None,
exclude: list[str] | None = None,
) -> "tuple[LoraAdapter, fl.Chain | None] | None": ) -> "tuple[LoraAdapter, fl.Chain | None] | None":
for layer, parent in target.walk(self.up.__class__): for layer, parent in target.walk(self.up.__class__):
if isinstance(parent, Lora): if isinstance(parent, Lora):
continue continue
if exclude is not None and any( all_parents = []
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude] if include is not None or exclude is not None:
): all_parents = parent.get_parents() + [parent]
if include is not None and all((p.__class__.__name__ not in include) for p in all_parents):
continue
if exclude is not None and any((p.__class__.__name__ in exclude) for p in all_parents):
continue continue
if not self.is_compatible(layer): if not self.is_compatible(layer):
@ -443,24 +451,32 @@ def auto_attach_loras(
loras: dict[str, Lora[Any]], loras: dict[str, Lora[Any]],
target: fl.Chain, target: fl.Chain,
/, /,
include: list[str] | None = None,
exclude: list[str] | None = None, exclude: list[str] | None = None,
debug_map: list[tuple[str, str]] | None = None,
) -> list[str]: ) -> list[str]:
"""Auto-attach several LoRA layers to a Chain. """Auto-attach several LoRA layers to a Chain.
Args: Args:
loras: A dictionary of LoRA layers associated to their respective key. loras: A dictionary of LoRA layers associated to their respective key.
target: The target Chain. target: The target Chain.
include: A list of layer names, only layers with such a layer in its parents will be considered.
exclude: A list of layer names, layers with such a layer in its parents will not be considered.
debug_map: Pass a list to get a debug mapping of key - path pairs.
Returns: Returns:
A list of keys of LoRA layers which failed to attach. A list of keys of LoRA layers which failed to attach.
""" """
failed_keys: list[str] = [] failed_keys: list[str] = []
for key, lora in loras.items(): for key, lora in loras.items():
if attached := lora.auto_attach(target, exclude=exclude): if attached := lora.auto_attach(target, include=include, exclude=exclude):
adapter, parent = attached adapter, parent = attached
if parent is None: if parent is None:
# `adapter` is already attached and `lora` has been added to it # `adapter` is already attached and `lora` has been added to it
continue continue
if debug_map is not None:
path = adapter.target.get_path(parent)
debug_map.append((key, path))
adapter.inject(parent) adapter.inject(parent)
else: else:
failed_keys.append(key) failed_keys.append(key)