export ControlLora and ControlLoraAdapter in refiners.foundationals.latent_diffusion.stable_diffusion_xl

This commit is contained in:
Laurent 2024-02-15 09:30:21 +00:00 committed by Laureηt
parent 383793b534
commit 684303230d
2 changed files with 13 additions and 9 deletions

View file

@ -1,3 +1,4 @@
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.control_lora import ControlLora, ControlLoraAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter from refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import SDXLAutoencoder, StableDiffusion_XL from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import SDXLAutoencoder, StableDiffusion_XL
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.t2i_adapter import SDXLT2IAdapter from refiners.foundationals.latent_diffusion.stable_diffusion_xl.t2i_adapter import SDXLT2IAdapter
@ -11,4 +12,6 @@ __all__ = [
"SDXLAutoencoder", "SDXLAutoencoder",
"SDXLIPAdapter", "SDXLIPAdapter",
"SDXLT2IAdapter", "SDXLT2IAdapter",
"ControlLora",
"ControlLoraAdapter",
] ]

View file

@ -133,10 +133,11 @@ class ZeroConvolution(Passthrough):
class ControlLora(Passthrough): class ControlLora(Passthrough):
"""ControlLora is a Half-UNet clone of the target UNet, patched with LoRAs. """ControlLora is a Half-UNet clone of the target UNet,
patched with various `LoRA` layers, `ZeroConvolution` layers, and a `ConditionEncoder`.
Like ControlNet, it injects residual tensors into the target UNet. Like ControlNet, it injects residual tensors into the target UNet.
See https://github.com/HighCWu/control-lora-v2 for more details. See <https://github.com/HighCWu/control-lora-v2> for more details.
Receives: Gets context: Receives: Gets context:
(Float[Tensor, "batch condition_channels width height"]): The input image. (Float[Tensor, "batch condition_channels width height"]): The input image.
@ -228,7 +229,7 @@ class ControlLora(Passthrough):
@property @property
def scale(self) -> float: def scale(self) -> float:
"""The scale of the injected residuals.""" """The scale of the residuals stored in the context."""
zero_convolution_module = self.ensure_find(ZeroConvolution) zero_convolution_module = self.ensure_find(ZeroConvolution)
return zero_convolution_module.scale return zero_convolution_module.scale
@ -239,9 +240,9 @@ class ControlLora(Passthrough):
class ControlLoraAdapter(Chain, Adapter[SDXLUNet]): class ControlLoraAdapter(Chain, Adapter[SDXLUNet]):
"""Adapter for ControlLora. """Adapter for [`ControlLora`][refiners.foundationals.latent_diffusion.stable_diffusion_xl.ControlLora].
This adapter simply prepends a ControlLora model inside the target's UNet. This adapter simply prepends a `ControlLora` model inside the target `SDXLUNet`.
""" """
def __init__( def __init__(
@ -310,7 +311,7 @@ class ControlLoraAdapter(Chain, Adapter[SDXLUNet]):
self, self,
state_dict: dict[str, Tensor], state_dict: dict[str, Tensor],
) -> None: ) -> None:
"""Load the weights from the state_dict into the ControlLora. """Load the weights from the state_dict into the `ControlLora`.
Args: Args:
state_dict: The state_dict containing the weights to load. state_dict: The state_dict containing the weights to load.
@ -325,7 +326,7 @@ class ControlLoraAdapter(Chain, Adapter[SDXLUNet]):
state_dict: dict[str, Tensor], state_dict: dict[str, Tensor],
control_lora: ControlLora, control_lora: ControlLora,
) -> None: ) -> None:
"""Load the LoRA layers from the state_dict into the ControlLora. """Load the [`LoRA`][refiners.fluxion.adapters.lora.Lora] layers from the state_dict into the `ControlLora`.
Args: Args:
name: The name of the ControlLora. name: The name of the ControlLora.
@ -366,7 +367,7 @@ class ControlLoraAdapter(Chain, Adapter[SDXLUNet]):
state_dict: dict[str, Tensor], state_dict: dict[str, Tensor],
control_lora: ControlLora, control_lora: ControlLora,
): ):
"""Load the ZeroConvolution layers from the state_dict into the ControlLora. """Load the `ZeroConvolution` layers from the state_dict into the `ControlLora`.
Args: Args:
state_dict: The state_dict containing the ZeroConvolution layers to load. state_dict: The state_dict containing the ZeroConvolution layers to load.
@ -386,7 +387,7 @@ class ControlLoraAdapter(Chain, Adapter[SDXLUNet]):
state_dict: dict[str, Tensor], state_dict: dict[str, Tensor],
control_lora: ControlLora, control_lora: ControlLora,
): ):
"""Load the ConditionEncoder layers from the state_dict into the ControlLora. """Load the `ConditionEncoder`'s layers from the state_dict into the `ControlLora`.
Args: Args:
state_dict: The state_dict containing the ConditionEncoder layers to load. state_dict: The state_dict containing the ConditionEncoder layers to load.