From 7ca6bd0ccd4af2387925d88b9261ee8b9bb068b4 Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Thu, 24 Aug 2023 02:26:37 +0200 Subject: [PATCH] implement the ConvertModule class and refactor conversion scripts --- README.md | 12 +- .../convert_diffusers_autoencoder_kl.py | 67 +++ .../convert_diffusers_controlnet.py} | 74 ++- .../convert_diffusers_lora.py} | 73 ++- scripts/conversion/convert_diffusers_unet.py | 96 +++ .../convert_informative_drawings.py | 64 ++ .../convert_refiners_lora_to_sdwebui.py} | 76 +-- .../convert_transformers_clip_text_model.py | 104 ++++ scripts/convert-clip-weights.py | 52 -- .../convert-informative-drawings-weights.py | 57 -- scripts/convert-sd-lda-weights.py | 53 -- scripts/convert-sd-unet-inpainting-weights.py | 58 -- scripts/convert-sd-unet-weights.py | 58 -- scripts/convert-sdxl-text-encoder-2.py | 57 -- scripts/convert-sdxl-unet-weights.py | 65 --- src/refiners/fluxion/model_converter.py | 545 ++++++++++++++++++ src/refiners/fluxion/utils.py | 234 +------- .../latent_diffusion/test_sdxl_unet.py | 30 +- 18 files changed, 1029 insertions(+), 746 deletions(-) create mode 100644 scripts/conversion/convert_diffusers_autoencoder_kl.py rename scripts/{convert-controlnet-weights.py => conversion/convert_diffusers_controlnet.py} (73%) rename scripts/{convert-lora-weights.py => conversion/convert_diffusers_lora.py} (66%) create mode 100644 scripts/conversion/convert_diffusers_unet.py create mode 100644 scripts/conversion/convert_informative_drawings.py rename scripts/{convert-loras-to-sdwebui.py => conversion/convert_refiners_lora_to_sdwebui.py} (60%) create mode 100644 scripts/conversion/convert_transformers_clip_text_model.py delete mode 100644 scripts/convert-clip-weights.py delete mode 100644 scripts/convert-informative-drawings-weights.py delete mode 100644 scripts/convert-sd-lda-weights.py delete mode 100644 scripts/convert-sd-unet-inpainting-weights.py delete mode 100644 scripts/convert-sd-unet-weights.py delete mode 100644 scripts/convert-sdxl-text-encoder-2.py delete mode 100644 scripts/convert-sdxl-unet-weights.py create mode 100644 src/refiners/fluxion/model_converter.py diff --git a/README.md b/README.md index 841e0ee..b1af641 100644 --- a/README.md +++ b/README.md @@ -212,9 +212,9 @@ Here is how to perform a text-to-image inference using the Stable Diffusion 1.5 Step 1: prepare the model weights in refiners' format: ```bash -python scripts/convert-clip-weights.py --output-file CLIPTextEncoderL.safetensors -python scripts/convert-sd-lda-weights.py --output-file lda.safetensors -python scripts/convert-sd-unet-weights.py --output-file unet.safetensors +python scripts/conversion/convert_transformers_clip_text_model.py --to clip.safetensors +python scripts/conversion/convert_diffusers_autoencoder_kl.py --to lda.safetensors +python scripts/conversion/convert_diffusers_unet.py --to unet.safetensors ``` > Note: this will download the original weights from https://huggingface.co/runwayml/stable-diffusion-v1-5 which takes some time. If you already have this repo cloned locally, use the `--from /path/to/stable-diffusion-v1-5` option instead. @@ -223,9 +223,9 @@ Step 2: download and convert a community Pokemon LoRA, e.g. [this one](https://h ```bash curl -LO https://huggingface.co/pcuenq/pokemon-lora/resolve/main/pytorch_lora_weights.bin -python scripts/convert-lora-weights.py \ +python scripts/conversion/convert_diffusers_lora.py \ --from pytorch_lora_weights.bin \ - --output-file pokemon_lora.safetensors + --to pokemon_lora.safetensors ``` Step 3: run inference using the GPU: @@ -238,7 +238,7 @@ import torch sd15 = StableDiffusion_1(device="cuda") -sd15.clip_text_encoder.load_state_dict(load_from_safetensors("CLIPTextEncoderL.safetensors")) +sd15.clip_text_encoder.load_state_dict(load_from_safetensors("clip.safetensors")) sd15.lda.load_state_dict(load_from_safetensors("lda.safetensors")) sd15.unet.load_state_dict(load_from_safetensors("unet.safetensors")) diff --git a/scripts/conversion/convert_diffusers_autoencoder_kl.py b/scripts/conversion/convert_diffusers_autoencoder_kl.py new file mode 100644 index 0000000..de0c9cd --- /dev/null +++ b/scripts/conversion/convert_diffusers_autoencoder_kl.py @@ -0,0 +1,67 @@ +import argparse +from pathlib import Path +import torch +from torch import nn +from diffusers import AutoencoderKL # type: ignore +from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder +from refiners.fluxion.model_converter import ModelConverter + + +class Args(argparse.Namespace): + source_path: str + output_path: str | None + use_half: bool + verbose: bool + + +def setup_converter(args: Args) -> ModelConverter: + target = LatentDiffusionAutoencoder() + source: nn.Module = AutoencoderKL.from_pretrained(pretrained_model_name_or_path=args.source_path, subfolder="vae") # type: ignore + x = torch.randn(1, 3, 512, 512) + converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose) + if not converter.run(source_args=(x,)): + raise RuntimeError("Model conversion failed") + return converter + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Convert a pretrained diffusers AutoencoderKL model to a refiners Latent Diffusion Autoencoder" + ) + parser.add_argument( + "--from", + type=str, + dest="source_path", + default="runwayml/stable-diffusion-v1-5", + help="Path to the source pretrained model (default: 'runwayml/stable-diffusion-v1-5').", + ) + parser.add_argument( + "--to", + type=str, + dest="output_path", + default=None, + help=( + "Path to save the converted model (extension will be .safetensors). If not specified, the output path will" + " be the source path with the extension changed to .safetensors." + ), + ) + parser.add_argument( + "--half", + action="store_true", + dest="use_half", + default=False, + help="Use this flag to save the output file as half precision (default: full precision).", + ) + parser.add_argument( + "--verbose", + action="store_true", + dest="verbose", + default=False, + help="Use this flag to print verbose output during conversion.", + ) + args = parser.parse_args(namespace=Args()) + if args.output_path is None: + args.output_path = f"{Path(args.source_path).stem}-autoencoder.safetensors" + assert args.output_path is not None + converter = setup_converter(args=args) + converter.save_to_safetensors(path=args.output_path, half=args.use_half) diff --git a/scripts/convert-controlnet-weights.py b/scripts/conversion/convert_diffusers_controlnet.py similarity index 73% rename from scripts/convert-controlnet-weights.py rename to scripts/conversion/convert_diffusers_controlnet.py index 9707222..42bda93 100644 --- a/scripts/convert-controlnet-weights.py +++ b/scripts/conversion/convert_diffusers_controlnet.py @@ -1,18 +1,24 @@ +# pyright: reportPrivateUsage=false +import argparse +from pathlib import Path import torch +from torch import nn from diffusers import ControlNetModel # type: ignore -from refiners.fluxion.utils import ( - forward_order_of_execution, - verify_shape_match, - convert_state_dict, - save_to_safetensors, -) -from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet +from refiners.fluxion.utils import save_to_safetensors +from refiners.fluxion.model_converter import ModelConverter +from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1Controlnet from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver from refiners.foundationals.latent_diffusion import SD1UNet +class Args(argparse.Namespace): + source_path: str + output_path: str | None + + @torch.no_grad() -def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]: +def convert(args: Args) -> dict[str, torch.Tensor]: + controlnet_src: nn.Module = ControlNetModel.from_pretrained(pretrained_model_name_or_path=args.source_path) # type: ignore controlnet = SD1Controlnet(name="mycn") condition = torch.randn(1, 3, 512, 512) @@ -33,10 +39,16 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]: # to diffusers in order, since we compute the residuals inline instead of # in a separate step. - source_order = forward_order_of_execution(module=controlnet_src, example_args=(x, timestep, clip_text_embedding, condition)) # type: ignore - target_order = forward_order_of_execution(module=controlnet, example_args=(x,)) + converter = ModelConverter( + source_model=controlnet_src, target_model=controlnet, skip_output_check=True, verbose=False + ) - broken_k = ("Conv2d", (torch.Size([320, 320, 1, 1]), torch.Size([320]))) + source_order = converter._trace_module_execution_order( + module=controlnet_src, args=(x, timestep, clip_text_embedding, condition), keys_to_skip=[] + ) + target_order = converter._trace_module_execution_order(module=controlnet, args=(x,), keys_to_skip=[]) + + broken_k = (str(object=nn.Conv2d), (torch.Size([320, 320, 1, 1]), torch.Size([320]))) expected_source_order = [ "down_blocks.0.attentions.0.proj_in", @@ -75,7 +87,7 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]: assert target_order[broken_k] == expected_target_order source_order[broken_k] = fixed_source_order - broken_k = ("Conv2d", (torch.Size([640, 640, 1, 1]), torch.Size([640]))) + broken_k = (str(object=nn.Conv2d), (torch.Size([640, 640, 1, 1]), torch.Size([640]))) expected_source_order = [ "down_blocks.1.attentions.0.proj_in", @@ -111,7 +123,7 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]: assert target_order[broken_k] == expected_target_order source_order[broken_k] = fixed_source_order - broken_k = ("Conv2d", (torch.Size([1280, 1280, 1, 1]), torch.Size([1280]))) + broken_k = (str(object=nn.Conv2d), (torch.Size([1280, 1280, 1, 1]), torch.Size([1280]))) expected_source_order = [ "down_blocks.2.attentions.0.proj_in", @@ -162,7 +174,7 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]: assert target_order[broken_k] == expected_target_order source_order[broken_k] = fixed_source_order - assert verify_shape_match(source_order=source_order, target_order=target_order) + assert converter._assert_shapes_aligned(source_order=source_order, target_order=target_order), "Shapes not aligned" mapping: dict[str, str] = {} for model_type_shape in source_order: @@ -170,7 +182,7 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]: target_keys = target_order[model_type_shape] mapping.update(zip(target_keys, source_keys)) - state_dict = convert_state_dict( + state_dict = converter._convert_state_dict( source_state_dict=controlnet_src.state_dict(), target_state_dict=controlnet.state_dict(), state_dict_mapping=mapping, @@ -180,27 +192,33 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]: def main() -> None: - import argparse - - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(description="Convert a diffusers ControlNet model to a Refiners ControlNet model") parser.add_argument( "--from", type=str, - dest="source", - required=True, - help="Source model", + dest="source_path", + default="lllyasviel/sd-controlnet-depth", + help=( + "Can be a path to a .bin, a .safetensors file, or a model identifier from Hugging Face Hub. Defaults to" + " lllyasviel/sd-controlnet-depth" + ), ) parser.add_argument( - "--output-file", + "--to", type=str, + dest="output_path", required=False, - default="output.safetensors", - help="Path for the output file", + default=None, + help=( + "Output path (.safetensors) for converted model. If not provided, the output path will be the same as the" + " source path." + ), ) - args = parser.parse_args() - controlnet_src = ControlNetModel.from_pretrained(pretrained_model_name_or_path=args.source) # type: ignore - tensors = convert(controlnet_src=controlnet_src) # type: ignore - save_to_safetensors(path=args.output_file, tensors=tensors) + args = parser.parse_args(namespace=Args()) + if args.output_path is None: + args.output_path = f"{Path(args.source_path).stem}-controlnet.safetensors" + state_dict = convert(args=args) + save_to_safetensors(path=args.output_path, tensors=state_dict) if __name__ == "__main__": diff --git a/scripts/convert-lora-weights.py b/scripts/conversion/convert_diffusers_lora.py similarity index 66% rename from scripts/convert-lora-weights.py rename to scripts/conversion/convert_diffusers_lora.py index 3abd146..b7b568f 100644 --- a/scripts/convert-lora-weights.py +++ b/scripts/conversion/convert_diffusers_lora.py @@ -1,20 +1,17 @@ -# Note: this conversion script currently only support simple LoRAs which adapt -# the UNet's attentions such as https://huggingface.co/pcuenq/pokemon-lora - +import argparse +from pathlib import Path from typing import cast import torch +from torch import Tensor from torch.nn.init import zeros_ from torch.nn import Parameter as TorchParameter - +from diffusers import DiffusionPipeline # type: ignore import refiners.fluxion.layers as fl - +from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.utils import save_to_safetensors from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet from refiners.foundationals.latent_diffusion.lora import LoraTarget, apply_loras_to_target from refiners.adapters.lora import Lora -from refiners.fluxion.utils import create_state_dict_mapping - -from diffusers import DiffusionPipeline # type: ignore def get_weight(linear: fl.Linear) -> torch.Tensor: @@ -31,10 +28,17 @@ def build_loras_safetensors(module: fl.Chain, key_prefix: str) -> dict[str, torc return {f"{key_prefix}{i:03d}": w for i, w in enumerate(iterable=weights)} +class Args(argparse.Namespace): + source_path: str + base_model: str + output_file: str + verbose: bool + + @torch.no_grad() -def process(source: str, base_model: str, output_file: str) -> None: - diffusers_state_dict = torch.load(source, map_location="cpu") # type: ignore - diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=base_model) # type: ignore +def process(args: Args) -> None: + diffusers_state_dict = cast(dict[str, Tensor], torch.load(args.source_path, map_location="cpu")) # type: ignore + diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.base_model) # type: ignore diffusers_model = cast(fl.Module, diffusers_sd.unet) # type: ignore refiners_model = SD1UNet(in_channels=4, clip_embedding_dim=768) @@ -54,10 +58,16 @@ def process(source: str, base_model: str, output_file: str) -> None: diffusers_args = (x, timestep, clip_text_embeddings) - diffusers_to_refiners = create_state_dict_mapping( - source_model=refiners_model, target_model=diffusers_model, source_args=refiners_args, target_args=diffusers_args + converter = ModelConverter( + source_model=refiners_model, target_model=diffusers_model, skip_output_check=True, verbose=args.verbose ) - assert diffusers_to_refiners is not None, "Model conversion failed" + if not converter.run( + source_args=refiners_args, + target_args=diffusers_args, + ): + raise RuntimeError("Model conversion failed") + + diffusers_to_refiners = converter.get_mapping() apply_loras_to_target(module=refiners_model, target=LoraTarget(target), rank=rank, scale=1.0) for layer in refiners_model.layers(layer_type=Lora): @@ -83,36 +93,47 @@ def process(source: str, base_model: str, output_file: str) -> None: state_dict = build_loras_safetensors(module=refiners_model, key_prefix="unet.") assert len(state_dict) == 320 - save_to_safetensors(path=output_file, tensors=state_dict, metadata=metadata) + save_to_safetensors(path=args.output_path, tensors=state_dict, metadata=metadata) def main() -> None: - import argparse - - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(description="Convert LoRAs saved using the diffusers library to refiners format.") parser.add_argument( "--from", type=str, - dest="source", + dest="source_path", required=True, - help="Source file path (.bin)", + help="Source file path (.bin|safetensors) containing the LoRAs.", ) parser.add_argument( "--base-model", type=str, required=False, default="runwayml/stable-diffusion-v1-5", - help="Base model", + help="Base model, used for the UNet structure. Default: runwayml/stable-diffusion-v1-5", ) parser.add_argument( - "--output-file", + "--to", type=str, + dest="output_path", required=False, - default="output.safetensors", - help="Path for the output file", + default=None, + help=( + "Output file path (.safetensors) for converted LoRAs. If not provided, the output path will be the same as" + " the source path." + ), ) - args = parser.parse_args() - process(source=args.source, base_model=args.base_model, output_file=args.output_file) + parser.add_argument( + "--verbose", + action="store_true", + dest="verbose", + default=False, + help="Use this flag to print verbose output during conversion.", + ) + args = parser.parse_args(namespace=Args()) + if args.output_path is None: + args.output_path = f"{Path(args.source_path).stem}-refiners.safetensors" + process(args=args) if __name__ == "__main__": diff --git a/scripts/conversion/convert_diffusers_unet.py b/scripts/conversion/convert_diffusers_unet.py new file mode 100644 index 0000000..fea8f2a --- /dev/null +++ b/scripts/conversion/convert_diffusers_unet.py @@ -0,0 +1,96 @@ +import argparse +from pathlib import Path +import torch +from torch import nn +from refiners.fluxion.model_converter import ModelConverter +from diffusers import UNet2DConditionModel # type: ignore +from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet + + +class Args(argparse.Namespace): + source_path: str + output_path: str | None + half: bool + verbose: bool + + +def setup_converter(args: Args) -> ModelConverter: + source: nn.Module = UNet2DConditionModel.from_pretrained( # type: ignore + pretrained_model_name_or_path=args.source_path, subfolder="unet" + ) + source_in_channels: int = source.config.in_channels # type: ignore + source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore + source_has_time_ids: bool = source.config.addition_embed_type == "text_time" # type: ignore + target = ( + SDXLUNet(in_channels=source_in_channels) + if source_has_time_ids + else SD1UNet(in_channels=source_in_channels, clip_embedding_dim=source_clip_embedding_dim) + ) + + x = torch.randn(1, source_in_channels, 32, 32) + timestep = torch.tensor(data=[0]) + clip_text_embeddings = torch.randn(1, 77, source_clip_embedding_dim) + + target.set_timestep(timestep=timestep) + target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings) + added_cond_kwargs = {} + if source_has_time_ids: + added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)} + target.set_time_ids(time_ids=added_cond_kwargs["time_ids"]) + target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"]) + + target_args = (x,) + source_args = { + "positional": (x, timestep, clip_text_embeddings), + "keyword": {"added_cond_kwargs": added_cond_kwargs} if source_has_time_ids else {}, + } + + converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose) + if not converter.run( + source_args=source_args, + target_args=target_args, + ): + raise RuntimeError("Model conversion failed") + return converter + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Converts a Diffusion UNet model to a Refiners SD1UNet or SDXLUNet model" + ) + parser.add_argument( + "--from", + type=str, + dest="source_path", + default="runwayml/stable-diffusion-v1-5", + help=( + "Can be a path to a .bin file, a .safetensors file or a model name from the HuggingFace Hub. Default:" + " runwayml/stable-diffusion-v1-5" + ), + ) + parser.add_argument( + "--to", + type=str, + dest="output_path", + default=None, + help=( + "Output path (.safetensors) for converted model. If not provided, the output path will be the same as the" + " source path." + ), + ) + parser.add_argument("--half", action="store_true", default=True, help="Convert to half precision. Default: True") + parser.add_argument( + "--verbose", + action="store_true", + default=False, + help="Prints additional information during conversion. Default: False", + ) + args = parser.parse_args(namespace=Args()) + if args.output_path is None: + args.output_path = f"{Path(args.source_path).stem}-unet.safetensors" + converter = setup_converter(args=args) + converter.save_to_safetensors(path=args.output_path, half=args.half) + + +if __name__ == "__main__": + main() diff --git a/scripts/conversion/convert_informative_drawings.py b/scripts/conversion/convert_informative_drawings.py new file mode 100644 index 0000000..75f4000 --- /dev/null +++ b/scripts/conversion/convert_informative_drawings.py @@ -0,0 +1,64 @@ +import argparse +from typing import TYPE_CHECKING, cast +import torch +from torch import nn +from refiners.fluxion.model_converter import ModelConverter +from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings + +try: + from model import Generator # type: ignore +except ImportError: + raise ImportError( + "Please download the model.py file from https://github.com/carolineec/informative-drawings and add it to your" + " PYTHONPATH" + ) +if TYPE_CHECKING: + Generator = cast(nn.Module, Generator) + + +class Args(argparse.Namespace): + source_path: str + output_path: str + verbose: bool + half: bool + + +def setup_converter(args: Args) -> ModelConverter: + source = Generator(3, 1, 3) + source.load_state_dict(state_dict=torch.load(f=args.source_path, map_location="cpu")) # type: ignore + source.eval() + target = InformativeDrawings() + x = torch.randn(1, 3, 512, 512) + converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose) + if not converter.run(source_args=(x,)): + raise RuntimeError("Model conversion failed") + return converter + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Converts a pretrained Informative Drawings model to a refiners Informative Drawings model" + ) + parser.add_argument( + "--from", + type=str, + dest="source_path", + default="model2.pth", + help="Path to the source model. (default: 'model2.pth').", + ) + parser.add_argument( + "--to", + type=str, + dest="output_path", + default="informative-drawings.safetensors", + help="Path to save the converted model. (default: 'informative-drawings.safetensors').", + ) + parser.add_argument("--verbose", action="store_true", dest="verbose") + parser.add_argument("--half", action="store_true", dest="half") + args = parser.parse_args(namespace=Args()) + converter = setup_converter(args=args) + converter.save_to_safetensors(path=args.output_path, half=args.half) + + +if __name__ == "__main__": + main() diff --git a/scripts/convert-loras-to-sdwebui.py b/scripts/conversion/convert_refiners_lora_to_sdwebui.py similarity index 60% rename from scripts/convert-loras-to-sdwebui.py rename to scripts/conversion/convert_refiners_lora_to_sdwebui.py index b5613e8..f263809 100644 --- a/scripts/convert-loras-to-sdwebui.py +++ b/scripts/conversion/convert_refiners_lora_to_sdwebui.py @@ -1,49 +1,36 @@ +import argparse +from functools import partial +from torch import Tensor from refiners.fluxion.utils import ( load_from_safetensors, load_metadata_from_safetensors, save_to_safetensors, ) +from convert_diffusers_unet import setup_converter as convert_unet, Args as UnetConversionArgs +from convert_transformers_clip_text_model import ( + setup_converter as convert_text_encoder, + Args as TextEncoderConversionArgs, +) from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL -from refiners.foundationals.clip.tokenizer import CLIPTokenizer -from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet from refiners.foundationals.latent_diffusion.lora import LoraTarget -from refiners.fluxion.layers.module import Module +from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1UNet import refiners.fluxion.layers as fl -from refiners.fluxion.utils import create_state_dict_mapping - -import torch - -from diffusers import DiffusionPipeline # type: ignore -from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore -from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore -@torch.no_grad() -def create_unet_mapping(src_model: UNet2DConditionModel, dst_model: SD1UNet) -> dict[str, str] | None: - x = torch.randn(1, 4, 32, 32) - timestep = torch.tensor(data=[0]) - clip_text_embeddings = torch.randn(1, 77, 768) - - src_args = (x, timestep, clip_text_embeddings) - dst_model.set_timestep(timestep=timestep) - dst_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings) - dst_args = (x,) - - return create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=src_args, target_args=dst_args) # type: ignore +def get_unet_mapping(source_path: str) -> dict[str, str]: + args = UnetConversionArgs(source_path=source_path, verbose=False) + return convert_unet(args=args).get_mapping() -@torch.no_grad() -def create_text_encoder_mapping(src_model: CLIPTextModel, dst_model: CLIPTextEncoderL) -> dict[str, str] | None: - tokenizer = dst_model.find(layer_type=CLIPTokenizer) - assert tokenizer is not None, "Could not find tokenizer" - tokens = tokenizer("Nice cat") - return create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[tokens], target_args=["Nice cat"]) # type: ignore +def get_text_encoder_mapping(source_path: str) -> dict[str, str]: + args = TextEncoderConversionArgs(source_path=source_path, subfolder="text_encoder", verbose=False) + return convert_text_encoder( + args=args, + ).get_mapping() def main() -> None: - import argparse - - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(description="Converts a refiner's LoRA weights to SD-WebUI's LoRA weights") parser.add_argument( "-i", "--input-file", @@ -55,7 +42,7 @@ def main() -> None: "-o", "--output-file", type=str, - required=True, + default="sdwebui_loras.safetensors", help="Path to the output file with sd-webui's LoRA weights (safetensors format)", ) parser.add_argument( @@ -66,27 +53,22 @@ def main() -> None: help="Path (preferred) or repository ID of Stable Diffusion 1.5 model (Hugging Face diffusers format)", ) args = parser.parse_args() - metadata = load_metadata_from_safetensors(path=args.input_file) - assert metadata is not None + assert metadata is not None, f"Could not load metadata from {args.input_file}" tensors = load_from_safetensors(path=args.input_file) - diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.sd15) # type: ignore - - state_dict: dict[str, torch.Tensor] = {} + state_dict: dict[str, Tensor] = {} for meta_key, meta_value in metadata.items(): match meta_key: case "unet_targets": - src_model = diffusers_sd.unet # type: ignore - dst_model = SD1UNet(in_channels=4, clip_embedding_dim=768) - create_mapping = create_unet_mapping + model = SD1UNet(in_channels=4, clip_embedding_dim=768) + create_mapping = partial(get_unet_mapping, source_path=args.sd15) key_prefix = "unet." lora_prefix = "lora_unet_" case "text_encoder_targets": - src_model = diffusers_sd.text_encoder # type: ignore - dst_model = CLIPTextEncoderL() - create_mapping = create_text_encoder_mapping + model = CLIPTextEncoderL() + create_mapping = partial(get_text_encoder_mapping, source_path=args.sd15) key_prefix = "text_encoder." lora_prefix = "lora_te_" case "lda_targets": @@ -94,8 +76,8 @@ def main() -> None: case _: raise ValueError(f"Unexpected key in checkpoint metadata: {meta_key}") - submodule_to_key: dict[Module, str] = {} - for name, submodule in dst_model.named_modules(): + submodule_to_key: dict[fl.Module, str] = {} + for name, submodule in model.named_modules(): submodule_to_key[submodule] = name # SD-WebUI expects LoRA state dicts with keys derived from the diffusers format, e.g.: @@ -110,14 +92,14 @@ def main() -> None: # # [1]: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/394ffa7/extensions-builtin/Lora/lora.py#L158-L225 - refiners_to_diffusers = create_mapping(src_model, dst_model) # type: ignore + refiners_to_diffusers = create_mapping() assert refiners_to_diffusers is not None # Compute the corresponding diffusers' keys where LoRA layers must be applied lora_injection_points: list[str] = [ refiners_to_diffusers[submodule_to_key[linear]] for target in [LoraTarget(t) for t in meta_value.split(sep=",")] - for layer in dst_model.layers(layer_type=target.get_class()) + for layer in model.layers(layer_type=target.get_class()) for linear in layer.layers(layer_type=fl.Linear) ] diff --git a/scripts/conversion/convert_transformers_clip_text_model.py b/scripts/conversion/convert_transformers_clip_text_model.py new file mode 100644 index 0000000..e97c988 --- /dev/null +++ b/scripts/conversion/convert_transformers_clip_text_model.py @@ -0,0 +1,104 @@ +import argparse +from pathlib import Path +from torch import nn +from refiners.fluxion.model_converter import ModelConverter +from transformers import CLIPTextModelWithProjection # type: ignore +from refiners.foundationals.clip.text_encoder import CLIPTextEncoder +from refiners.foundationals.clip.tokenizer import CLIPTokenizer +import refiners.fluxion.layers as fl + + +class Args(argparse.Namespace): + source_path: str + subfolder: str + output_path: str | None + use_half: bool + verbose: bool + + +def setup_converter(args: Args) -> ModelConverter: + source: nn.Module = CLIPTextModelWithProjection.from_pretrained( # type: ignore + pretrained_model_name_or_path=args.source_path, subfolder=args.subfolder + ) + assert isinstance(source, nn.Module), "Source model is not a nn.Module" + architecture: str = source.config.architectures[0] # type: ignore + embedding_dim: int = source.config.hidden_size # type: ignore + projection_dim: int = source.config.projection_dim # type: ignore + num_layers: int = source.config.num_hidden_layers # type: ignore + num_attention_heads: int = source.config.num_attention_heads # type: ignore + feed_forward_dim: int = source.config.intermediate_size # type: ignore + use_quick_gelu: bool = source.config.hidden_act == "quick_gelu" # type: ignore + target = CLIPTextEncoder( + embedding_dim=embedding_dim, + num_layers=num_layers, + num_attention_heads=num_attention_heads, + feedforward_dim=feed_forward_dim, + use_quick_gelu=use_quick_gelu, + ) + match architecture: + case "CLIPTextModel": + source.text_projection = fl.Identity() + case "CLIPTextModelWithProjection": + target.append(module=fl.Linear(in_features=embedding_dim, out_features=projection_dim, bias=False)) + case _: + raise RuntimeError(f"Unsupported architecture: {architecture}") + text = "What a nice cat you have there!" + tokenizer = target.find(layer_type=CLIPTokenizer) + assert tokenizer is not None, "Could not find tokenizer" + tokens = tokenizer(text) + converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose) + if not converter.run(source_args=(tokens,), target_args=(text,)): + raise RuntimeError("Model conversion failed") + return converter + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Converts a CLIPTextEncoder from the library transformers from the HuggingFace Hub to refiners." + ) + parser.add_argument( + "--from", + type=str, + dest="source_path", + default="runwayml/stable-diffusion-v1-5", + help=( + "Can be a path to a .bin file, a .safetensors file or a model name from the HuggingFace Hub. Default:" + " runwayml/stable-diffusion-v1-5" + ), + ) + parser.add_argument( + "--subfolder", + type=str, + dest="subfolder", + default="text_encoder", + help=( + "Subfolder in the source path where the model is located inside the Hub. Default: text_encoder (for" + " CLIPTextModel)" + ), + ) + parser.add_argument( + "--to", + type=str, + dest="output_path", + default=None, + help=( + "Output path (.safetensors) for converted model. If not provided, the output path will be the same as the" + " source path." + ), + ) + parser.add_argument("--half", action="store_true", default=True, help="Convert to half precision. Default: True") + parser.add_argument( + "--verbose", + action="store_true", + default=False, + help="Prints additional information during conversion. Default: False", + ) + args = parser.parse_args(namespace=Args()) + if args.output_path is None: + args.output_path = f"{Path(args.source_path).stem}-{args.subfolder}.safetensors" + converter = setup_converter(args=args) + converter.save_to_safetensors(path=args.output_path, half=args.half) + + +if __name__ == "__main__": + main() diff --git a/scripts/convert-clip-weights.py b/scripts/convert-clip-weights.py deleted file mode 100644 index 3e52e4a..0000000 --- a/scripts/convert-clip-weights.py +++ /dev/null @@ -1,52 +0,0 @@ -import torch - -from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict, save_to_safetensors - -from diffusers import DiffusionPipeline # type: ignore -from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore - -from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL -from refiners.foundationals.clip.tokenizer import CLIPTokenizer - - -@torch.no_grad() -def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]: - dst_model = CLIPTextEncoderL() - tokenizer = dst_model.find(layer_type=CLIPTokenizer) - assert tokenizer is not None, "Could not find tokenizer" - tokens = tokenizer("Nice cat") - mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[tokens], target_args=["Nice cat"]) # type: ignore - assert mapping is not None, "Model conversion failed" - state_dict = convert_state_dict( - source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping - ) - return {k: v.half() for k, v in state_dict.items()} - - -def main() -> None: - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "--from", - type=str, - dest="source", - required=False, - default="runwayml/stable-diffusion-v1-5", - help="Source model", - ) - parser.add_argument( - "--output-file", - type=str, - required=False, - default="CLIPTextEncoderL.safetensors", - help="Path for the output file", - ) - args = parser.parse_args() - src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).text_encoder # type: ignore - tensors = convert(src_model=src_model) # type: ignore - save_to_safetensors(path=args.output_file, tensors=tensors) - - -if __name__ == "__main__": - main() diff --git a/scripts/convert-informative-drawings-weights.py b/scripts/convert-informative-drawings-weights.py deleted file mode 100644 index 49034df..0000000 --- a/scripts/convert-informative-drawings-weights.py +++ /dev/null @@ -1,57 +0,0 @@ -# Original weights can be found here: https://huggingface.co/spaces/carolineec/informativedrawings -# Code is at https://github.com/carolineec/informative-drawings -# Copy `model.py` in your `PYTHONPATH`. You can edit it to remove un-necessary code -# and imports if you want, we only need `Generator`. - -import torch - -from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict, save_to_safetensors - -from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings -from model import Generator # type: ignore - - -@torch.no_grad() -def convert(checkpoint: str, device: torch.device | str) -> dict[str, torch.Tensor]: - src_model = Generator(3, 1, 3) # type: ignore - src_model.load_state_dict(torch.load(checkpoint, map_location=device)) # type: ignore - src_model.eval() # type: ignore - - dst_model = InformativeDrawings() - - x = torch.randn(1, 3, 512, 512) - - mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore - assert mapping is not None, "Model conversion failed" - state_dict = convert_state_dict(source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping) # type: ignore - return {k: v.half() for k, v in state_dict.items()} - - -def main() -> None: - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "--from", - type=str, - dest="source", - required=False, - default="model2.pth", - help="Source model", - ) - parser.add_argument( - "--output-file", - type=str, - required=False, - default="informative-drawings.safetensors", - help="Path for the output file", - ) - args = parser.parse_args() - device = "cuda" if torch.cuda.is_available() else "cpu" - - tensors = convert(checkpoint=args.source, device=device) - save_to_safetensors(path=args.output_file, tensors=tensors) - - -if __name__ == "__main__": - main() diff --git a/scripts/convert-sd-lda-weights.py b/scripts/convert-sd-lda-weights.py deleted file mode 100644 index e800292..0000000 --- a/scripts/convert-sd-lda-weights.py +++ /dev/null @@ -1,53 +0,0 @@ -import torch - -from refiners.fluxion.utils import ( - create_state_dict_mapping, - convert_state_dict, - save_to_safetensors, -) - -from diffusers import DiffusionPipeline # type: ignore -from diffusers.models.autoencoder_kl import AutoencoderKL # type: ignore - -from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder - - -@torch.no_grad() -def convert(src_model: AutoencoderKL) -> dict[str, torch.Tensor]: - dst_model = LatentDiffusionAutoencoder() - x = torch.randn(1, 3, 512, 512) - mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore - assert mapping is not None, "Model conversion failed" - state_dict = convert_state_dict( - source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping - ) - return {k: v.half() for k, v in state_dict.items()} - - -def main() -> None: - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "--from", - type=str, - dest="source", - required=False, - default="runwayml/stable-diffusion-v1-5", - help="Source model", - ) - parser.add_argument( - "--output-file", - type=str, - required=False, - default="lda.safetensors", - help="Path for the output file", - ) - args = parser.parse_args() - src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).vae # type: ignore - tensors = convert(src_model=src_model) # type: ignore - save_to_safetensors(path=args.output_file, tensors=tensors) - - -if __name__ == "__main__": - main() diff --git a/scripts/convert-sd-unet-inpainting-weights.py b/scripts/convert-sd-unet-inpainting-weights.py deleted file mode 100644 index e4aa6c4..0000000 --- a/scripts/convert-sd-unet-inpainting-weights.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch - -from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict, save_to_safetensors - -from diffusers import StableDiffusionInpaintPipeline # type: ignore -from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore - -from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet - - -@torch.no_grad() -def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]: - dst_model = SD1UNet(in_channels=9, clip_embedding_dim=768) - - x = torch.randn(1, 9, 32, 32) - timestep = torch.tensor(data=[0]) - clip_text_embeddings = torch.randn(1, 77, 768) - - src_args = (x, timestep, clip_text_embeddings) - dst_model.set_timestep(timestep=timestep) - dst_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings) - dst_args = (x,) - - mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=src_args, target_args=dst_args) # type: ignore - assert mapping is not None, "Model conversion failed" - state_dict = convert_state_dict( - source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping - ) - return {k: v.half() for k, v in state_dict.items()} - - -def main() -> None: - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "--from", - type=str, - dest="source", - required=False, - default="runwayml/stable-diffusion-inpainting", - help="Source model", - ) - parser.add_argument( - "--output-file", - type=str, - required=False, - default="stable_diffusion_1_5_inpainting_unet.safetensors", - help="Path for the output file", - ) - args = parser.parse_args() - src_model = StableDiffusionInpaintPipeline.from_pretrained(pretrained_model_name_or_path=args.source).unet # type: ignore - tensors = convert(src_model=src_model) # type: ignore - save_to_safetensors(path=args.output_file, tensors=tensors) - - -if __name__ == "__main__": - main() diff --git a/scripts/convert-sd-unet-weights.py b/scripts/convert-sd-unet-weights.py deleted file mode 100644 index 232411a..0000000 --- a/scripts/convert-sd-unet-weights.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch - -from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict, save_to_safetensors - -from diffusers import DiffusionPipeline # type: ignore -from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore - -from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet - - -@torch.no_grad() -def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]: - dst_model = SD1UNet(in_channels=4, clip_embedding_dim=768) - - x = torch.randn(1, 4, 32, 32) - timestep = torch.tensor(data=[0]) - clip_text_embeddings = torch.randn(1, 77, 768) - - src_args = (x, timestep, clip_text_embeddings) - dst_model.set_timestep(timestep=timestep) - dst_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings) - dst_args = (x,) - - mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=src_args, target_args=dst_args) # type: ignore - assert mapping is not None, "Model conversion failed" - state_dict = convert_state_dict( - source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping - ) - return {k: v.half() for k, v in state_dict.items()} - - -def main() -> None: - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "--from", - type=str, - dest="source", - required=False, - default="runwayml/stable-diffusion-v1-5", - help="Source model", - ) - parser.add_argument( - "--output-file", - type=str, - required=False, - default="stable_diffusion_1_5_unet.safetensors", - help="Path for the output file", - ) - args = parser.parse_args() - src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).unet # type: ignore - tensors = convert(src_model=src_model) # type: ignore - save_to_safetensors(path=args.output_file, tensors=tensors) - - -if __name__ == "__main__": - main() diff --git a/scripts/convert-sdxl-text-encoder-2.py b/scripts/convert-sdxl-text-encoder-2.py deleted file mode 100644 index c1cee00..0000000 --- a/scripts/convert-sdxl-text-encoder-2.py +++ /dev/null @@ -1,57 +0,0 @@ -import torch - -from safetensors.torch import save_file # type: ignore -from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict - -from diffusers import DiffusionPipeline # type: ignore -from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore - -from refiners.foundationals.clip.tokenizer import CLIPTokenizer -from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG -import refiners.fluxion.layers as fl - - -@torch.no_grad() -def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]: - dst_model = CLIPTextEncoderG() - # Extra projection layer (see CLIPTextModelWithProjection in transformers) - dst_model.append(module=fl.Linear(in_features=1280, out_features=1280, bias=False)) - tokenizer = dst_model.find(layer_type=CLIPTokenizer) - assert tokenizer is not None, "Could not find tokenizer" - tokens = tokenizer("Nice cat") - mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[tokens], target_args=["Nice cat"]) # type: ignore - if mapping is None: - raise RuntimeError("Could not create state dict mapping") - state_dict = convert_state_dict( - source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping - ) - return {k: v.half() for k, v in state_dict.items()} - - -def main() -> None: - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "--from", - type=str, - dest="source", - required=False, - default="stabilityai/stable-diffusion-xl-base-0.9", - help="Source model", - ) - parser.add_argument( - "--output-file", - type=str, - required=False, - default="CLIPTextEncoderG.safetensors", - help="Path for the output file", - ) - args = parser.parse_args() - src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).text_encoder_2 # type: ignore - tensors = convert(src_model=src_model) # type: ignore - save_file(tensors=tensors, filename=args.output_file) - - -if __name__ == "__main__": - main() diff --git a/scripts/convert-sdxl-unet-weights.py b/scripts/convert-sdxl-unet-weights.py deleted file mode 100644 index ea61277..0000000 --- a/scripts/convert-sdxl-unet-weights.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch - -from safetensors.torch import save_file # type: ignore -from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict - -from diffusers import DiffusionPipeline # type: ignore -from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore - -from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet - - -@torch.no_grad() -def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]: - dst_model = SDXLUNet(in_channels=4) - - x = torch.randn(1, 4, 32, 32) - timestep = torch.tensor(data=[0]) - clip_text_embeddings = torch.randn(1, 77, 2048) - - added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)} - src_args = (x, timestep, clip_text_embeddings, None, None, None, None, added_cond_kwargs) - dst_model.set_timestep(timestep=timestep) - dst_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings) - dst_model.set_time_ids(time_ids=added_cond_kwargs["time_ids"]) - dst_model.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"]) - dst_args = (x,) - - mapping = create_state_dict_mapping( - source_model=src_model, target_model=dst_model, source_args=src_args, target_args=dst_args # type: ignore - ) - if mapping is None: - raise RuntimeError("Could not create state dict mapping") - state_dict = convert_state_dict( - source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping - ) - return {k: v for k, v in state_dict.items()} - - -def main() -> None: - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "--from", - type=str, - dest="source", - required=False, - default="stabilityai/stable-diffusion-xl-base-0.9", - help="Source model", - ) - parser.add_argument( - "--output-file", - type=str, - required=False, - default="stable_diffusion_xl_unet.safetensors", - help="Path for the output file", - ) - args = parser.parse_args() - src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).unet # type: ignore - tensors = convert(src_model=src_model) # type: ignore - save_file(tensors=tensors, filename=args.output_file) - - -if __name__ == "__main__": - main() diff --git a/src/refiners/fluxion/model_converter.py b/src/refiners/fluxion/model_converter.py new file mode 100644 index 0000000..b5c6a57 --- /dev/null +++ b/src/refiners/fluxion/model_converter.py @@ -0,0 +1,545 @@ +from collections import defaultdict +from enum import Enum, auto +from pathlib import Path +from torch import Tensor, nn +from torch.utils.hooks import RemovableHandle +import torch +from typing import Any, DefaultDict, TypedDict + +from refiners.fluxion.utils import norm, save_to_safetensors + +TORCH_BASIC_LAYERS: list[type[nn.Module]] = [ + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + nn.Linear, + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.LayerNorm, + nn.GroupNorm, + nn.Embedding, + nn.MaxPool2d, + nn.AvgPool2d, + nn.AdaptiveAvgPool2d, +] + + +ModelTypeShape = tuple[str, tuple[torch.Size, ...]] + + +class ModuleArgsDict(TypedDict): + """Represents positional and keyword arguments passed to a module. + + - `positional`: A tuple of positional arguments. + - `keyword`: A dictionary of keyword arguments. + """ + + positional: tuple[Any, ...] + keyword: dict[str, Any] + + +class ConversionStage(Enum): + """Represents the current stage of the conversion process. + + - `INIT`: The conversion process has not started. + - `BASIC_LAYERS_MATCH`: The source and target models have the same number of basic layers. + """ + + INIT = auto() + BASIC_LAYERS_MATCH = auto() + SHAPE_AND_LAYERS_MATCH = auto() + MODELS_OUTPUT_AGREE = auto() + + +class ModelConverter: + ModuleArgs = tuple[Any, ...] | dict[str, Any] | ModuleArgsDict + stage: ConversionStage = ConversionStage.INIT + _stored_mapping: dict[str, str] | None = None + + def __init__( + self, + source_model: nn.Module, + target_model: nn.Module, + source_keys_to_skip: list[str] | None = None, + target_keys_to_skip: list[str] | None = None, + custom_layer_mapping: dict[type[nn.Module], type[nn.Module]] | None = None, + threshold: float = 1e-5, + skip_output_check: bool = False, + verbose: bool = True, + ) -> None: + """ + Create a ModelConverter. + + - `source_model`: The model to convert from. + - `target_model`: The model to convert to. + - `source_keys_to_skip`: A list of keys to skip when tracing the source model. + - `target_keys_to_skip`: A list of keys to skip when tracing the target model. + - `custom_layer_mapping`: A dictionary mapping custom layer types between the source and target models. + - `threshold`: The threshold for comparing outputs between the source and target models. + - `skip_output_check`: Whether to skip comparing the outputs of the source and target models. + - `verbose`: Whether to print messages during the conversion process. + + The conversion process consists of three stages: + + 1. Verify that the source and target models have the same number of basic layers. + 2. Find matching shapes and layers between the source and target models. + 3. Convert the source model's state_dict to match the target model's state_dict. + 4. Compare the outputs of the source and target models. + + The conversion process can be run multiple times, and will resume from the last stage. + + ### Example: + ``` + converter = ModelConverter(source_model=source, target_model=target, threshold=0.1, verbose=False) + is_converted = converter(args) + if is_converted: + converter.save_to_safetensors(path="test.pt") + ``` + """ + self.source_model = source_model + self.target_model = target_model + self.source_keys_to_skip = source_keys_to_skip or [] + self.target_keys_to_skip = target_keys_to_skip or [] + self.custom_layer_mapping = custom_layer_mapping or {} + self.threshold = threshold + self.skip_output_check = skip_output_check + self.verbose = verbose + + def run(self, source_args: ModuleArgs, target_args: ModuleArgs | None = None) -> bool: + """ + Run the conversion process. + + - `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments, + a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args` + is not provided, these arguments will also be passed to the target model. + - `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments, + a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. + + ### Returns: + + - `True` if the conversion process is done and the models agree. + + The conversion process consists of three stages: + + 1. Verify that the source and target models have the same number of basic layers. + 2. Find matching shapes and layers between the source and target models. + 3. Convert the source model's state_dict to match the target model's state_dict. + 4. Compare the outputs of the source and target models. + + The conversion process can be run multiple times, and will resume from the last stage. + """ + if target_args is None: + target_args = source_args + + match self.stage: + case ConversionStage.MODELS_OUTPUT_AGREE: + self._log(message="Conversion is done: you can export the converted model using `save_to_safetensors`") + return True + + case ConversionStage.SHAPE_AND_LAYERS_MATCH if self._run_models_output_agree_stage(): + self.stage = ConversionStage.MODELS_OUTPUT_AGREE + return True + + case ConversionStage.BASIC_LAYERS_MATCH if self._run_basic_layers_match_stage( + source_args=source_args, target_args=target_args + ): + self.stage = ( + ConversionStage.SHAPE_AND_LAYERS_MATCH + if not self.skip_output_check + else ConversionStage.MODELS_OUTPUT_AGREE + ) + return self.run(source_args=source_args, target_args=target_args) + + case ConversionStage.INIT if self._run_init_stage(): + self.stage = ConversionStage.BASIC_LAYERS_MATCH + return self.run(source_args=source_args, target_args=target_args) + + case _: + return False + + def __repr__(self) -> str: + return ( + f"ModelConverter(source_model={self.source_model.__class__.__name__}," + f" target_model={self.target_model.__class__.__name__}, stage={self.stage})" + ) + + def __bool__(self) -> bool: + return self.stage == ConversionStage.MODELS_OUTPUT_AGREE + + def get_state_dict(self) -> dict[str, Tensor]: + """Get the converted state_dict.""" + if not self: + raise ValueError("The conversion process is not done yet. Run `converter(args)` first.") + return self.target_model.state_dict() + + def get_mapping(self) -> dict[str, str]: + """Get the mapping between the source and target models' state_dicts.""" + if not self: + raise ValueError("The conversion process is not done yet. Run `converter(args)` first.") + assert self._stored_mapping is not None, "Mapping is not stored" + return self._stored_mapping + + def save_to_safetensors(self, path: Path | str, metadata: dict[str, str] | None = None, half: bool = False) -> None: + """Save the converted model to a SafeTensors file. + + This method can only be called after the conversion process is done. + + - `path`: The path to save the converted model to. + - `metadata`: Metadata to save with the converted model. + - `half`: Whether to save the converted model as half precision. + + ### Raises: + - `ValueError` if the conversion process is not done yet. Run `converter(args)` first. + """ + if not self: + raise ValueError("The conversion process is not done yet. Run `converter(args)` first.") + state_dict = self.get_state_dict() + if half: + state_dict = {key: value.half() for key, value in state_dict.items()} + save_to_safetensors(path=path, tensors=state_dict, metadata=metadata) + + def map_state_dicts( + self, + source_args: ModuleArgs, + target_args: ModuleArgs | None = None, + ) -> dict[str, str] | None: + """ + Find a mapping between the source and target models' state_dicts. + + - `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments, + a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args` + is not provided, these arguments will also be passed to the target model. + - `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments, + a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. + + ### Returns: + - A dictionary mapping keys in the target model's state_dict to keys in the source model's state_dict. + """ + if target_args is None: + target_args = source_args + + source_order = self._trace_module_execution_order( + module=self.source_model, args=source_args, keys_to_skip=self.source_keys_to_skip + ) + target_order = self._trace_module_execution_order( + module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip + ) + + if not self._assert_shapes_aligned(source_order=source_order, target_order=target_order): + return None + + mapping: dict[str, str] = {} + for model_type_shape in source_order: + source_keys = source_order[model_type_shape] + target_keys = target_order[model_type_shape] + mapping.update(zip(target_keys, source_keys)) + + return mapping + + def compare_models( + self, + source_args: ModuleArgs, + target_args: ModuleArgs | None = None, + threshold: float = 1e-5, + ) -> bool: + """ + Compare the outputs of the source and target models. + + - `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments, + a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args` + is not provided, these arguments will also be passed to the target model. + - `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments, + a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. + - `threshold`: The threshold for comparing outputs between the source and target models. + """ + if target_args is None: + target_args = source_args + + source_outputs = self._collect_layers_outputs( + module=self.source_model, args=source_args, keys_to_skip=self.source_keys_to_skip + ) + target_outputs = self._collect_layers_outputs( + module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip + ) + + prev_source_key, prev_target_key = None, None + for (source_key, source_output), (target_key, target_output) in zip(source_outputs, target_outputs): + diff = norm(source_output - target_output).item() + if diff > threshold: + self._log( + f"Models diverged between {prev_source_key} and {source_key}, and between {prev_target_key} and" + f" {target_key}, difference in norm: {diff}" + ) + return False + prev_source_key, prev_target_key = source_key, target_key + + return True + + def _run_init_stage(self) -> bool: + """Run the init stage of the conversion process.""" + is_count_correct = self._verify_basic_layers_count() + is_not_missing_layers = self._verify_missing_basic_layers() + + return is_count_correct and is_not_missing_layers + + def _run_basic_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool: + """Run the basic layers match stage of the conversion process.""" + self._log(message="Finding matching shapes and layers...") + + mapping = self.map_state_dicts(source_args=source_args, target_args=target_args) + self._stored_mapping = mapping + if mapping is None: + self._log(message="Models do not have matching shapes.") + return False + + self._log(message="Found matching shapes and layers. Converting state_dict...") + + source_state_dict = self.source_model.state_dict() + target_state_dict = self.target_model.state_dict() + converted_state_dict = self._convert_state_dict( + source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping + ) + self.target_model.load_state_dict(state_dict=converted_state_dict) + + return True + + def _run_shape_and_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool: + """Run the shape and layers match stage of the conversion process.""" + if self.compare_models(source_args=source_args, target_args=target_args, threshold=self.threshold): + self._log(message="Models agree. You can export the converted model using `save_to_safetensors`") + return True + else: + self._log(message="Models do not agree. Try to increase the threshold or modify the models.") + return False + + def _run_models_output_agree_stage(self) -> bool: + """Run the models output agree stage of the conversion process.""" + self._log(message="Conversion is done: you can export the converted model using `save_to_safetensors`") + return True + + def _log(self, message: str) -> None: + """Print a message if `verbose` is `True`.""" + if self.verbose: + print(message) + + def _debug_print_shapes( + self, + shape: ModelTypeShape, + source_keys: list[str], + target_keys: list[str], + ) -> None: + """Print the shapes of the sub-modules in `source_keys` and `target_keys`.""" + self._log(message=f"{shape}") + max_len = max(len(source_keys), len(target_keys)) + for i in range(max_len): + source_key = source_keys[i] if i < len(source_keys) else "---" + target_key = target_keys[i] if i < len(target_keys) else "---" + self._log(f"\t{source_key}\t{target_key}") + + @staticmethod + def _unpack_module_args(module_args: ModuleArgs) -> tuple[tuple[Any, ...], dict[str, Any]]: + """Unpack the positional and keyword arguments passed to a module.""" + match module_args: + case tuple(positional_args): + keyword_args: dict[str, Any] = {} + case {"positional": positional_args, "keyword": keyword_args}: + pass + case _: + positional_args = () + keyword_args = dict(**module_args) + + return positional_args, keyword_args + + def _infer_basic_layer_type(self, module: nn.Module) -> type[nn.Module] | None: + """Infer the type of a basic layer.""" + for layer_type in TORCH_BASIC_LAYERS: + if isinstance(module, layer_type): + return layer_type + + for source_type in self.custom_layer_mapping.keys(): + if isinstance(module, source_type): + return source_type + + return None + + def get_module_signature(self, module: nn.Module) -> ModelTypeShape: + """Get the signature of a module.""" + layer_type = self._infer_basic_layer_type(module=module) + assert layer_type is not None, f"Module {module} is not a basic layer" + param_shapes = [p.shape for p in module.parameters()] + return (str(object=layer_type), tuple(param_shapes)) + + def _count_basic_layers(self, module: nn.Module) -> dict[type[nn.Module], int]: + """Count the number of basic layers in a module.""" + count: DefaultDict[type[nn.Module], int] = defaultdict(int) + for submodule in module.modules(): + layer_type = self._infer_basic_layer_type(module=submodule) + if layer_type is not None: + count[layer_type] += 1 + + return count + + def _verify_basic_layers_count(self) -> bool: + """Verify that the source and target models have the same number of basic layers.""" + source_layers = self._count_basic_layers(module=self.source_model) + target_layers = self._count_basic_layers(module=self.target_model) + + diff: dict[type[nn.Module], tuple[int, int]] = {} + for layer_type in set(source_layers.keys()) | set(target_layers.keys()): + source_count = source_layers.get(layer_type, 0) + target_count = target_layers.get(layer_type, 0) + if source_count != target_count: + diff[layer_type] = (source_count, target_count) + + if diff: + message = "Models do not have the same number of basic layers:\n" + for layer_type, counts in diff.items(): + message += f" {layer_type}: Source {counts[0]} - Target {counts[1]}\n" + self._log(message=message.rstrip()) + return False + + return True + + def _is_weighted_leaf_module(self, module: nn.Module) -> bool: + """Check if a module is a leaf module with weights.""" + return next(module.parameters(), None) is not None and next(module.children(), None) is None + + def _check_for_missing_basic_layers(self, module: nn.Module) -> list[type[nn.Module]]: + """Check if a module has weighted leaf modules that are not basic layers.""" + return [ + type(submodule) + for submodule in module.modules() + if self._is_weighted_leaf_module(module=submodule) and not self._infer_basic_layer_type(module=submodule) + ] + + def _verify_missing_basic_layers(self) -> bool: + """Verify that the source and target models do not have missing basic layers.""" + missing_source_layers = self._check_for_missing_basic_layers(module=self.source_model) + missing_target_layers = self._check_for_missing_basic_layers(module=self.target_model) + + if missing_source_layers or missing_target_layers: + self._log( + message=( + "Models might have missing basic layers. You can either pass them into keys to skip or set" + f" `check_missing_basic_layer` to `False`: {missing_source_layers}, {missing_target_layers}" + ) + ) + return False + + return True + + @torch.no_grad() + def _trace_module_execution_order( + self, + module: nn.Module, + args: ModuleArgs, + keys_to_skip: list[str], + ) -> dict[ModelTypeShape, list[str]]: + """ + Execute a forward pass and store the order of execution of specific sub-modules. + + - `module`: The module to trace. + - `args`: The arguments to pass to the module it can be either a tuple of positional arguments, + a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. + - `keys_to_skip`: A list of keys to skip when tracing the module. + + ### Returns: + - A dictionary mapping the signature of each sub-module to a list of keys in the module's `named_modules` + """ + submodule_to_key: dict[nn.Module, str] = {} + execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list) + + def collect_execution_order_hook(layer: nn.Module, *_: Any) -> None: + layer_signature = self.get_module_signature(module=layer) + execution_order[layer_signature].append(submodule_to_key[layer]) + + hooks: list[RemovableHandle] = [] + named_modules: list[tuple[str, nn.Module]] = module.named_modules() # type: ignore + for name, submodule in named_modules: + if (self._infer_basic_layer_type(module=submodule) is not None) and name not in keys_to_skip: + submodule_to_key[submodule] = name # type: ignore + hook = submodule.register_forward_hook(hook=collect_execution_order_hook) + hooks.append(hook) + + positional_args, keyword_args = self._unpack_module_args(module_args=args) + module(*positional_args, **keyword_args) + + for hook in hooks: + hook.remove() + + return dict(execution_order) + + def _assert_shapes_aligned( + self, source_order: dict[ModelTypeShape, list[str]], target_order: dict[ModelTypeShape, list[str]] + ) -> bool: + """Assert that the shapes of the sub-modules in `source_order` and `target_order` are aligned.""" + model_type_shapes = set(source_order.keys()) | set(target_order.keys()) + shape_missmatched = False + + for model_type_shape in model_type_shapes: + source_keys = source_order.get(model_type_shape, []) + target_keys = target_order.get(model_type_shape, []) + + if len(source_keys) != len(target_keys): + shape_missmatched = True + self._debug_print_shapes(shape=model_type_shape, source_keys=source_keys, target_keys=target_keys) + + return not shape_missmatched + + @staticmethod + def _convert_state_dict( + source_state_dict: dict[str, Tensor], target_state_dict: dict[str, Tensor], state_dict_mapping: dict[str, str] + ) -> dict[str, Tensor]: + """Convert the source model's state_dict to match the target model's state_dict.""" + converted_state_dict: dict[str, Tensor] = {} + for target_key in target_state_dict: + target_prefix, suffix = target_key.rsplit(sep=".", maxsplit=1) + source_prefix = state_dict_mapping[target_prefix] + source_key = ".".join([source_prefix, suffix]) + converted_state_dict[target_key] = source_state_dict[source_key] + + return converted_state_dict + + @torch.no_grad() + def _collect_layers_outputs( + self, module: nn.Module, args: ModuleArgs, keys_to_skip: list[str] + ) -> list[tuple[str, Tensor]]: + """ + Execute a forward pass and store the output of specific sub-modules. + + - `module`: The module to trace. + - `args`: The arguments to pass to the module it can be either a tuple of positional arguments, + a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. + - `keys_to_skip`: A list of keys to skip when tracing the module. + + ### Returns: + - A list of tuples containing the key of each sub-module and its output. + + ### Note: + - The output of each sub-module is cloned to avoid memory leaks. + """ + submodule_to_key: dict[nn.Module, str] = {} + execution_order: list[tuple[str, Tensor]] = [] + + def collect_execution_order_hook(layer: nn.Module, _: Any, output: Tensor) -> None: + execution_order.append((submodule_to_key[layer], output.clone())) + + hooks: list[RemovableHandle] = [] + named_modules: list[tuple[str, nn.Module]] = module.named_modules() # type: ignore + for name, submodule in named_modules: + if (self._infer_basic_layer_type(module=submodule) is not None) and name not in keys_to_skip: + submodule_to_key[submodule] = name # type: ignore + hook = submodule.register_forward_hook(hook=collect_execution_order_hook) + hooks.append(hook) + + positional_args, keyword_args = self._unpack_module_args(module_args=args) + module(*positional_args, **keyword_args) + + for hook in hooks: + hook.remove() + + return execution_order diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index 75061ce..2b4fcc9 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -1,18 +1,13 @@ -from collections import defaultdict -from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Literal, TypeVar +from typing import Dict, Iterable, Literal, TypeVar from PIL import Image from numpy import array, float32 from pathlib import Path from safetensors import safe_open as _safe_open # type: ignore from safetensors.torch import save_file as _save_file # type: ignore from torch import norm as _norm, manual_seed as _manual_seed # type: ignore +import torch from torch.nn.functional import pad as _pad, interpolate as _interpolate # type: ignore -from torch import Size, Tensor, tensor, no_grad, device as Device, dtype as DType, nn -from torch.utils.hooks import RemovableHandle - -if TYPE_CHECKING: - from refiners.fluxion.layers.module import Module +from torch import Tensor, device as Device, dtype as DType T = TypeVar("T") @@ -31,7 +26,7 @@ def pad(x: Tensor, pad: Iterable[int], value: float = 0.0) -> Tensor: return _pad(input=x, pad=pad, value=value) # type: ignore -def interpolate(x: Tensor, factor: float | Size, mode: str = "nearest") -> Tensor: +def interpolate(x: Tensor, factor: float | torch.Size, mode: str = "nearest") -> Tensor: return ( _interpolate(x, scale_factor=factor, mode=mode) if isinstance(factor, float | int) @@ -44,7 +39,9 @@ def bidirectional_mapping(mapping: Dict[str, str]) -> Dict[str, str]: def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor: - return tensor(array(image).astype(float32).transpose(2, 0, 1) / 255.0, device=device, dtype=dtype).unsqueeze(0) + return torch.tensor(array(image).astype(float32).transpose(2, 0, 1) / 255.0, device=device, dtype=dtype).unsqueeze( + 0 + ) def tensor_to_image(tensor: Tensor) -> Image.Image: @@ -77,220 +74,3 @@ def load_metadata_from_safetensors(path: Path | str) -> dict[str, str] | None: def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None: _save_file(tensors, path, metadata) # type: ignore - - -class BasicLayers(Enum): - Conv1d = nn.Conv1d - Conv2d = nn.Conv2d - Conv3d = nn.Conv3d - ConvTranspose1d = nn.ConvTranspose1d - ConvTranspose2d = nn.ConvTranspose2d - ConvTranspose3d = nn.ConvTranspose3d - Linear = nn.Linear - BatchNorm1d = nn.BatchNorm1d - BatchNorm2d = nn.BatchNorm2d - BatchNorm3d = nn.BatchNorm3d - LayerNorm = nn.LayerNorm - GroupNorm = nn.GroupNorm - Embedding = nn.Embedding - MaxPool2d = nn.MaxPool2d - AvgPool2d = nn.AvgPool2d - AdaptiveAvgPool2d = nn.AdaptiveAvgPool2d - - -ModelTypeShape = tuple[str, tuple[Size, ...]] - - -def infer_basic_layer_type(module: nn.Module) -> BasicLayers | None: - """Identify if the provided module matches any in the BasicLayers enum.""" - for layer_type in BasicLayers: - if isinstance(module, layer_type.value): - return layer_type - return None - - -def get_module_signature(module: "Module") -> ModelTypeShape: - """Return a tuple representing the module's type and parameter shapes.""" - layer_type = infer_basic_layer_type(module=module) - assert layer_type is not None, f"Module {module} is not a basic layer" - param_shapes = [p.shape for p in module.parameters()] - return (str(object=layer_type), tuple(param_shapes)) - - -def forward_order_of_execution( - module: "Module", - example_args: tuple[Any, ...], - key_skipper: Callable[[str], bool] | None = None, -) -> dict[ModelTypeShape, list[str]]: - """ - Determine the execution order of sub-modules during a forward pass. - - Optionally skips specific modules using `key_skipper`. - """ - key_skipper = key_skipper or (lambda _: False) - - submodule_to_key: dict["Module", str] = {} - execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list) - - def collect_execution_order_hook(layer: "Module", *_: Any) -> None: - layer_signature = get_module_signature(module=layer) - execution_order[layer_signature].append(submodule_to_key[layer]) - - hooks: list[RemovableHandle] = [] - for name, submodule in module.named_modules(): - if (infer_basic_layer_type(module=submodule) is not None) and not key_skipper(name): - submodule_to_key[submodule] = name - hook = submodule.register_forward_hook(hook=collect_execution_order_hook) - hooks.append(hook) - - with no_grad(): - module(*example_args) - - for hook in hooks: - hook.remove() - - return dict(execution_order) - - -def print_side_by_side( - shape: ModelTypeShape, - source_keys: list[str], - target_keys: list[str], -) -> None: - """Print module keys side by side, useful for debugging shape mismatches.""" - print(f"{shape}") - max_len = max(len(source_keys), len(target_keys)) - for i in range(max_len): - source_key = source_keys[i] if i < len(source_keys) else "---" - target_key = target_keys[i] if i < len(target_keys) else "---" - print(f"\t{source_key}\t{target_key}") - - -def verify_shape_match( - source_order: dict[ModelTypeShape, list[str]], target_order: dict[ModelTypeShape, list[str]] -) -> bool: - """Check if the sub-modules in source and target have matching shapes.""" - model_type_shapes = set(source_order.keys()) | set(target_order.keys()) - shape_missmatched = False - - for model_type_shape in model_type_shapes: - source_keys = source_order.get(model_type_shape, []) - target_keys = target_order.get(model_type_shape, []) - - if len(source_keys) != len(target_keys): - shape_missmatched = True - print_side_by_side(shape=model_type_shape, source_keys=source_keys, target_keys=target_keys) - - return not shape_missmatched - - -def create_state_dict_mapping( - source_model: "Module", - target_model: "Module", - source_args: tuple[Any, ...], - target_args: tuple[Any, ...] | None = None, - source_key_skipper: Callable[[str], bool] | None = None, - target_key_skipper: Callable[[str], bool] | None = None, -) -> dict[str, str] | None: - """ - Create a mapping between state_dict keys of the source and target models. - - This facilitates the transfer of weights when architectures have slight differences. - """ - if target_args is None: - target_args = source_args - - source_order = forward_order_of_execution( - module=source_model, example_args=source_args, key_skipper=source_key_skipper - ) - target_order = forward_order_of_execution( - module=target_model, example_args=target_args, key_skipper=target_key_skipper - ) - - if not verify_shape_match(source_order=source_order, target_order=target_order): - return None - - mapping: dict[str, str] = {} - for model_type_shape in source_order: - source_keys = source_order[model_type_shape] - target_keys = target_order[model_type_shape] - mapping.update(zip(target_keys, source_keys)) - - return mapping - - -def convert_state_dict( - source_state_dict: dict[str, Tensor], target_state_dict: dict[str, Tensor], state_dict_mapping: dict[str, str] -) -> dict[str, Tensor]: - """Convert source state_dict based on the provided mapping to match target state_dict structure.""" - converted_state_dict: dict[str, Tensor] = {} - for target_key in target_state_dict: - target_prefix, suffix = target_key.rsplit(sep=".", maxsplit=1) - source_prefix = state_dict_mapping[target_prefix] - source_key = ".".join([source_prefix, suffix]) - converted_state_dict[target_key] = source_state_dict[source_key] - - return converted_state_dict - - -def forward_store_outputs( - module: "Module", - example_args: tuple[Any, ...], - key_skipper: Callable[[str], bool] | None = None, -) -> list[tuple[str, Tensor]]: - """Execute a forward pass and store outputs of specific sub-modules.""" - key_skipper = key_skipper or (lambda _: False) - submodule_to_key: dict["Module", str] = {} - execution_order: list[tuple[str, Tensor]] = [] # Store outputs in a list - - def collect_execution_order_hook(layer: "Module", _: Any, output: Tensor) -> None: - execution_order.append((submodule_to_key[layer], output.clone())) # Store a copy of the output - - hooks: list[RemovableHandle] = [] - for name, submodule in module.named_modules(): - if (infer_basic_layer_type(module=module) is not None) and not key_skipper(name): - submodule_to_key[submodule] = name - hook = submodule.register_forward_hook(hook=collect_execution_order_hook) - hooks.append(hook) - - with no_grad(): - module(*example_args) - - for hook in hooks: - hook.remove() - - return execution_order - - -def compare_models( - source_model: "Module", - target_model: "Module", - source_args: tuple[Any, ...], - target_args: tuple[Any, ...] | None = None, - source_key_skipper: Callable[[str], bool] | None = None, - target_key_skipper: Callable[[str], bool] | None = None, - threshold: float = 1e-5, -) -> bool: - """ - Compare the outputs of two models given the same inputs. - - Flag if any difference exceeds the given threshold. - """ - if target_args is None: - target_args = source_args - - source_order = forward_store_outputs(module=source_model, example_args=source_args, key_skipper=source_key_skipper) - target_order = forward_store_outputs(module=target_model, example_args=target_args, key_skipper=target_key_skipper) - - prev_source_key, prev_target_key = None, None - for (source_key, source_output), (target_key, target_output) in zip(source_order, target_order): - diff = norm(source_output - target_output).item() - if diff > threshold: - print( - f"Models diverged between {prev_source_key} and {source_key}, and between {prev_target_key} and" - f" {target_key}, difference in norm: {diff}" - ) - return False - prev_source_key, prev_target_key = source_key, target_key - - return True diff --git a/tests/foundationals/latent_diffusion/test_sdxl_unet.py b/tests/foundationals/latent_diffusion/test_sdxl_unet.py index 46a65e5..92af508 100644 --- a/tests/foundationals/latent_diffusion/test_sdxl_unet.py +++ b/tests/foundationals/latent_diffusion/test_sdxl_unet.py @@ -3,9 +3,10 @@ from pathlib import Path from warnings import warn import pytest import torch +from refiners.fluxion.utils import manual_seed -from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet -from refiners.fluxion.utils import compare_models +from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet +from refiners.fluxion.model_converter import ModelConverter @pytest.fixture(scope="module") @@ -47,23 +48,28 @@ def refiners_sdxl_unet(sdxl_unet_weights_std: Path) -> SDXLUNet: @torch.no_grad() def test_sdxl_unet(diffusers_sdxl_unet: Any, refiners_sdxl_unet: SDXLUNet) -> None: - torch.manual_seed(seed=0) # type: ignore + source = diffusers_sdxl_unet + target = refiners_sdxl_unet + + manual_seed(seed=0) x = torch.randn(1, 4, 32, 32) timestep = torch.tensor(data=[0]) clip_text_embeddings = torch.randn(1, 77, 2048) added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)} - source_args = (x, timestep, clip_text_embeddings, None, None, None, None, added_cond_kwargs) - refiners_sdxl_unet.set_timestep(timestep=timestep) - refiners_sdxl_unet.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings) - refiners_sdxl_unet.set_time_ids(time_ids=added_cond_kwargs["time_ids"]) - refiners_sdxl_unet.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"]) + target.set_timestep(timestep=timestep) + target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings) + target.set_time_ids(time_ids=added_cond_kwargs["time_ids"]) + target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"]) target_args = (x,) + source_args = { + "positional": (x, timestep, clip_text_embeddings), + "keyword": {"added_cond_kwargs": added_cond_kwargs}, + } - assert compare_models( - source_model=diffusers_sdxl_unet, - target_model=refiners_sdxl_unet, + converter = ModelConverter(source_model=source, target_model=target, verbose=False, threshold=1e-2) + + assert converter.run( source_args=source_args, target_args=target_args, - threshold=1e-2, )