improve docstrings

This commit is contained in:
Pierre Chapuis 2024-03-08 14:37:18 +01:00
parent 5d784bedab
commit 975560165c
2 changed files with 28 additions and 7 deletions

View file

@ -486,10 +486,11 @@ def auto_attach_loras(
"""Auto-attach several LoRA layers to a Chain.
Args:
loras: A dictionary of LoRA layers associated to their respective key.
loras: A dictionary of LoRA layers associated to their respective key. The keys are typically
derived from the state dict and only used for `debug_map` and the return value.
target: The target Chain.
include: A list of layer names, only layers with such a layer in its parents will be considered.
exclude: A list of layer names, layers with such a layer in its parents will not be considered.
include: A list of layer names, only layers with such a layer in their ancestors will be considered.
exclude: A list of layer names, layers with such a layer in their ancestors will not be considered.
sanity_check: Check that LoRAs passed are correctly attached.
debug_map: Pass a list to get a debug mapping of key - path pairs of attached points.
Returns:
@ -507,7 +508,7 @@ def auto_attach_loras(
f"sanity check failed: {len(debug_map_1)} / {len(loras)} LoRA layers attached, {len(failed_keys_1)} failed"
)
# Sanity check: if we re-run the attach, all layers should fail.
# Extra sanity check: if we re-run the attach, all layers should fail.
debug_map_2: list[tuple[str, str]] = []
failed_keys_2 = _auto_attach_loras(loras_copy, target, include=include, exclude=exclude, debug_map=debug_map_2)
if debug_map_2 or len(failed_keys_2) != len(loras):

View file

@ -61,6 +61,19 @@ class SDLoraManager:
name: The name of the LoRA.
tensors: The `state_dict` of the LoRA to load.
scale: The scale to use for the LoRA.
unet_inclusions: A list of layer names, only layers with such a layer
in their ancestors will be considered when patching the UNet.
unet_exclusions: A list of layer names, layers with such a layer in
their ancestors will not be considered when patching the UNet.
If this is `None` then it defaults to `["TimestepEncoder"]`.
unet_preprocess: A map between parts of state dict keys and layer names.
This is used to attach some keys to specific parts of the UNet.
You should leave it set to `None` (it has a default value),
otherwise read the source code to understand how it works.
text_encoder_inclusions: A list of layer names, only layers with such a layer
in their ancestors will be considered when patching the text encoder.
text_encoder_exclusions: A list of layer names, layers with such a layer in
their ancestors will not be considered when patching the text encoder.
Raises:
AssertionError: If the Manager already has a LoRA with the same name.
@ -117,15 +130,22 @@ class SDLoraManager:
/,
include: list[str] | None = None,
exclude: list[str] | None = None,
debug_map: list[tuple[str, str]] | None = None,
) -> None:
"""Add multiple LoRAs to the text encoder.
"""Add multiple LoRAs to the text encoder. See `add_loras` for details about arguments.
Args:
loras: The dictionary of LoRAs to add to the text encoder.
(keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)
"""
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
auto_attach_loras(text_encoder_loras, self.clip_text_encoder, exclude=exclude, include=include)
auto_attach_loras(
text_encoder_loras,
self.clip_text_encoder,
exclude=exclude,
include=include,
debug_map=debug_map,
)
def add_loras_to_unet(
self,
@ -136,7 +156,7 @@ class SDLoraManager:
preprocess: dict[str, str] | None = None,
debug_map: list[tuple[str, str]] | None = None,
) -> None:
"""Add multiple LoRAs to the U-Net.
"""Add multiple LoRAs to the U-Net. See `add_loras` for details about arguments.
Args:
loras: The dictionary of LoRAs to add to the U-Net.