improve docstrings

This commit is contained in:
Pierre Chapuis 2024-03-08 14:37:18 +01:00
parent 5d784bedab
commit 975560165c
2 changed files with 28 additions and 7 deletions

View file

@ -486,10 +486,11 @@ def auto_attach_loras(
"""Auto-attach several LoRA layers to a Chain. """Auto-attach several LoRA layers to a Chain.
Args: Args:
loras: A dictionary of LoRA layers associated to their respective key. loras: A dictionary of LoRA layers associated to their respective key. The keys are typically
derived from the state dict and only used for `debug_map` and the return value.
target: The target Chain. target: The target Chain.
include: A list of layer names, only layers with such a layer in its parents will be considered. include: A list of layer names, only layers with such a layer in their ancestors will be considered.
exclude: A list of layer names, layers with such a layer in its parents will not be considered. exclude: A list of layer names, layers with such a layer in their ancestors will not be considered.
sanity_check: Check that LoRAs passed are correctly attached. sanity_check: Check that LoRAs passed are correctly attached.
debug_map: Pass a list to get a debug mapping of key - path pairs of attached points. debug_map: Pass a list to get a debug mapping of key - path pairs of attached points.
Returns: Returns:
@ -507,7 +508,7 @@ def auto_attach_loras(
f"sanity check failed: {len(debug_map_1)} / {len(loras)} LoRA layers attached, {len(failed_keys_1)} failed" f"sanity check failed: {len(debug_map_1)} / {len(loras)} LoRA layers attached, {len(failed_keys_1)} failed"
) )
# Sanity check: if we re-run the attach, all layers should fail. # Extra sanity check: if we re-run the attach, all layers should fail.
debug_map_2: list[tuple[str, str]] = [] debug_map_2: list[tuple[str, str]] = []
failed_keys_2 = _auto_attach_loras(loras_copy, target, include=include, exclude=exclude, debug_map=debug_map_2) failed_keys_2 = _auto_attach_loras(loras_copy, target, include=include, exclude=exclude, debug_map=debug_map_2)
if debug_map_2 or len(failed_keys_2) != len(loras): if debug_map_2 or len(failed_keys_2) != len(loras):

View file

@ -61,6 +61,19 @@ class SDLoraManager:
name: The name of the LoRA. name: The name of the LoRA.
tensors: The `state_dict` of the LoRA to load. tensors: The `state_dict` of the LoRA to load.
scale: The scale to use for the LoRA. scale: The scale to use for the LoRA.
unet_inclusions: A list of layer names, only layers with such a layer
in their ancestors will be considered when patching the UNet.
unet_exclusions: A list of layer names, layers with such a layer in
their ancestors will not be considered when patching the UNet.
If this is `None` then it defaults to `["TimestepEncoder"]`.
unet_preprocess: A map between parts of state dict keys and layer names.
This is used to attach some keys to specific parts of the UNet.
You should leave it set to `None` (it has a default value),
otherwise read the source code to understand how it works.
text_encoder_inclusions: A list of layer names, only layers with such a layer
in their ancestors will be considered when patching the text encoder.
text_encoder_exclusions: A list of layer names, layers with such a layer in
their ancestors will not be considered when patching the text encoder.
Raises: Raises:
AssertionError: If the Manager already has a LoRA with the same name. AssertionError: If the Manager already has a LoRA with the same name.
@ -117,15 +130,22 @@ class SDLoraManager:
/, /,
include: list[str] | None = None, include: list[str] | None = None,
exclude: list[str] | None = None, exclude: list[str] | None = None,
debug_map: list[tuple[str, str]] | None = None,
) -> None: ) -> None:
"""Add multiple LoRAs to the text encoder. """Add multiple LoRAs to the text encoder. See `add_loras` for details about arguments.
Args: Args:
loras: The dictionary of LoRAs to add to the text encoder. loras: The dictionary of LoRAs to add to the text encoder.
(keys are the names of the LoRAs, values are the LoRAs to add to the text encoder) (keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)
""" """
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key} text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
auto_attach_loras(text_encoder_loras, self.clip_text_encoder, exclude=exclude, include=include) auto_attach_loras(
text_encoder_loras,
self.clip_text_encoder,
exclude=exclude,
include=include,
debug_map=debug_map,
)
def add_loras_to_unet( def add_loras_to_unet(
self, self,
@ -136,7 +156,7 @@ class SDLoraManager:
preprocess: dict[str, str] | None = None, preprocess: dict[str, str] | None = None,
debug_map: list[tuple[str, str]] | None = None, debug_map: list[tuple[str, str]] | None = None,
) -> None: ) -> None:
"""Add multiple LoRAs to the U-Net. """Add multiple LoRAs to the U-Net. See `add_loras` for details about arguments.
Args: Args:
loras: The dictionary of LoRAs to add to the U-Net. loras: The dictionary of LoRAs to add to the U-Net.