diff --git a/src/refiners/fluxion/adapters/lora.py b/src/refiners/fluxion/adapters/lora.py index 4d73b7b..1743bc3 100644 --- a/src/refiners/fluxion/adapters/lora.py +++ b/src/refiners/fluxion/adapters/lora.py @@ -133,15 +133,23 @@ class Lora(Generic[T], fl.Chain, ABC): ... def auto_attach( - self, target: fl.Chain, exclude: list[str] | None = None + self, + target: fl.Chain, + include: list[str] | None = None, + exclude: list[str] | None = None, ) -> "tuple[LoraAdapter, fl.Chain | None] | None": for layer, parent in target.walk(self.up.__class__): if isinstance(parent, Lora): continue - if exclude is not None and any( - [any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude] - ): + all_parents = [] + if include is not None or exclude is not None: + all_parents = parent.get_parents() + [parent] + + if include is not None and all((p.__class__.__name__ not in include) for p in all_parents): + continue + + if exclude is not None and any((p.__class__.__name__ in exclude) for p in all_parents): continue if not self.is_compatible(layer): @@ -443,24 +451,32 @@ def auto_attach_loras( loras: dict[str, Lora[Any]], target: fl.Chain, /, + include: list[str] | None = None, exclude: list[str] | None = None, + debug_map: list[tuple[str, str]] | None = None, ) -> list[str]: """Auto-attach several LoRA layers to a Chain. Args: loras: A dictionary of LoRA layers associated to their respective key. target: The target Chain. + include: A list of layer names, only layers with such a layer in its parents will be considered. + exclude: A list of layer names, layers with such a layer in its parents will not be considered. + debug_map: Pass a list to get a debug mapping of key - path pairs. Returns: A list of keys of LoRA layers which failed to attach. """ failed_keys: list[str] = [] for key, lora in loras.items(): - if attached := lora.auto_attach(target, exclude=exclude): + if attached := lora.auto_attach(target, include=include, exclude=exclude): adapter, parent = attached if parent is None: # `adapter` is already attached and `lora` has been added to it continue + if debug_map is not None: + path = adapter.target.get_path(parent) + debug_map.append((key, path)) adapter.inject(parent) else: failed_keys.append(key)