diff --git a/src/refiners/fluxion/adapters/lora.py b/src/refiners/fluxion/adapters/lora.py index 19f6d00..29a4bb9 100644 --- a/src/refiners/fluxion/adapters/lora.py +++ b/src/refiners/fluxion/adapters/lora.py @@ -486,10 +486,11 @@ def auto_attach_loras( """Auto-attach several LoRA layers to a Chain. Args: - loras: A dictionary of LoRA layers associated to their respective key. + loras: A dictionary of LoRA layers associated to their respective key. The keys are typically + derived from the state dict and only used for `debug_map` and the return value. target: The target Chain. - include: A list of layer names, only layers with such a layer in its parents will be considered. - exclude: A list of layer names, layers with such a layer in its parents will not be considered. + include: A list of layer names, only layers with such a layer in their ancestors will be considered. + exclude: A list of layer names, layers with such a layer in their ancestors will not be considered. sanity_check: Check that LoRAs passed are correctly attached. debug_map: Pass a list to get a debug mapping of key - path pairs of attached points. Returns: @@ -507,7 +508,7 @@ def auto_attach_loras( f"sanity check failed: {len(debug_map_1)} / {len(loras)} LoRA layers attached, {len(failed_keys_1)} failed" ) - # Sanity check: if we re-run the attach, all layers should fail. + # Extra sanity check: if we re-run the attach, all layers should fail. debug_map_2: list[tuple[str, str]] = [] failed_keys_2 = _auto_attach_loras(loras_copy, target, include=include, exclude=exclude, debug_map=debug_map_2) if debug_map_2 or len(failed_keys_2) != len(loras): diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index e70190c..16654b1 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -61,6 +61,19 @@ class SDLoraManager: name: The name of the LoRA. tensors: The `state_dict` of the LoRA to load. scale: The scale to use for the LoRA. + unet_inclusions: A list of layer names, only layers with such a layer + in their ancestors will be considered when patching the UNet. + unet_exclusions: A list of layer names, layers with such a layer in + their ancestors will not be considered when patching the UNet. + If this is `None` then it defaults to `["TimestepEncoder"]`. + unet_preprocess: A map between parts of state dict keys and layer names. + This is used to attach some keys to specific parts of the UNet. + You should leave it set to `None` (it has a default value), + otherwise read the source code to understand how it works. + text_encoder_inclusions: A list of layer names, only layers with such a layer + in their ancestors will be considered when patching the text encoder. + text_encoder_exclusions: A list of layer names, layers with such a layer in + their ancestors will not be considered when patching the text encoder. Raises: AssertionError: If the Manager already has a LoRA with the same name. @@ -117,15 +130,22 @@ class SDLoraManager: /, include: list[str] | None = None, exclude: list[str] | None = None, + debug_map: list[tuple[str, str]] | None = None, ) -> None: - """Add multiple LoRAs to the text encoder. + """Add multiple LoRAs to the text encoder. See `add_loras` for details about arguments. Args: loras: The dictionary of LoRAs to add to the text encoder. (keys are the names of the LoRAs, values are the LoRAs to add to the text encoder) """ text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key} - auto_attach_loras(text_encoder_loras, self.clip_text_encoder, exclude=exclude, include=include) + auto_attach_loras( + text_encoder_loras, + self.clip_text_encoder, + exclude=exclude, + include=include, + debug_map=debug_map, + ) def add_loras_to_unet( self, @@ -136,7 +156,7 @@ class SDLoraManager: preprocess: dict[str, str] | None = None, debug_map: list[tuple[str, str]] | None = None, ) -> None: - """Add multiple LoRAs to the U-Net. + """Add multiple LoRAs to the U-Net. See `add_loras` for details about arguments. Args: loras: The dictionary of LoRAs to add to the U-Net.