mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
implement ControlLora
and ControlLoraAdapter
This commit is contained in:
parent
a54808e757
commit
41a5ce2052
|
@ -0,0 +1,401 @@
|
|||
from torch import Tensor, device as Device, dtype as DType
|
||||
|
||||
from refiners.fluxion.adapters.adapter import Adapter
|
||||
from refiners.fluxion.adapters.lora import Lora, LoraAdapter
|
||||
from refiners.fluxion.context import Contexts
|
||||
from refiners.fluxion.layers import Chain, Conv2d, Multiply, Passthrough, Residual, SiLU, UseContext
|
||||
from refiners.fluxion.layers.module import WeightedModule
|
||||
from refiners.foundationals.latent_diffusion.range_adapter import RangeAdapter2d
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualAccumulator
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||
from refiners.foundationals.latent_diffusion.unet import ResidualBlock
|
||||
|
||||
|
||||
class ConditionEncoder(Chain):
|
||||
"""Encode an image into a condition latent tensor.
|
||||
|
||||
Receives:
|
||||
(Float[Tensor, "batch in_channels width height"]): The input image.
|
||||
|
||||
Returns:
|
||||
(Float[Tensor, "batch out_channels latent_width latent_height"]): The condition latent tensor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 320,
|
||||
intermediate_channels: tuple[int, ...] = (16, 32, 96, 256),
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
"""Initialize the ConditionEncoder.
|
||||
|
||||
Args:
|
||||
in_channels: The number of channels of the image tensor.
|
||||
out_channels: The number of channels of the latent tensor to encode the condition into.
|
||||
intermediate_channels: The number of channels of the intermediate layers.
|
||||
device: The PyTorch device to use.
|
||||
dtype: The PyTorch data type to use.
|
||||
"""
|
||||
|
||||
super().__init__(
|
||||
Chain(
|
||||
Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=intermediate_channels[0],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
SiLU(),
|
||||
),
|
||||
*(
|
||||
Chain(
|
||||
Conv2d(
|
||||
in_channels=intermediate_channels[i],
|
||||
out_channels=intermediate_channels[i],
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
SiLU(),
|
||||
Conv2d(
|
||||
in_channels=intermediate_channels[i],
|
||||
out_channels=intermediate_channels[i + 1],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
SiLU(),
|
||||
)
|
||||
for i in range(len(intermediate_channels) - 1)
|
||||
),
|
||||
Conv2d(
|
||||
in_channels=intermediate_channels[-1],
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ZeroConvolution(Passthrough):
|
||||
"""Transform and store the ControlLora's residuals in the context of the original UNet.
|
||||
|
||||
Receives:
|
||||
(Float[Tensor, "batch in_channels width height"]): The input tensor to transform and store.
|
||||
|
||||
Returns: Updates context:
|
||||
(Tensor): Add the residual to the nth residual of the target's UNet.
|
||||
(context="unet", key="residuals")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
residual_index: int,
|
||||
scale: float = 1.0,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
"""Initialize the ZeroConvolution.
|
||||
|
||||
Args:
|
||||
in_channels: The number of channels of the input tensor.
|
||||
out_channels: The number of channels of the output tensor/residual.
|
||||
residual_index: The index of the residual to store in the target's UNet.
|
||||
scale: The scale to multiply the residuals by.
|
||||
device: The PyTorch device to use.
|
||||
dtype: The PyTorch data type to use.
|
||||
"""
|
||||
self.scale = scale
|
||||
|
||||
super().__init__(
|
||||
Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
Multiply(scale=scale),
|
||||
ResidualAccumulator(n=residual_index),
|
||||
)
|
||||
|
||||
|
||||
class ControlLora(Passthrough):
|
||||
"""ControlLora is a Half-UNet clone of the target UNet, patched with LoRAs.
|
||||
|
||||
Like ControlNet, it injects residual tensors into the target UNet.
|
||||
See https://github.com/HighCWu/control-lora-v2 for more details.
|
||||
|
||||
Receives: Gets context:
|
||||
(Float[Tensor, "batch condition_channels width height"]): The input image.
|
||||
|
||||
Returns: Sets context:
|
||||
(list[Tensor]): The residuals to be added to the target UNet's residuals.
|
||||
(context="unet", key="residuals")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
unet: SDXLUNet,
|
||||
scale: float = 1.0,
|
||||
condition_channels: int = 3,
|
||||
) -> None:
|
||||
"""Initialize the ControlLora.
|
||||
|
||||
Args:
|
||||
name: The name of the ControlLora.
|
||||
unet: The target UNet.
|
||||
scale: The scale to multiply the residuals by.
|
||||
condition_channels: The number of channels of the input condition tensor.
|
||||
"""
|
||||
self.name = name
|
||||
|
||||
super().__init__(
|
||||
timestep_encoder := unet.layer("TimestepEncoder", Chain).structural_copy(),
|
||||
downblocks := unet.layer("DownBlocks", Chain).structural_copy(),
|
||||
middle_block := unet.layer("MiddleBlock", Chain).structural_copy(),
|
||||
)
|
||||
|
||||
# modify the context_key of the copied TimestepEncoder to avoid conflicts
|
||||
timestep_encoder.context_key = f"timestep_embedding_control_lora_{name}"
|
||||
|
||||
# modify the context_key of each RangeAdapter2d to avoid conflicts
|
||||
for range_adapter in self.layers(RangeAdapter2d):
|
||||
range_adapter.context_key = f"timestep_embedding_control_lora_{name}"
|
||||
|
||||
# insert the ConditionEncoder in the first DownBlock
|
||||
first_downblock = downblocks.layer(0, Chain)
|
||||
out_channels = first_downblock.layer(0, Conv2d).out_channels
|
||||
first_downblock.append(
|
||||
Residual(
|
||||
UseContext(f"control_lora_{name}", f"condition"),
|
||||
ConditionEncoder(
|
||||
in_channels=condition_channels,
|
||||
out_channels=out_channels,
|
||||
device=unet.device,
|
||||
dtype=unet.dtype,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# replace each ResidualAccumulator by a ZeroConvolution
|
||||
for residual_accumulator in self.layers(ResidualAccumulator):
|
||||
downblock = self.ensure_find_parent(residual_accumulator)
|
||||
|
||||
first_layer = downblock[0]
|
||||
assert hasattr(first_layer, "out_channels"), f"{first_layer} has no out_channels attribute"
|
||||
|
||||
block_channels = first_layer.out_channels
|
||||
assert isinstance(block_channels, int)
|
||||
|
||||
downblock.replace(
|
||||
residual_accumulator,
|
||||
ZeroConvolution(
|
||||
scale=scale,
|
||||
residual_index=residual_accumulator.n,
|
||||
in_channels=block_channels,
|
||||
out_channels=block_channels,
|
||||
device=unet.device,
|
||||
dtype=unet.dtype,
|
||||
),
|
||||
)
|
||||
|
||||
# append a ZeroConvolution to middle_block
|
||||
middle_block_channels = middle_block.layer(0, ResidualBlock).out_channels
|
||||
middle_block.append(
|
||||
ZeroConvolution(
|
||||
scale=scale,
|
||||
residual_index=len(downblocks),
|
||||
in_channels=middle_block_channels,
|
||||
out_channels=middle_block_channels,
|
||||
device=unet.device,
|
||||
dtype=unet.dtype,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
"""The scale of the injected residuals."""
|
||||
zero_convolution_module = self.ensure_find(ZeroConvolution)
|
||||
return zero_convolution_module.scale
|
||||
|
||||
@scale.setter
|
||||
def scale(self, value: float) -> None:
|
||||
for zero_convolution_module in self.layers(ZeroConvolution):
|
||||
zero_convolution_module.scale = value
|
||||
|
||||
|
||||
class ControlLoraAdapter(Chain, Adapter[SDXLUNet]):
|
||||
"""Adapter for ControlLora.
|
||||
|
||||
This adapter simply prepends a ControlLora model inside the target's UNet.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
target: SDXLUNet,
|
||||
scale: float = 1.0,
|
||||
condition_channels: int = 3,
|
||||
weights: dict[str, Tensor] | None = None,
|
||||
) -> None:
|
||||
with self.setup_adapter(target):
|
||||
self.name = name
|
||||
self._control_lora = [
|
||||
ControlLora(
|
||||
name=name,
|
||||
unet=target,
|
||||
scale=scale,
|
||||
condition_channels=condition_channels,
|
||||
),
|
||||
]
|
||||
|
||||
super().__init__(target)
|
||||
|
||||
if weights:
|
||||
self.load_weights(weights)
|
||||
|
||||
@property
|
||||
def control_lora(self) -> ControlLora:
|
||||
"""The ControlLora model."""
|
||||
return self._control_lora[0]
|
||||
|
||||
def init_context(self) -> Contexts:
|
||||
return {
|
||||
f"control_lora_{self.name}": {
|
||||
"condition": None,
|
||||
}
|
||||
}
|
||||
|
||||
def inject(self, parent: Chain | None = None) -> "ControlLoraAdapter":
|
||||
self.target.insert(index=0, module=self.control_lora)
|
||||
return super().inject(parent)
|
||||
|
||||
def eject(self) -> None:
|
||||
self.target.remove(self.control_lora)
|
||||
return super().eject()
|
||||
|
||||
def structural_copy(self) -> "ControlLoraAdapter":
|
||||
raise RuntimeError("ControlLoraAdapter cannot be copied, eject it first.")
|
||||
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
"""The scale of the injected residuals."""
|
||||
return self.control_lora.scale
|
||||
|
||||
@scale.setter
|
||||
def scale(self, value: float) -> None:
|
||||
self.control_lora.scale = value
|
||||
|
||||
def set_condition(self, condition: Tensor) -> None:
|
||||
self.set_context(
|
||||
context=f"control_lora_{self.name}",
|
||||
value={"condition": condition},
|
||||
)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
state_dict: dict[str, Tensor],
|
||||
) -> None:
|
||||
"""Load the weights from the state_dict into the ControlLora.
|
||||
|
||||
Args:
|
||||
state_dict: The state_dict containing the weights to load.
|
||||
"""
|
||||
ControlLoraAdapter.load_lora_layers(self.name, state_dict, self.control_lora)
|
||||
ControlLoraAdapter.load_zero_convolution_layers(state_dict, self.control_lora)
|
||||
ControlLoraAdapter.load_condition_encoder(state_dict, self.control_lora)
|
||||
|
||||
@staticmethod
|
||||
def load_lora_layers(
|
||||
name: str,
|
||||
state_dict: dict[str, Tensor],
|
||||
control_lora: ControlLora,
|
||||
) -> None:
|
||||
"""Load the LoRA layers from the state_dict into the ControlLora.
|
||||
|
||||
Args:
|
||||
name: The name of the ControlLora.
|
||||
state_dict: The state_dict containing the LoRA layers to load.
|
||||
control_lora: The ControlLora to load the LoRA layers into.
|
||||
"""
|
||||
# filter the LoraAdapters from the state_dict
|
||||
lora_weights = {
|
||||
key.removeprefix("ControlLora."): value for key, value in state_dict.items() if "ControlLora" in key
|
||||
}
|
||||
lora_weights = {f"{key}.weight": value for key, value in lora_weights.items()}
|
||||
|
||||
# move the tensors to the device and dtype of the ControlLora
|
||||
lora_weights = {
|
||||
key: value.to(
|
||||
dtype=control_lora.dtype,
|
||||
device=control_lora.device,
|
||||
)
|
||||
for key, value in lora_weights.items()
|
||||
}
|
||||
|
||||
# load every LoRA layers from the filtered state_dict
|
||||
loras = Lora.from_dict(name, state_dict=lora_weights)
|
||||
|
||||
# attach the LoRA layers to the ControlLora
|
||||
adapters: list[LoraAdapter] = []
|
||||
for key, lora in loras.items():
|
||||
target = control_lora.layer(key.split("."), WeightedModule)
|
||||
assert lora.is_compatible(target)
|
||||
adapter = LoraAdapter(target, lora)
|
||||
adapters.append(adapter)
|
||||
|
||||
for adapter in adapters:
|
||||
adapter.inject(control_lora)
|
||||
|
||||
@staticmethod
|
||||
def load_zero_convolution_layers(
|
||||
state_dict: dict[str, Tensor],
|
||||
control_lora: ControlLora,
|
||||
):
|
||||
"""Load the ZeroConvolution layers from the state_dict into the ControlLora.
|
||||
|
||||
Args:
|
||||
state_dict: The state_dict containing the ZeroConvolution layers to load.
|
||||
control_lora: The ControlLora to load the ZeroConvolution layers into.
|
||||
"""
|
||||
zero_convolution_layers = list(control_lora.layers(ZeroConvolution))
|
||||
for i, zero_convolution_layer in enumerate(zero_convolution_layers):
|
||||
zero_convolution_state_dict = {
|
||||
key.removeprefix(f"ZeroConvolution_{i+1:02d}."): value
|
||||
for key, value in state_dict.items()
|
||||
if f"ZeroConvolution_{i+1:02d}" in key
|
||||
}
|
||||
zero_convolution_layer.load_state_dict(zero_convolution_state_dict)
|
||||
|
||||
@staticmethod
|
||||
def load_condition_encoder(
|
||||
state_dict: dict[str, Tensor],
|
||||
control_lora: ControlLora,
|
||||
):
|
||||
"""Load the ConditionEncoder layers from the state_dict into the ControlLora.
|
||||
|
||||
Args:
|
||||
state_dict: The state_dict containing the ConditionEncoder layers to load.
|
||||
control_lora: The ControlLora to load the ConditionEncoder layers into.
|
||||
"""
|
||||
condition_encoder_layer = control_lora.ensure_find(ConditionEncoder)
|
||||
condition_encoder_state_dict = {
|
||||
key.removeprefix("ConditionEncoder."): value
|
||||
for key, value in state_dict.items()
|
||||
if "ConditionEncoder" in key
|
||||
}
|
||||
condition_encoder_layer.load_state_dict(condition_encoder_state_dict)
|
Loading…
Reference in a new issue