mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-25 07:38:45 +00:00
implement ControlLora
and ControlLoraAdapter
This commit is contained in:
parent
a54808e757
commit
41a5ce2052
|
@ -0,0 +1,401 @@
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
from refiners.fluxion.adapters.lora import Lora, LoraAdapter
|
||||||
|
from refiners.fluxion.context import Contexts
|
||||||
|
from refiners.fluxion.layers import Chain, Conv2d, Multiply, Passthrough, Residual, SiLU, UseContext
|
||||||
|
from refiners.fluxion.layers.module import WeightedModule
|
||||||
|
from refiners.foundationals.latent_diffusion.range_adapter import RangeAdapter2d
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualAccumulator
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||||
|
from refiners.foundationals.latent_diffusion.unet import ResidualBlock
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionEncoder(Chain):
|
||||||
|
"""Encode an image into a condition latent tensor.
|
||||||
|
|
||||||
|
Receives:
|
||||||
|
(Float[Tensor, "batch in_channels width height"]): The input image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Float[Tensor, "batch out_channels latent_width latent_height"]): The condition latent tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 3,
|
||||||
|
out_channels: int = 320,
|
||||||
|
intermediate_channels: tuple[int, ...] = (16, 32, 96, 256),
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the ConditionEncoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels: The number of channels of the image tensor.
|
||||||
|
out_channels: The number of channels of the latent tensor to encode the condition into.
|
||||||
|
intermediate_channels: The number of channels of the intermediate layers.
|
||||||
|
device: The PyTorch device to use.
|
||||||
|
dtype: The PyTorch data type to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
Chain(
|
||||||
|
Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=intermediate_channels[0],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
SiLU(),
|
||||||
|
),
|
||||||
|
*(
|
||||||
|
Chain(
|
||||||
|
Conv2d(
|
||||||
|
in_channels=intermediate_channels[i],
|
||||||
|
out_channels=intermediate_channels[i],
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
SiLU(),
|
||||||
|
Conv2d(
|
||||||
|
in_channels=intermediate_channels[i],
|
||||||
|
out_channels=intermediate_channels[i + 1],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
SiLU(),
|
||||||
|
)
|
||||||
|
for i in range(len(intermediate_channels) - 1)
|
||||||
|
),
|
||||||
|
Conv2d(
|
||||||
|
in_channels=intermediate_channels[-1],
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ZeroConvolution(Passthrough):
|
||||||
|
"""Transform and store the ControlLora's residuals in the context of the original UNet.
|
||||||
|
|
||||||
|
Receives:
|
||||||
|
(Float[Tensor, "batch in_channels width height"]): The input tensor to transform and store.
|
||||||
|
|
||||||
|
Returns: Updates context:
|
||||||
|
(Tensor): Add the residual to the nth residual of the target's UNet.
|
||||||
|
(context="unet", key="residuals")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
residual_index: int,
|
||||||
|
scale: float = 1.0,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the ZeroConvolution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels: The number of channels of the input tensor.
|
||||||
|
out_channels: The number of channels of the output tensor/residual.
|
||||||
|
residual_index: The index of the residual to store in the target's UNet.
|
||||||
|
scale: The scale to multiply the residuals by.
|
||||||
|
device: The PyTorch device to use.
|
||||||
|
dtype: The PyTorch data type to use.
|
||||||
|
"""
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
Multiply(scale=scale),
|
||||||
|
ResidualAccumulator(n=residual_index),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ControlLora(Passthrough):
|
||||||
|
"""ControlLora is a Half-UNet clone of the target UNet, patched with LoRAs.
|
||||||
|
|
||||||
|
Like ControlNet, it injects residual tensors into the target UNet.
|
||||||
|
See https://github.com/HighCWu/control-lora-v2 for more details.
|
||||||
|
|
||||||
|
Receives: Gets context:
|
||||||
|
(Float[Tensor, "batch condition_channels width height"]): The input image.
|
||||||
|
|
||||||
|
Returns: Sets context:
|
||||||
|
(list[Tensor]): The residuals to be added to the target UNet's residuals.
|
||||||
|
(context="unet", key="residuals")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
unet: SDXLUNet,
|
||||||
|
scale: float = 1.0,
|
||||||
|
condition_channels: int = 3,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the ControlLora.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the ControlLora.
|
||||||
|
unet: The target UNet.
|
||||||
|
scale: The scale to multiply the residuals by.
|
||||||
|
condition_channels: The number of channels of the input condition tensor.
|
||||||
|
"""
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
timestep_encoder := unet.layer("TimestepEncoder", Chain).structural_copy(),
|
||||||
|
downblocks := unet.layer("DownBlocks", Chain).structural_copy(),
|
||||||
|
middle_block := unet.layer("MiddleBlock", Chain).structural_copy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# modify the context_key of the copied TimestepEncoder to avoid conflicts
|
||||||
|
timestep_encoder.context_key = f"timestep_embedding_control_lora_{name}"
|
||||||
|
|
||||||
|
# modify the context_key of each RangeAdapter2d to avoid conflicts
|
||||||
|
for range_adapter in self.layers(RangeAdapter2d):
|
||||||
|
range_adapter.context_key = f"timestep_embedding_control_lora_{name}"
|
||||||
|
|
||||||
|
# insert the ConditionEncoder in the first DownBlock
|
||||||
|
first_downblock = downblocks.layer(0, Chain)
|
||||||
|
out_channels = first_downblock.layer(0, Conv2d).out_channels
|
||||||
|
first_downblock.append(
|
||||||
|
Residual(
|
||||||
|
UseContext(f"control_lora_{name}", f"condition"),
|
||||||
|
ConditionEncoder(
|
||||||
|
in_channels=condition_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
device=unet.device,
|
||||||
|
dtype=unet.dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# replace each ResidualAccumulator by a ZeroConvolution
|
||||||
|
for residual_accumulator in self.layers(ResidualAccumulator):
|
||||||
|
downblock = self.ensure_find_parent(residual_accumulator)
|
||||||
|
|
||||||
|
first_layer = downblock[0]
|
||||||
|
assert hasattr(first_layer, "out_channels"), f"{first_layer} has no out_channels attribute"
|
||||||
|
|
||||||
|
block_channels = first_layer.out_channels
|
||||||
|
assert isinstance(block_channels, int)
|
||||||
|
|
||||||
|
downblock.replace(
|
||||||
|
residual_accumulator,
|
||||||
|
ZeroConvolution(
|
||||||
|
scale=scale,
|
||||||
|
residual_index=residual_accumulator.n,
|
||||||
|
in_channels=block_channels,
|
||||||
|
out_channels=block_channels,
|
||||||
|
device=unet.device,
|
||||||
|
dtype=unet.dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# append a ZeroConvolution to middle_block
|
||||||
|
middle_block_channels = middle_block.layer(0, ResidualBlock).out_channels
|
||||||
|
middle_block.append(
|
||||||
|
ZeroConvolution(
|
||||||
|
scale=scale,
|
||||||
|
residual_index=len(downblocks),
|
||||||
|
in_channels=middle_block_channels,
|
||||||
|
out_channels=middle_block_channels,
|
||||||
|
device=unet.device,
|
||||||
|
dtype=unet.dtype,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scale(self) -> float:
|
||||||
|
"""The scale of the injected residuals."""
|
||||||
|
zero_convolution_module = self.ensure_find(ZeroConvolution)
|
||||||
|
return zero_convolution_module.scale
|
||||||
|
|
||||||
|
@scale.setter
|
||||||
|
def scale(self, value: float) -> None:
|
||||||
|
for zero_convolution_module in self.layers(ZeroConvolution):
|
||||||
|
zero_convolution_module.scale = value
|
||||||
|
|
||||||
|
|
||||||
|
class ControlLoraAdapter(Chain, Adapter[SDXLUNet]):
|
||||||
|
"""Adapter for ControlLora.
|
||||||
|
|
||||||
|
This adapter simply prepends a ControlLora model inside the target's UNet.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
target: SDXLUNet,
|
||||||
|
scale: float = 1.0,
|
||||||
|
condition_channels: int = 3,
|
||||||
|
weights: dict[str, Tensor] | None = None,
|
||||||
|
) -> None:
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
self.name = name
|
||||||
|
self._control_lora = [
|
||||||
|
ControlLora(
|
||||||
|
name=name,
|
||||||
|
unet=target,
|
||||||
|
scale=scale,
|
||||||
|
condition_channels=condition_channels,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
super().__init__(target)
|
||||||
|
|
||||||
|
if weights:
|
||||||
|
self.load_weights(weights)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def control_lora(self) -> ControlLora:
|
||||||
|
"""The ControlLora model."""
|
||||||
|
return self._control_lora[0]
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {
|
||||||
|
f"control_lora_{self.name}": {
|
||||||
|
"condition": None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def inject(self, parent: Chain | None = None) -> "ControlLoraAdapter":
|
||||||
|
self.target.insert(index=0, module=self.control_lora)
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
self.target.remove(self.control_lora)
|
||||||
|
return super().eject()
|
||||||
|
|
||||||
|
def structural_copy(self) -> "ControlLoraAdapter":
|
||||||
|
raise RuntimeError("ControlLoraAdapter cannot be copied, eject it first.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scale(self) -> float:
|
||||||
|
"""The scale of the injected residuals."""
|
||||||
|
return self.control_lora.scale
|
||||||
|
|
||||||
|
@scale.setter
|
||||||
|
def scale(self, value: float) -> None:
|
||||||
|
self.control_lora.scale = value
|
||||||
|
|
||||||
|
def set_condition(self, condition: Tensor) -> None:
|
||||||
|
self.set_context(
|
||||||
|
context=f"control_lora_{self.name}",
|
||||||
|
value={"condition": condition},
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_weights(
|
||||||
|
self,
|
||||||
|
state_dict: dict[str, Tensor],
|
||||||
|
) -> None:
|
||||||
|
"""Load the weights from the state_dict into the ControlLora.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict: The state_dict containing the weights to load.
|
||||||
|
"""
|
||||||
|
ControlLoraAdapter.load_lora_layers(self.name, state_dict, self.control_lora)
|
||||||
|
ControlLoraAdapter.load_zero_convolution_layers(state_dict, self.control_lora)
|
||||||
|
ControlLoraAdapter.load_condition_encoder(state_dict, self.control_lora)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_lora_layers(
|
||||||
|
name: str,
|
||||||
|
state_dict: dict[str, Tensor],
|
||||||
|
control_lora: ControlLora,
|
||||||
|
) -> None:
|
||||||
|
"""Load the LoRA layers from the state_dict into the ControlLora.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the ControlLora.
|
||||||
|
state_dict: The state_dict containing the LoRA layers to load.
|
||||||
|
control_lora: The ControlLora to load the LoRA layers into.
|
||||||
|
"""
|
||||||
|
# filter the LoraAdapters from the state_dict
|
||||||
|
lora_weights = {
|
||||||
|
key.removeprefix("ControlLora."): value for key, value in state_dict.items() if "ControlLora" in key
|
||||||
|
}
|
||||||
|
lora_weights = {f"{key}.weight": value for key, value in lora_weights.items()}
|
||||||
|
|
||||||
|
# move the tensors to the device and dtype of the ControlLora
|
||||||
|
lora_weights = {
|
||||||
|
key: value.to(
|
||||||
|
dtype=control_lora.dtype,
|
||||||
|
device=control_lora.device,
|
||||||
|
)
|
||||||
|
for key, value in lora_weights.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# load every LoRA layers from the filtered state_dict
|
||||||
|
loras = Lora.from_dict(name, state_dict=lora_weights)
|
||||||
|
|
||||||
|
# attach the LoRA layers to the ControlLora
|
||||||
|
adapters: list[LoraAdapter] = []
|
||||||
|
for key, lora in loras.items():
|
||||||
|
target = control_lora.layer(key.split("."), WeightedModule)
|
||||||
|
assert lora.is_compatible(target)
|
||||||
|
adapter = LoraAdapter(target, lora)
|
||||||
|
adapters.append(adapter)
|
||||||
|
|
||||||
|
for adapter in adapters:
|
||||||
|
adapter.inject(control_lora)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_zero_convolution_layers(
|
||||||
|
state_dict: dict[str, Tensor],
|
||||||
|
control_lora: ControlLora,
|
||||||
|
):
|
||||||
|
"""Load the ZeroConvolution layers from the state_dict into the ControlLora.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict: The state_dict containing the ZeroConvolution layers to load.
|
||||||
|
control_lora: The ControlLora to load the ZeroConvolution layers into.
|
||||||
|
"""
|
||||||
|
zero_convolution_layers = list(control_lora.layers(ZeroConvolution))
|
||||||
|
for i, zero_convolution_layer in enumerate(zero_convolution_layers):
|
||||||
|
zero_convolution_state_dict = {
|
||||||
|
key.removeprefix(f"ZeroConvolution_{i+1:02d}."): value
|
||||||
|
for key, value in state_dict.items()
|
||||||
|
if f"ZeroConvolution_{i+1:02d}" in key
|
||||||
|
}
|
||||||
|
zero_convolution_layer.load_state_dict(zero_convolution_state_dict)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_condition_encoder(
|
||||||
|
state_dict: dict[str, Tensor],
|
||||||
|
control_lora: ControlLora,
|
||||||
|
):
|
||||||
|
"""Load the ConditionEncoder layers from the state_dict into the ControlLora.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict: The state_dict containing the ConditionEncoder layers to load.
|
||||||
|
control_lora: The ControlLora to load the ConditionEncoder layers into.
|
||||||
|
"""
|
||||||
|
condition_encoder_layer = control_lora.ensure_find(ConditionEncoder)
|
||||||
|
condition_encoder_state_dict = {
|
||||||
|
key.removeprefix("ConditionEncoder."): value
|
||||||
|
for key, value in state_dict.items()
|
||||||
|
if "ConditionEncoder" in key
|
||||||
|
}
|
||||||
|
condition_encoder_layer.load_state_dict(condition_encoder_state_dict)
|
Loading…
Reference in a new issue