mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 22:58:45 +00:00
improve docstrings
This commit is contained in:
parent
5d784bedab
commit
975560165c
|
@ -486,10 +486,11 @@ def auto_attach_loras(
|
||||||
"""Auto-attach several LoRA layers to a Chain.
|
"""Auto-attach several LoRA layers to a Chain.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
loras: A dictionary of LoRA layers associated to their respective key.
|
loras: A dictionary of LoRA layers associated to their respective key. The keys are typically
|
||||||
|
derived from the state dict and only used for `debug_map` and the return value.
|
||||||
target: The target Chain.
|
target: The target Chain.
|
||||||
include: A list of layer names, only layers with such a layer in its parents will be considered.
|
include: A list of layer names, only layers with such a layer in their ancestors will be considered.
|
||||||
exclude: A list of layer names, layers with such a layer in its parents will not be considered.
|
exclude: A list of layer names, layers with such a layer in their ancestors will not be considered.
|
||||||
sanity_check: Check that LoRAs passed are correctly attached.
|
sanity_check: Check that LoRAs passed are correctly attached.
|
||||||
debug_map: Pass a list to get a debug mapping of key - path pairs of attached points.
|
debug_map: Pass a list to get a debug mapping of key - path pairs of attached points.
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -507,7 +508,7 @@ def auto_attach_loras(
|
||||||
f"sanity check failed: {len(debug_map_1)} / {len(loras)} LoRA layers attached, {len(failed_keys_1)} failed"
|
f"sanity check failed: {len(debug_map_1)} / {len(loras)} LoRA layers attached, {len(failed_keys_1)} failed"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sanity check: if we re-run the attach, all layers should fail.
|
# Extra sanity check: if we re-run the attach, all layers should fail.
|
||||||
debug_map_2: list[tuple[str, str]] = []
|
debug_map_2: list[tuple[str, str]] = []
|
||||||
failed_keys_2 = _auto_attach_loras(loras_copy, target, include=include, exclude=exclude, debug_map=debug_map_2)
|
failed_keys_2 = _auto_attach_loras(loras_copy, target, include=include, exclude=exclude, debug_map=debug_map_2)
|
||||||
if debug_map_2 or len(failed_keys_2) != len(loras):
|
if debug_map_2 or len(failed_keys_2) != len(loras):
|
||||||
|
|
|
@ -61,6 +61,19 @@ class SDLoraManager:
|
||||||
name: The name of the LoRA.
|
name: The name of the LoRA.
|
||||||
tensors: The `state_dict` of the LoRA to load.
|
tensors: The `state_dict` of the LoRA to load.
|
||||||
scale: The scale to use for the LoRA.
|
scale: The scale to use for the LoRA.
|
||||||
|
unet_inclusions: A list of layer names, only layers with such a layer
|
||||||
|
in their ancestors will be considered when patching the UNet.
|
||||||
|
unet_exclusions: A list of layer names, layers with such a layer in
|
||||||
|
their ancestors will not be considered when patching the UNet.
|
||||||
|
If this is `None` then it defaults to `["TimestepEncoder"]`.
|
||||||
|
unet_preprocess: A map between parts of state dict keys and layer names.
|
||||||
|
This is used to attach some keys to specific parts of the UNet.
|
||||||
|
You should leave it set to `None` (it has a default value),
|
||||||
|
otherwise read the source code to understand how it works.
|
||||||
|
text_encoder_inclusions: A list of layer names, only layers with such a layer
|
||||||
|
in their ancestors will be considered when patching the text encoder.
|
||||||
|
text_encoder_exclusions: A list of layer names, layers with such a layer in
|
||||||
|
their ancestors will not be considered when patching the text encoder.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AssertionError: If the Manager already has a LoRA with the same name.
|
AssertionError: If the Manager already has a LoRA with the same name.
|
||||||
|
@ -117,15 +130,22 @@ class SDLoraManager:
|
||||||
/,
|
/,
|
||||||
include: list[str] | None = None,
|
include: list[str] | None = None,
|
||||||
exclude: list[str] | None = None,
|
exclude: list[str] | None = None,
|
||||||
|
debug_map: list[tuple[str, str]] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add multiple LoRAs to the text encoder.
|
"""Add multiple LoRAs to the text encoder. See `add_loras` for details about arguments.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
loras: The dictionary of LoRAs to add to the text encoder.
|
loras: The dictionary of LoRAs to add to the text encoder.
|
||||||
(keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)
|
(keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)
|
||||||
"""
|
"""
|
||||||
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
|
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
|
||||||
auto_attach_loras(text_encoder_loras, self.clip_text_encoder, exclude=exclude, include=include)
|
auto_attach_loras(
|
||||||
|
text_encoder_loras,
|
||||||
|
self.clip_text_encoder,
|
||||||
|
exclude=exclude,
|
||||||
|
include=include,
|
||||||
|
debug_map=debug_map,
|
||||||
|
)
|
||||||
|
|
||||||
def add_loras_to_unet(
|
def add_loras_to_unet(
|
||||||
self,
|
self,
|
||||||
|
@ -136,7 +156,7 @@ class SDLoraManager:
|
||||||
preprocess: dict[str, str] | None = None,
|
preprocess: dict[str, str] | None = None,
|
||||||
debug_map: list[tuple[str, str]] | None = None,
|
debug_map: list[tuple[str, str]] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add multiple LoRAs to the U-Net.
|
"""Add multiple LoRAs to the U-Net. See `add_loras` for details about arguments.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
loras: The dictionary of LoRAs to add to the U-Net.
|
loras: The dictionary of LoRAs to add to the U-Net.
|
||||||
|
|
Loading…
Reference in a new issue