mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
write ControlLora weight conversion script
This commit is contained in:
parent
41a5ce2052
commit
5fee723cd1
349
scripts/conversion/convert_fooocus_control_lora.py
Normal file
349
scripts/conversion/convert_fooocus_control_lora.py
Normal file
|
@ -0,0 +1,349 @@
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from logging import info
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from huggingface_hub import hf_hub_download # type: ignore
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import Parameter as TorchParameter
|
||||||
|
|
||||||
|
from refiners.fluxion.adapters.lora import Lora, LoraAdapter, auto_attach_loras
|
||||||
|
from refiners.fluxion.layers import Conv2d
|
||||||
|
from refiners.fluxion.layers.linear import Linear
|
||||||
|
from refiners.fluxion.utils import load_from_safetensors, save_to_safetensors
|
||||||
|
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.control_lora import (
|
||||||
|
ConditionEncoder,
|
||||||
|
ControlLora,
|
||||||
|
ControlLoraAdapter,
|
||||||
|
ZeroConvolution,
|
||||||
|
)
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||||
|
|
||||||
|
|
||||||
|
def sort_keys(key: str, /) -> tuple[str, int]:
|
||||||
|
"""Compute the score of a key, relatively to its suffix.
|
||||||
|
|
||||||
|
When used by [`sorted`][sorted], the keys will only be sorted "at the suffix level".
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The key to sort.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The padded suffix of the key.
|
||||||
|
The score of the key's suffix.
|
||||||
|
"""
|
||||||
|
if "time_embed" in key: # HACK: will place the "time_embed" layers at very start of the list
|
||||||
|
return ("", -2)
|
||||||
|
|
||||||
|
if "label_emb" in key: # HACK: will place the "label_emb" layers right after "time_embed"
|
||||||
|
return ("", -1)
|
||||||
|
|
||||||
|
if "proj_out" in key: # HACK: will place the "proj_out" layers at the end of each "transformer_blocks"
|
||||||
|
return (key.removesuffix("proj_out") + "transformer_blocks.99.ff.net.2", 10)
|
||||||
|
|
||||||
|
return SDLoraManager.sort_keys(key)
|
||||||
|
|
||||||
|
|
||||||
|
def load_lora_layers(
|
||||||
|
name: str,
|
||||||
|
state_dict: dict[str, Tensor],
|
||||||
|
control_lora: ControlLora,
|
||||||
|
) -> dict[str, Lora[Linear | Conv2d]]:
|
||||||
|
"""Load the LoRA layers from the state_dict into the ControlLora.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the LoRA.
|
||||||
|
state_dict: The state_dict of the LoRA.
|
||||||
|
control_lora: The ControlLora to load the LoRA layers into.
|
||||||
|
"""
|
||||||
|
# filter from the state_dict the layers that will be used for the LoRA layers
|
||||||
|
lora_weights = {f"{key}.weight": value for key, value in state_dict.items() if ".up" in key or ".down" in key}
|
||||||
|
|
||||||
|
# move the tensors to the device and dtype of the ControlLora
|
||||||
|
lora_weights = {
|
||||||
|
key: value.to(
|
||||||
|
dtype=control_lora.dtype,
|
||||||
|
device=control_lora.device,
|
||||||
|
)
|
||||||
|
for key, value in lora_weights.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# load every LoRA layers from the filtered state_dict
|
||||||
|
lora_layers = Lora.from_dict(name, state_dict=lora_weights)
|
||||||
|
|
||||||
|
# sort all the LoRA's keys using the `sort_keys` method
|
||||||
|
lora_layers = {
|
||||||
|
key: lora_layers[key]
|
||||||
|
for key in sorted(
|
||||||
|
lora_layers.keys(),
|
||||||
|
key=sort_keys,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
# auto-attach the LoRA layers to the U-Net
|
||||||
|
failed_keys = auto_attach_loras(lora_layers, control_lora, exclude=["ZeroConvolution", "ConditionEncoder"])
|
||||||
|
assert not failed_keys, f"Failed to auto-attach {len(failed_keys)}/{len(lora_layers)} LoRA layers."
|
||||||
|
|
||||||
|
# eject all the LoRA adapters from the U-Net
|
||||||
|
# because we need each target path as if the adapter wasn't injected
|
||||||
|
for lora_layer in lora_layers.values():
|
||||||
|
lora_adapter = lora_layer.parent
|
||||||
|
assert isinstance(lora_adapter, LoraAdapter)
|
||||||
|
lora_adapter.eject()
|
||||||
|
|
||||||
|
return lora_layers
|
||||||
|
|
||||||
|
|
||||||
|
def load_condition_encoder(
|
||||||
|
state_dict: dict[str, Tensor],
|
||||||
|
control_lora: ControlLora,
|
||||||
|
) -> None:
|
||||||
|
"""Load the ConditionEncoder's Conv2d layers from the state_dict into the ControlLora.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict: The state_dict of the ConditionEncoder.
|
||||||
|
control_lora: The control_lora to load the ConditionEncoder's Conv2d layers into.
|
||||||
|
"""
|
||||||
|
# filter from the state_dict the layers that will be used for the ConditionEncoder
|
||||||
|
condition_encoder_tensors = {key: value for key, value in state_dict.items() if "input_hint_block" in key}
|
||||||
|
|
||||||
|
# move the tensors to the device and dtype of the ControlLora
|
||||||
|
condition_encoder_tensors = {
|
||||||
|
key: value.to(
|
||||||
|
dtype=control_lora.dtype,
|
||||||
|
device=control_lora.device,
|
||||||
|
)
|
||||||
|
for key, value in condition_encoder_tensors.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# find the ConditionEncoder's Conv2d layers
|
||||||
|
condition_encoder_layer = control_lora.ensure_find(ConditionEncoder)
|
||||||
|
condition_encoder_conv2ds = list(condition_encoder_layer.layers(Conv2d))
|
||||||
|
|
||||||
|
# replace the Conv2d layers' weights and biases with the ones from the state_dict
|
||||||
|
for i, layer in enumerate(condition_encoder_conv2ds):
|
||||||
|
layer.weight = TorchParameter(condition_encoder_tensors[f"input_hint_block.{i*2}.weight"])
|
||||||
|
layer.bias = TorchParameter(condition_encoder_tensors[f"input_hint_block.{i*2}.bias"])
|
||||||
|
|
||||||
|
|
||||||
|
def load_zero_convolutions(
|
||||||
|
state_dict: dict[str, Tensor],
|
||||||
|
control_lora: ControlLora,
|
||||||
|
) -> None:
|
||||||
|
"""Load the ZeroConvolution's Conv2d layers from the state_dict into the ControlLora.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict: The state_dict of the ZeroConvolution.
|
||||||
|
control_lora: The ControlLora to load the ZeroConvolution's Conv2d layers into.
|
||||||
|
"""
|
||||||
|
# filter from the state_dict the layers that will be used for the ZeroConvolution layers
|
||||||
|
zero_convolution_tensors = {key: value for key, value in state_dict.items() if "zero_convs" in key}
|
||||||
|
n = len(zero_convolution_tensors) // 2
|
||||||
|
zero_convolution_tensors[f"zero_convs.{n}.0.weight"] = state_dict["middle_block_out.0.weight"]
|
||||||
|
zero_convolution_tensors[f"zero_convs.{n}.0.bias"] = state_dict["middle_block_out.0.bias"]
|
||||||
|
|
||||||
|
# move the tensors to the device and dtype of the ControlLora
|
||||||
|
zero_convolution_tensors = {
|
||||||
|
key: value.to(
|
||||||
|
dtype=control_lora.dtype,
|
||||||
|
device=control_lora.device,
|
||||||
|
)
|
||||||
|
for key, value in zero_convolution_tensors.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# find the ZeroConvolution's Conv2d layers
|
||||||
|
zero_convolution_layers = list(control_lora.layers(ZeroConvolution))
|
||||||
|
zero_convolution_conv2ds = [layer.ensure_find(Conv2d) for layer in zero_convolution_layers]
|
||||||
|
|
||||||
|
# replace the Conv2d layers' weights and biases with the ones from the state_dict
|
||||||
|
for i, layer in enumerate(zero_convolution_conv2ds):
|
||||||
|
layer.weight = TorchParameter(zero_convolution_tensors[f"zero_convs.{i}.0.weight"])
|
||||||
|
layer.bias = TorchParameter(zero_convolution_tensors[f"zero_convs.{i}.0.bias"])
|
||||||
|
|
||||||
|
|
||||||
|
def simplify_key(key: str, prefix: str, index: int | None = None) -> str:
|
||||||
|
"""Simplify a key by stripping everything to the left of the prefix.
|
||||||
|
|
||||||
|
Also optionally add a zero-padded index to the prefix.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> simplify_key("foo.bar.ControlLora.something", "ControlLora", 1)
|
||||||
|
"ControlLora_01.something"
|
||||||
|
|
||||||
|
>>> simplify_key("foo.bar.ControlLora.DownBlocks.something", "ControlLora")
|
||||||
|
"ControlLora.DownBlocks.something"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The key to simplify.
|
||||||
|
prefix: The prefix to remove.
|
||||||
|
index: The index to add.
|
||||||
|
"""
|
||||||
|
_, right = key.split(prefix, maxsplit=1)
|
||||||
|
if index:
|
||||||
|
return f"{prefix}_{index:02d}{right}"
|
||||||
|
else:
|
||||||
|
return f"{prefix}{right}"
|
||||||
|
|
||||||
|
|
||||||
|
def convert_lora_layers(
|
||||||
|
lora_layers: dict[str, Lora[Linear | Conv2d]],
|
||||||
|
control_lora: ControlLora,
|
||||||
|
refiners_state_dict: dict[str, Tensor],
|
||||||
|
) -> None:
|
||||||
|
"""Convert the LoRA layers to the refiners format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lora_layers: The LoRA layers to convert.
|
||||||
|
control_lora: The ControlLora to convert the LoRA layers from.
|
||||||
|
refiners_state_dict: The refiners state dict to update with the converted LoRA layers.
|
||||||
|
"""
|
||||||
|
for lora_layer in lora_layers.values():
|
||||||
|
# get the adapter associated with the LoRA layer
|
||||||
|
lora_adapter = lora_layer.parent
|
||||||
|
assert isinstance(lora_adapter, LoraAdapter)
|
||||||
|
|
||||||
|
# get the path of the adapter's target in the ControlLora
|
||||||
|
target = lora_adapter.target
|
||||||
|
path = target.get_path(parent=control_lora.ensure_find_parent(target))
|
||||||
|
|
||||||
|
state_dict = {
|
||||||
|
f"{path}.down": lora_layer.down.weight,
|
||||||
|
f"{path}.up": lora_layer.up.weight,
|
||||||
|
}
|
||||||
|
state_dict = {simplify_key(key, "ControlLora"): param for key, param in state_dict.items()}
|
||||||
|
refiners_state_dict.update(state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_zero_convolutions(
|
||||||
|
control_lora: ControlLora,
|
||||||
|
refiners_state_dict: dict[str, Tensor],
|
||||||
|
) -> None:
|
||||||
|
"""Convert the ZeroConvolution layers to the refiners format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
control_lora: The ControlLora to convert the ZeroConvolution layers from.
|
||||||
|
refiners_state_dict: The refiners state dict to update with the converted ZeroConvolution layers.
|
||||||
|
"""
|
||||||
|
zero_convolution_layers = list(control_lora.layers(ZeroConvolution))
|
||||||
|
for i, zero_convolution_layer in enumerate(zero_convolution_layers):
|
||||||
|
state_dict = zero_convolution_layer.state_dict()
|
||||||
|
path = zero_convolution_layer.get_path()
|
||||||
|
state_dict = {f"{path}.{key}": param for key, param in state_dict.items()}
|
||||||
|
state_dict = {simplify_key(key, "ZeroConvolution", i + 1): param for key, param in state_dict.items()}
|
||||||
|
refiners_state_dict.update(state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_condition_encoder(
|
||||||
|
control_lora: ControlLora,
|
||||||
|
refiners_state_dict: dict[str, Tensor],
|
||||||
|
) -> None:
|
||||||
|
"""Convert the ConditionEncoder to the refiners format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
control_lora: The ControlLora to convert the ConditionEncoder from.
|
||||||
|
refiners_state_dict: The refiners state dict to update with the converted ConditionEncoder.
|
||||||
|
"""
|
||||||
|
condition_encoder_layer = control_lora.ensure_find(ConditionEncoder)
|
||||||
|
path = condition_encoder_layer.get_path()
|
||||||
|
state_dict = condition_encoder_layer.state_dict()
|
||||||
|
state_dict = {f"{path}.{key}": param for key, param in state_dict.items()}
|
||||||
|
state_dict = {simplify_key(key, "ConditionEncoder"): param for key, param in state_dict.items()}
|
||||||
|
refiners_state_dict.update(state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def convert(
|
||||||
|
name: str,
|
||||||
|
state_dict_path: Path,
|
||||||
|
output_path: Path,
|
||||||
|
) -> None:
|
||||||
|
sdxl = StableDiffusion_XL()
|
||||||
|
info("Stable Diffusion XL model initialized")
|
||||||
|
|
||||||
|
fooocus_state_dict = load_from_safetensors(state_dict_path)
|
||||||
|
info(f"Fooocus weights loaded from: {state_dict_path}")
|
||||||
|
|
||||||
|
control_lora_adapter = ControlLoraAdapter(target=sdxl.unet, name=name).inject()
|
||||||
|
control_lora = control_lora_adapter.control_lora
|
||||||
|
info("ControlLoraAdapter initialized")
|
||||||
|
|
||||||
|
lora_layers = load_lora_layers(name, fooocus_state_dict, control_lora)
|
||||||
|
info("LoRA layers loaded")
|
||||||
|
|
||||||
|
load_zero_convolutions(fooocus_state_dict, control_lora)
|
||||||
|
info("ZeroConvolution layers loaded")
|
||||||
|
|
||||||
|
load_condition_encoder(fooocus_state_dict, control_lora)
|
||||||
|
info("ConditionEncoder loaded")
|
||||||
|
|
||||||
|
refiners_state_dict: dict[str, Tensor] = {}
|
||||||
|
convert_lora_layers(lora_layers, control_lora, refiners_state_dict)
|
||||||
|
info("LoRA layers converted to refiners format")
|
||||||
|
|
||||||
|
convert_zero_convolutions(control_lora, refiners_state_dict)
|
||||||
|
info("ZeroConvolution layers converted to refiners format")
|
||||||
|
|
||||||
|
convert_condition_encoder(control_lora, refiners_state_dict)
|
||||||
|
info("ConditionEncoder converted to refiners format")
|
||||||
|
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
save_to_safetensors(path=output_path, tensors=refiners_state_dict)
|
||||||
|
info(f"Converted ControlLora state dict saved to disk at: {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Convert ControlLora (from Fooocus) weights to refiners.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--from",
|
||||||
|
type=Path,
|
||||||
|
dest="source_path",
|
||||||
|
default="lllyasviel/misc:control-lora-canny-rank128.safetensors",
|
||||||
|
help="Path to the state_dict of the ControlLora, or a Hugging Face model ID.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--to",
|
||||||
|
type=Path,
|
||||||
|
dest="output_path",
|
||||||
|
help=(
|
||||||
|
"Path to save the converted model (extension will be .safetensors)."
|
||||||
|
"If not specified, the output path will be the source path with the extension changed to .safetensors."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose",
|
||||||
|
action="store_true",
|
||||||
|
dest="verbose",
|
||||||
|
default=False,
|
||||||
|
help="Use this flag to print verbose output during conversion.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.verbose:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(levelname)s: %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not args.source_path.exists():
|
||||||
|
repo_id, filename = str(args.source_path).split(":")
|
||||||
|
args.source_path = Path(
|
||||||
|
hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
filename=filename,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.output_path is None:
|
||||||
|
args.output_path = Path(f"refiners_{args.source_path.stem}.safetensors")
|
||||||
|
|
||||||
|
convert(
|
||||||
|
name=args.source_path.stem,
|
||||||
|
state_dict_path=args.source_path,
|
||||||
|
output_path=args.output_path,
|
||||||
|
)
|
|
@ -273,6 +273,20 @@ def download_controlnet():
|
||||||
download_files(urls, mfidabel_folder)
|
download_files(urls, mfidabel_folder)
|
||||||
|
|
||||||
|
|
||||||
|
def download_control_lora_fooocus():
|
||||||
|
base_folder = os.path.join(test_weights_dir, "lllyasviel", "misc")
|
||||||
|
control_loras = [
|
||||||
|
"control-lora-canny-rank128.safetensors",
|
||||||
|
"fooocus_xl_cpds_128.safetensors",
|
||||||
|
]
|
||||||
|
|
||||||
|
for control_lora in control_loras:
|
||||||
|
download_file(
|
||||||
|
url=f"https://huggingface.co/lllyasviel/misc/resolve/main/{control_lora}",
|
||||||
|
dest_folder=base_folder,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def download_unclip():
|
def download_unclip():
|
||||||
base_folder = os.path.join(test_weights_dir, "stabilityai", "stable-diffusion-2-1-unclip")
|
base_folder = os.path.join(test_weights_dir, "stabilityai", "stable-diffusion-2-1-unclip")
|
||||||
download_file(
|
download_file(
|
||||||
|
@ -624,6 +638,21 @@ def convert_dinov2():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_control_lora_fooocus():
|
||||||
|
run_conversion_script(
|
||||||
|
"convert_fooocus_control_lora.py",
|
||||||
|
"tests/weights/lllyasviel/misc/control-lora-canny-rank128.safetensors",
|
||||||
|
"tests/weights/control_lora/refiners_control-lora-canny-rank128.safetensors",
|
||||||
|
expected_hash="4d505134",
|
||||||
|
)
|
||||||
|
run_conversion_script(
|
||||||
|
"convert_fooocus_control_lora.py",
|
||||||
|
"tests/weights/lllyasviel/misc/fooocus_xl_cpds_128.safetensors",
|
||||||
|
"tests/weights/control_lora/refiners_fooocus_xl_cpds_128.safetensors",
|
||||||
|
expected_hash="d81aa461",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def download_all():
|
def download_all():
|
||||||
print(f"\nAll weights will be downloaded to {test_weights_dir}\n")
|
print(f"\nAll weights will be downloaded to {test_weights_dir}\n")
|
||||||
download_sd15("runwayml/stable-diffusion-v1-5")
|
download_sd15("runwayml/stable-diffusion-v1-5")
|
||||||
|
@ -639,6 +668,7 @@ def download_all():
|
||||||
download_t2i_adapter()
|
download_t2i_adapter()
|
||||||
download_sam()
|
download_sam()
|
||||||
download_dinov2()
|
download_dinov2()
|
||||||
|
download_control_lora_fooocus()
|
||||||
|
|
||||||
|
|
||||||
def convert_all():
|
def convert_all():
|
||||||
|
@ -654,6 +684,7 @@ def convert_all():
|
||||||
convert_t2i_adapter()
|
convert_t2i_adapter()
|
||||||
convert_sam()
|
convert_sam()
|
||||||
convert_dinov2()
|
convert_dinov2()
|
||||||
|
convert_control_lora_fooocus()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
Loading…
Reference in a new issue