From 5fee723cd1001f6bcb577289c9d607fd3428316a Mon Sep 17 00:00:00 2001 From: Laurent Date: Wed, 14 Feb 2024 15:26:47 +0000 Subject: [PATCH] write ControlLora weight conversion script --- .../convert_fooocus_control_lora.py | 349 ++++++++++++++++++ scripts/prepare_test_weights.py | 31 ++ 2 files changed, 380 insertions(+) create mode 100644 scripts/conversion/convert_fooocus_control_lora.py diff --git a/scripts/conversion/convert_fooocus_control_lora.py b/scripts/conversion/convert_fooocus_control_lora.py new file mode 100644 index 0000000..0d59829 --- /dev/null +++ b/scripts/conversion/convert_fooocus_control_lora.py @@ -0,0 +1,349 @@ +import argparse +import logging +from logging import info +from pathlib import Path + +from huggingface_hub import hf_hub_download # type: ignore +from torch import Tensor +from torch.nn import Parameter as TorchParameter + +from refiners.fluxion.adapters.lora import Lora, LoraAdapter, auto_attach_loras +from refiners.fluxion.layers import Conv2d +from refiners.fluxion.layers.linear import Linear +from refiners.fluxion.utils import load_from_safetensors, save_to_safetensors +from refiners.foundationals.latent_diffusion.lora import SDLoraManager +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.control_lora import ( + ConditionEncoder, + ControlLora, + ControlLoraAdapter, + ZeroConvolution, +) +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL + + +def sort_keys(key: str, /) -> tuple[str, int]: + """Compute the score of a key, relatively to its suffix. + + When used by [`sorted`][sorted], the keys will only be sorted "at the suffix level". + + Args: + key: The key to sort. + + Returns: + The padded suffix of the key. + The score of the key's suffix. + """ + if "time_embed" in key: # HACK: will place the "time_embed" layers at very start of the list + return ("", -2) + + if "label_emb" in key: # HACK: will place the "label_emb" layers right after "time_embed" + return ("", -1) + + if "proj_out" in key: # HACK: will place the "proj_out" layers at the end of each "transformer_blocks" + return (key.removesuffix("proj_out") + "transformer_blocks.99.ff.net.2", 10) + + return SDLoraManager.sort_keys(key) + + +def load_lora_layers( + name: str, + state_dict: dict[str, Tensor], + control_lora: ControlLora, +) -> dict[str, Lora[Linear | Conv2d]]: + """Load the LoRA layers from the state_dict into the ControlLora. + + Args: + name: The name of the LoRA. + state_dict: The state_dict of the LoRA. + control_lora: The ControlLora to load the LoRA layers into. + """ + # filter from the state_dict the layers that will be used for the LoRA layers + lora_weights = {f"{key}.weight": value for key, value in state_dict.items() if ".up" in key or ".down" in key} + + # move the tensors to the device and dtype of the ControlLora + lora_weights = { + key: value.to( + dtype=control_lora.dtype, + device=control_lora.device, + ) + for key, value in lora_weights.items() + } + + # load every LoRA layers from the filtered state_dict + lora_layers = Lora.from_dict(name, state_dict=lora_weights) + + # sort all the LoRA's keys using the `sort_keys` method + lora_layers = { + key: lora_layers[key] + for key in sorted( + lora_layers.keys(), + key=sort_keys, + ) + } + + # auto-attach the LoRA layers to the U-Net + failed_keys = auto_attach_loras(lora_layers, control_lora, exclude=["ZeroConvolution", "ConditionEncoder"]) + assert not failed_keys, f"Failed to auto-attach {len(failed_keys)}/{len(lora_layers)} LoRA layers." + + # eject all the LoRA adapters from the U-Net + # because we need each target path as if the adapter wasn't injected + for lora_layer in lora_layers.values(): + lora_adapter = lora_layer.parent + assert isinstance(lora_adapter, LoraAdapter) + lora_adapter.eject() + + return lora_layers + + +def load_condition_encoder( + state_dict: dict[str, Tensor], + control_lora: ControlLora, +) -> None: + """Load the ConditionEncoder's Conv2d layers from the state_dict into the ControlLora. + + Args: + state_dict: The state_dict of the ConditionEncoder. + control_lora: The control_lora to load the ConditionEncoder's Conv2d layers into. + """ + # filter from the state_dict the layers that will be used for the ConditionEncoder + condition_encoder_tensors = {key: value for key, value in state_dict.items() if "input_hint_block" in key} + + # move the tensors to the device and dtype of the ControlLora + condition_encoder_tensors = { + key: value.to( + dtype=control_lora.dtype, + device=control_lora.device, + ) + for key, value in condition_encoder_tensors.items() + } + + # find the ConditionEncoder's Conv2d layers + condition_encoder_layer = control_lora.ensure_find(ConditionEncoder) + condition_encoder_conv2ds = list(condition_encoder_layer.layers(Conv2d)) + + # replace the Conv2d layers' weights and biases with the ones from the state_dict + for i, layer in enumerate(condition_encoder_conv2ds): + layer.weight = TorchParameter(condition_encoder_tensors[f"input_hint_block.{i*2}.weight"]) + layer.bias = TorchParameter(condition_encoder_tensors[f"input_hint_block.{i*2}.bias"]) + + +def load_zero_convolutions( + state_dict: dict[str, Tensor], + control_lora: ControlLora, +) -> None: + """Load the ZeroConvolution's Conv2d layers from the state_dict into the ControlLora. + + Args: + state_dict: The state_dict of the ZeroConvolution. + control_lora: The ControlLora to load the ZeroConvolution's Conv2d layers into. + """ + # filter from the state_dict the layers that will be used for the ZeroConvolution layers + zero_convolution_tensors = {key: value for key, value in state_dict.items() if "zero_convs" in key} + n = len(zero_convolution_tensors) // 2 + zero_convolution_tensors[f"zero_convs.{n}.0.weight"] = state_dict["middle_block_out.0.weight"] + zero_convolution_tensors[f"zero_convs.{n}.0.bias"] = state_dict["middle_block_out.0.bias"] + + # move the tensors to the device and dtype of the ControlLora + zero_convolution_tensors = { + key: value.to( + dtype=control_lora.dtype, + device=control_lora.device, + ) + for key, value in zero_convolution_tensors.items() + } + + # find the ZeroConvolution's Conv2d layers + zero_convolution_layers = list(control_lora.layers(ZeroConvolution)) + zero_convolution_conv2ds = [layer.ensure_find(Conv2d) for layer in zero_convolution_layers] + + # replace the Conv2d layers' weights and biases with the ones from the state_dict + for i, layer in enumerate(zero_convolution_conv2ds): + layer.weight = TorchParameter(zero_convolution_tensors[f"zero_convs.{i}.0.weight"]) + layer.bias = TorchParameter(zero_convolution_tensors[f"zero_convs.{i}.0.bias"]) + + +def simplify_key(key: str, prefix: str, index: int | None = None) -> str: + """Simplify a key by stripping everything to the left of the prefix. + + Also optionally add a zero-padded index to the prefix. + + Example: + >>> simplify_key("foo.bar.ControlLora.something", "ControlLora", 1) + "ControlLora_01.something" + + >>> simplify_key("foo.bar.ControlLora.DownBlocks.something", "ControlLora") + "ControlLora.DownBlocks.something" + + Args: + key: The key to simplify. + prefix: The prefix to remove. + index: The index to add. + """ + _, right = key.split(prefix, maxsplit=1) + if index: + return f"{prefix}_{index:02d}{right}" + else: + return f"{prefix}{right}" + + +def convert_lora_layers( + lora_layers: dict[str, Lora[Linear | Conv2d]], + control_lora: ControlLora, + refiners_state_dict: dict[str, Tensor], +) -> None: + """Convert the LoRA layers to the refiners format. + + Args: + lora_layers: The LoRA layers to convert. + control_lora: The ControlLora to convert the LoRA layers from. + refiners_state_dict: The refiners state dict to update with the converted LoRA layers. + """ + for lora_layer in lora_layers.values(): + # get the adapter associated with the LoRA layer + lora_adapter = lora_layer.parent + assert isinstance(lora_adapter, LoraAdapter) + + # get the path of the adapter's target in the ControlLora + target = lora_adapter.target + path = target.get_path(parent=control_lora.ensure_find_parent(target)) + + state_dict = { + f"{path}.down": lora_layer.down.weight, + f"{path}.up": lora_layer.up.weight, + } + state_dict = {simplify_key(key, "ControlLora"): param for key, param in state_dict.items()} + refiners_state_dict.update(state_dict) + + +def convert_zero_convolutions( + control_lora: ControlLora, + refiners_state_dict: dict[str, Tensor], +) -> None: + """Convert the ZeroConvolution layers to the refiners format. + + Args: + control_lora: The ControlLora to convert the ZeroConvolution layers from. + refiners_state_dict: The refiners state dict to update with the converted ZeroConvolution layers. + """ + zero_convolution_layers = list(control_lora.layers(ZeroConvolution)) + for i, zero_convolution_layer in enumerate(zero_convolution_layers): + state_dict = zero_convolution_layer.state_dict() + path = zero_convolution_layer.get_path() + state_dict = {f"{path}.{key}": param for key, param in state_dict.items()} + state_dict = {simplify_key(key, "ZeroConvolution", i + 1): param for key, param in state_dict.items()} + refiners_state_dict.update(state_dict) + + +def convert_condition_encoder( + control_lora: ControlLora, + refiners_state_dict: dict[str, Tensor], +) -> None: + """Convert the ConditionEncoder to the refiners format. + + Args: + control_lora: The ControlLora to convert the ConditionEncoder from. + refiners_state_dict: The refiners state dict to update with the converted ConditionEncoder. + """ + condition_encoder_layer = control_lora.ensure_find(ConditionEncoder) + path = condition_encoder_layer.get_path() + state_dict = condition_encoder_layer.state_dict() + state_dict = {f"{path}.{key}": param for key, param in state_dict.items()} + state_dict = {simplify_key(key, "ConditionEncoder"): param for key, param in state_dict.items()} + refiners_state_dict.update(state_dict) + + +def convert( + name: str, + state_dict_path: Path, + output_path: Path, +) -> None: + sdxl = StableDiffusion_XL() + info("Stable Diffusion XL model initialized") + + fooocus_state_dict = load_from_safetensors(state_dict_path) + info(f"Fooocus weights loaded from: {state_dict_path}") + + control_lora_adapter = ControlLoraAdapter(target=sdxl.unet, name=name).inject() + control_lora = control_lora_adapter.control_lora + info("ControlLoraAdapter initialized") + + lora_layers = load_lora_layers(name, fooocus_state_dict, control_lora) + info("LoRA layers loaded") + + load_zero_convolutions(fooocus_state_dict, control_lora) + info("ZeroConvolution layers loaded") + + load_condition_encoder(fooocus_state_dict, control_lora) + info("ConditionEncoder loaded") + + refiners_state_dict: dict[str, Tensor] = {} + convert_lora_layers(lora_layers, control_lora, refiners_state_dict) + info("LoRA layers converted to refiners format") + + convert_zero_convolutions(control_lora, refiners_state_dict) + info("ZeroConvolution layers converted to refiners format") + + convert_condition_encoder(control_lora, refiners_state_dict) + info("ConditionEncoder converted to refiners format") + + output_path.parent.mkdir(parents=True, exist_ok=True) + save_to_safetensors(path=output_path, tensors=refiners_state_dict) + info(f"Converted ControlLora state dict saved to disk at: {output_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Convert ControlLora (from Fooocus) weights to refiners.", + ) + + parser.add_argument( + "--from", + type=Path, + dest="source_path", + default="lllyasviel/misc:control-lora-canny-rank128.safetensors", + help="Path to the state_dict of the ControlLora, or a Hugging Face model ID.", + ) + + parser.add_argument( + "--to", + type=Path, + dest="output_path", + help=( + "Path to save the converted model (extension will be .safetensors)." + "If not specified, the output path will be the source path with the extension changed to .safetensors." + ), + ) + + parser.add_argument( + "--verbose", + action="store_true", + dest="verbose", + default=False, + help="Use this flag to print verbose output during conversion.", + ) + + args = parser.parse_args() + + if args.verbose: + logging.basicConfig( + level=logging.INFO, + format="%(levelname)s: %(message)s", + ) + + if not args.source_path.exists(): + repo_id, filename = str(args.source_path).split(":") + args.source_path = Path( + hf_hub_download( + repo_id=repo_id, + filename=filename, + ) + ) + + if args.output_path is None: + args.output_path = Path(f"refiners_{args.source_path.stem}.safetensors") + + convert( + name=args.source_path.stem, + state_dict_path=args.source_path, + output_path=args.output_path, + ) diff --git a/scripts/prepare_test_weights.py b/scripts/prepare_test_weights.py index f611e1e..ce42c77 100644 --- a/scripts/prepare_test_weights.py +++ b/scripts/prepare_test_weights.py @@ -273,6 +273,20 @@ def download_controlnet(): download_files(urls, mfidabel_folder) +def download_control_lora_fooocus(): + base_folder = os.path.join(test_weights_dir, "lllyasviel", "misc") + control_loras = [ + "control-lora-canny-rank128.safetensors", + "fooocus_xl_cpds_128.safetensors", + ] + + for control_lora in control_loras: + download_file( + url=f"https://huggingface.co/lllyasviel/misc/resolve/main/{control_lora}", + dest_folder=base_folder, + ) + + def download_unclip(): base_folder = os.path.join(test_weights_dir, "stabilityai", "stable-diffusion-2-1-unclip") download_file( @@ -624,6 +638,21 @@ def convert_dinov2(): ) +def convert_control_lora_fooocus(): + run_conversion_script( + "convert_fooocus_control_lora.py", + "tests/weights/lllyasviel/misc/control-lora-canny-rank128.safetensors", + "tests/weights/control_lora/refiners_control-lora-canny-rank128.safetensors", + expected_hash="4d505134", + ) + run_conversion_script( + "convert_fooocus_control_lora.py", + "tests/weights/lllyasviel/misc/fooocus_xl_cpds_128.safetensors", + "tests/weights/control_lora/refiners_fooocus_xl_cpds_128.safetensors", + expected_hash="d81aa461", + ) + + def download_all(): print(f"\nAll weights will be downloaded to {test_weights_dir}\n") download_sd15("runwayml/stable-diffusion-v1-5") @@ -639,6 +668,7 @@ def download_all(): download_t2i_adapter() download_sam() download_dinov2() + download_control_lora_fooocus() def convert_all(): @@ -654,6 +684,7 @@ def convert_all(): convert_t2i_adapter() convert_sam() convert_dinov2() + convert_control_lora_fooocus() def main():