mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
export ControlLora
and ControlLoraAdapter
in refiners.foundationals.latent_diffusion.stable_diffusion_xl
This commit is contained in:
parent
383793b534
commit
684303230d
|
@ -1,3 +1,4 @@
|
|||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.control_lora import ControlLora, ControlLoraAdapter
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import SDXLAutoencoder, StableDiffusion_XL
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.t2i_adapter import SDXLT2IAdapter
|
||||
|
@ -11,4 +12,6 @@ __all__ = [
|
|||
"SDXLAutoencoder",
|
||||
"SDXLIPAdapter",
|
||||
"SDXLT2IAdapter",
|
||||
"ControlLora",
|
||||
"ControlLoraAdapter",
|
||||
]
|
||||
|
|
|
@ -133,10 +133,11 @@ class ZeroConvolution(Passthrough):
|
|||
|
||||
|
||||
class ControlLora(Passthrough):
|
||||
"""ControlLora is a Half-UNet clone of the target UNet, patched with LoRAs.
|
||||
"""ControlLora is a Half-UNet clone of the target UNet,
|
||||
patched with various `LoRA` layers, `ZeroConvolution` layers, and a `ConditionEncoder`.
|
||||
|
||||
Like ControlNet, it injects residual tensors into the target UNet.
|
||||
See https://github.com/HighCWu/control-lora-v2 for more details.
|
||||
See <https://github.com/HighCWu/control-lora-v2> for more details.
|
||||
|
||||
Receives: Gets context:
|
||||
(Float[Tensor, "batch condition_channels width height"]): The input image.
|
||||
|
@ -228,7 +229,7 @@ class ControlLora(Passthrough):
|
|||
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
"""The scale of the injected residuals."""
|
||||
"""The scale of the residuals stored in the context."""
|
||||
zero_convolution_module = self.ensure_find(ZeroConvolution)
|
||||
return zero_convolution_module.scale
|
||||
|
||||
|
@ -239,9 +240,9 @@ class ControlLora(Passthrough):
|
|||
|
||||
|
||||
class ControlLoraAdapter(Chain, Adapter[SDXLUNet]):
|
||||
"""Adapter for ControlLora.
|
||||
"""Adapter for [`ControlLora`][refiners.foundationals.latent_diffusion.stable_diffusion_xl.ControlLora].
|
||||
|
||||
This adapter simply prepends a ControlLora model inside the target's UNet.
|
||||
This adapter simply prepends a `ControlLora` model inside the target `SDXLUNet`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -310,7 +311,7 @@ class ControlLoraAdapter(Chain, Adapter[SDXLUNet]):
|
|||
self,
|
||||
state_dict: dict[str, Tensor],
|
||||
) -> None:
|
||||
"""Load the weights from the state_dict into the ControlLora.
|
||||
"""Load the weights from the state_dict into the `ControlLora`.
|
||||
|
||||
Args:
|
||||
state_dict: The state_dict containing the weights to load.
|
||||
|
@ -325,7 +326,7 @@ class ControlLoraAdapter(Chain, Adapter[SDXLUNet]):
|
|||
state_dict: dict[str, Tensor],
|
||||
control_lora: ControlLora,
|
||||
) -> None:
|
||||
"""Load the LoRA layers from the state_dict into the ControlLora.
|
||||
"""Load the [`LoRA`][refiners.fluxion.adapters.lora.Lora] layers from the state_dict into the `ControlLora`.
|
||||
|
||||
Args:
|
||||
name: The name of the ControlLora.
|
||||
|
@ -366,7 +367,7 @@ class ControlLoraAdapter(Chain, Adapter[SDXLUNet]):
|
|||
state_dict: dict[str, Tensor],
|
||||
control_lora: ControlLora,
|
||||
):
|
||||
"""Load the ZeroConvolution layers from the state_dict into the ControlLora.
|
||||
"""Load the `ZeroConvolution` layers from the state_dict into the `ControlLora`.
|
||||
|
||||
Args:
|
||||
state_dict: The state_dict containing the ZeroConvolution layers to load.
|
||||
|
@ -386,7 +387,7 @@ class ControlLoraAdapter(Chain, Adapter[SDXLUNet]):
|
|||
state_dict: dict[str, Tensor],
|
||||
control_lora: ControlLora,
|
||||
):
|
||||
"""Load the ConditionEncoder layers from the state_dict into the ControlLora.
|
||||
"""Load the `ConditionEncoder`'s layers from the state_dict into the `ControlLora`.
|
||||
|
||||
Args:
|
||||
state_dict: The state_dict containing the ConditionEncoder layers to load.
|
||||
|
|
Loading…
Reference in a new issue