mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
Improve filtering when auto-attaching LoRAs.
Also support debug output to help diagnose bad mappings.
This commit is contained in:
parent
f8d55ccb20
commit
fafe5f8f5a
|
@ -133,15 +133,23 @@ class Lora(Generic[T], fl.Chain, ABC):
|
||||||
...
|
...
|
||||||
|
|
||||||
def auto_attach(
|
def auto_attach(
|
||||||
self, target: fl.Chain, exclude: list[str] | None = None
|
self,
|
||||||
|
target: fl.Chain,
|
||||||
|
include: list[str] | None = None,
|
||||||
|
exclude: list[str] | None = None,
|
||||||
) -> "tuple[LoraAdapter, fl.Chain | None] | None":
|
) -> "tuple[LoraAdapter, fl.Chain | None] | None":
|
||||||
for layer, parent in target.walk(self.up.__class__):
|
for layer, parent in target.walk(self.up.__class__):
|
||||||
if isinstance(parent, Lora):
|
if isinstance(parent, Lora):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if exclude is not None and any(
|
all_parents = []
|
||||||
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
|
if include is not None or exclude is not None:
|
||||||
):
|
all_parents = parent.get_parents() + [parent]
|
||||||
|
|
||||||
|
if include is not None and all((p.__class__.__name__ not in include) for p in all_parents):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if exclude is not None and any((p.__class__.__name__ in exclude) for p in all_parents):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not self.is_compatible(layer):
|
if not self.is_compatible(layer):
|
||||||
|
@ -443,24 +451,32 @@ def auto_attach_loras(
|
||||||
loras: dict[str, Lora[Any]],
|
loras: dict[str, Lora[Any]],
|
||||||
target: fl.Chain,
|
target: fl.Chain,
|
||||||
/,
|
/,
|
||||||
|
include: list[str] | None = None,
|
||||||
exclude: list[str] | None = None,
|
exclude: list[str] | None = None,
|
||||||
|
debug_map: list[tuple[str, str]] | None = None,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Auto-attach several LoRA layers to a Chain.
|
"""Auto-attach several LoRA layers to a Chain.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
loras: A dictionary of LoRA layers associated to their respective key.
|
loras: A dictionary of LoRA layers associated to their respective key.
|
||||||
target: The target Chain.
|
target: The target Chain.
|
||||||
|
include: A list of layer names, only layers with such a layer in its parents will be considered.
|
||||||
|
exclude: A list of layer names, layers with such a layer in its parents will not be considered.
|
||||||
|
debug_map: Pass a list to get a debug mapping of key - path pairs.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of keys of LoRA layers which failed to attach.
|
A list of keys of LoRA layers which failed to attach.
|
||||||
"""
|
"""
|
||||||
failed_keys: list[str] = []
|
failed_keys: list[str] = []
|
||||||
for key, lora in loras.items():
|
for key, lora in loras.items():
|
||||||
if attached := lora.auto_attach(target, exclude=exclude):
|
if attached := lora.auto_attach(target, include=include, exclude=exclude):
|
||||||
adapter, parent = attached
|
adapter, parent = attached
|
||||||
if parent is None:
|
if parent is None:
|
||||||
# `adapter` is already attached and `lora` has been added to it
|
# `adapter` is already attached and `lora` has been added to it
|
||||||
continue
|
continue
|
||||||
|
if debug_map is not None:
|
||||||
|
path = adapter.target.get_path(parent)
|
||||||
|
debug_map.append((key, path))
|
||||||
adapter.inject(parent)
|
adapter.inject(parent)
|
||||||
else:
|
else:
|
||||||
failed_keys.append(key)
|
failed_keys.append(key)
|
||||||
|
|
Loading…
Reference in a new issue