mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
export ControlLora
and ControlLoraAdapter
in refiners.foundationals.latent_diffusion.stable_diffusion_xl
This commit is contained in:
parent
383793b534
commit
684303230d
|
@ -1,3 +1,4 @@
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.control_lora import ControlLora, ControlLoraAdapter
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import SDXLAutoencoder, StableDiffusion_XL
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import SDXLAutoencoder, StableDiffusion_XL
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.t2i_adapter import SDXLT2IAdapter
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.t2i_adapter import SDXLT2IAdapter
|
||||||
|
@ -11,4 +12,6 @@ __all__ = [
|
||||||
"SDXLAutoencoder",
|
"SDXLAutoencoder",
|
||||||
"SDXLIPAdapter",
|
"SDXLIPAdapter",
|
||||||
"SDXLT2IAdapter",
|
"SDXLT2IAdapter",
|
||||||
|
"ControlLora",
|
||||||
|
"ControlLoraAdapter",
|
||||||
]
|
]
|
||||||
|
|
|
@ -133,10 +133,11 @@ class ZeroConvolution(Passthrough):
|
||||||
|
|
||||||
|
|
||||||
class ControlLora(Passthrough):
|
class ControlLora(Passthrough):
|
||||||
"""ControlLora is a Half-UNet clone of the target UNet, patched with LoRAs.
|
"""ControlLora is a Half-UNet clone of the target UNet,
|
||||||
|
patched with various `LoRA` layers, `ZeroConvolution` layers, and a `ConditionEncoder`.
|
||||||
|
|
||||||
Like ControlNet, it injects residual tensors into the target UNet.
|
Like ControlNet, it injects residual tensors into the target UNet.
|
||||||
See https://github.com/HighCWu/control-lora-v2 for more details.
|
See <https://github.com/HighCWu/control-lora-v2> for more details.
|
||||||
|
|
||||||
Receives: Gets context:
|
Receives: Gets context:
|
||||||
(Float[Tensor, "batch condition_channels width height"]): The input image.
|
(Float[Tensor, "batch condition_channels width height"]): The input image.
|
||||||
|
@ -228,7 +229,7 @@ class ControlLora(Passthrough):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scale(self) -> float:
|
def scale(self) -> float:
|
||||||
"""The scale of the injected residuals."""
|
"""The scale of the residuals stored in the context."""
|
||||||
zero_convolution_module = self.ensure_find(ZeroConvolution)
|
zero_convolution_module = self.ensure_find(ZeroConvolution)
|
||||||
return zero_convolution_module.scale
|
return zero_convolution_module.scale
|
||||||
|
|
||||||
|
@ -239,9 +240,9 @@ class ControlLora(Passthrough):
|
||||||
|
|
||||||
|
|
||||||
class ControlLoraAdapter(Chain, Adapter[SDXLUNet]):
|
class ControlLoraAdapter(Chain, Adapter[SDXLUNet]):
|
||||||
"""Adapter for ControlLora.
|
"""Adapter for [`ControlLora`][refiners.foundationals.latent_diffusion.stable_diffusion_xl.ControlLora].
|
||||||
|
|
||||||
This adapter simply prepends a ControlLora model inside the target's UNet.
|
This adapter simply prepends a `ControlLora` model inside the target `SDXLUNet`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -310,7 +311,7 @@ class ControlLoraAdapter(Chain, Adapter[SDXLUNet]):
|
||||||
self,
|
self,
|
||||||
state_dict: dict[str, Tensor],
|
state_dict: dict[str, Tensor],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Load the weights from the state_dict into the ControlLora.
|
"""Load the weights from the state_dict into the `ControlLora`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state_dict: The state_dict containing the weights to load.
|
state_dict: The state_dict containing the weights to load.
|
||||||
|
@ -325,7 +326,7 @@ class ControlLoraAdapter(Chain, Adapter[SDXLUNet]):
|
||||||
state_dict: dict[str, Tensor],
|
state_dict: dict[str, Tensor],
|
||||||
control_lora: ControlLora,
|
control_lora: ControlLora,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Load the LoRA layers from the state_dict into the ControlLora.
|
"""Load the [`LoRA`][refiners.fluxion.adapters.lora.Lora] layers from the state_dict into the `ControlLora`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: The name of the ControlLora.
|
name: The name of the ControlLora.
|
||||||
|
@ -366,7 +367,7 @@ class ControlLoraAdapter(Chain, Adapter[SDXLUNet]):
|
||||||
state_dict: dict[str, Tensor],
|
state_dict: dict[str, Tensor],
|
||||||
control_lora: ControlLora,
|
control_lora: ControlLora,
|
||||||
):
|
):
|
||||||
"""Load the ZeroConvolution layers from the state_dict into the ControlLora.
|
"""Load the `ZeroConvolution` layers from the state_dict into the `ControlLora`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state_dict: The state_dict containing the ZeroConvolution layers to load.
|
state_dict: The state_dict containing the ZeroConvolution layers to load.
|
||||||
|
@ -386,7 +387,7 @@ class ControlLoraAdapter(Chain, Adapter[SDXLUNet]):
|
||||||
state_dict: dict[str, Tensor],
|
state_dict: dict[str, Tensor],
|
||||||
control_lora: ControlLora,
|
control_lora: ControlLora,
|
||||||
):
|
):
|
||||||
"""Load the ConditionEncoder layers from the state_dict into the ControlLora.
|
"""Load the `ConditionEncoder`'s layers from the state_dict into the `ControlLora`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state_dict: The state_dict containing the ConditionEncoder layers to load.
|
state_dict: The state_dict containing the ConditionEncoder layers to load.
|
||||||
|
|
Loading…
Reference in a new issue