implement the ConvertModule class and refactor conversion scripts

This commit is contained in:
limiteinductive 2023-08-24 02:26:37 +02:00 committed by Benjamin Trom
parent 3680f9d196
commit 7ca6bd0ccd
18 changed files with 1029 additions and 746 deletions

View file

@ -212,9 +212,9 @@ Here is how to perform a text-to-image inference using the Stable Diffusion 1.5
Step 1: prepare the model weights in refiners' format:
```bash
python scripts/convert-clip-weights.py --output-file CLIPTextEncoderL.safetensors
python scripts/convert-sd-lda-weights.py --output-file lda.safetensors
python scripts/convert-sd-unet-weights.py --output-file unet.safetensors
python scripts/conversion/convert_transformers_clip_text_model.py --to clip.safetensors
python scripts/conversion/convert_diffusers_autoencoder_kl.py --to lda.safetensors
python scripts/conversion/convert_diffusers_unet.py --to unet.safetensors
```
> Note: this will download the original weights from https://huggingface.co/runwayml/stable-diffusion-v1-5 which takes some time. If you already have this repo cloned locally, use the `--from /path/to/stable-diffusion-v1-5` option instead.
@ -223,9 +223,9 @@ Step 2: download and convert a community Pokemon LoRA, e.g. [this one](https://h
```bash
curl -LO https://huggingface.co/pcuenq/pokemon-lora/resolve/main/pytorch_lora_weights.bin
python scripts/convert-lora-weights.py \
python scripts/conversion/convert_diffusers_lora.py \
--from pytorch_lora_weights.bin \
--output-file pokemon_lora.safetensors
--to pokemon_lora.safetensors
```
Step 3: run inference using the GPU:
@ -238,7 +238,7 @@ import torch
sd15 = StableDiffusion_1(device="cuda")
sd15.clip_text_encoder.load_state_dict(load_from_safetensors("CLIPTextEncoderL.safetensors"))
sd15.clip_text_encoder.load_state_dict(load_from_safetensors("clip.safetensors"))
sd15.lda.load_state_dict(load_from_safetensors("lda.safetensors"))
sd15.unet.load_state_dict(load_from_safetensors("unet.safetensors"))

View file

@ -0,0 +1,67 @@
import argparse
from pathlib import Path
import torch
from torch import nn
from diffusers import AutoencoderKL # type: ignore
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
from refiners.fluxion.model_converter import ModelConverter
class Args(argparse.Namespace):
source_path: str
output_path: str | None
use_half: bool
verbose: bool
def setup_converter(args: Args) -> ModelConverter:
target = LatentDiffusionAutoencoder()
source: nn.Module = AutoencoderKL.from_pretrained(pretrained_model_name_or_path=args.source_path, subfolder="vae") # type: ignore
x = torch.randn(1, 3, 512, 512)
converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose)
if not converter.run(source_args=(x,)):
raise RuntimeError("Model conversion failed")
return converter
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert a pretrained diffusers AutoencoderKL model to a refiners Latent Diffusion Autoencoder"
)
parser.add_argument(
"--from",
type=str,
dest="source_path",
default="runwayml/stable-diffusion-v1-5",
help="Path to the source pretrained model (default: 'runwayml/stable-diffusion-v1-5').",
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default=None,
help=(
"Path to save the converted model (extension will be .safetensors). If not specified, the output path will"
" be the source path with the extension changed to .safetensors."
),
)
parser.add_argument(
"--half",
action="store_true",
dest="use_half",
default=False,
help="Use this flag to save the output file as half precision (default: full precision).",
)
parser.add_argument(
"--verbose",
action="store_true",
dest="verbose",
default=False,
help="Use this flag to print verbose output during conversion.",
)
args = parser.parse_args(namespace=Args())
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}-autoencoder.safetensors"
assert args.output_path is not None
converter = setup_converter(args=args)
converter.save_to_safetensors(path=args.output_path, half=args.use_half)

View file

@ -1,18 +1,24 @@
# pyright: reportPrivateUsage=false
import argparse
from pathlib import Path
import torch
from torch import nn
from diffusers import ControlNetModel # type: ignore
from refiners.fluxion.utils import (
forward_order_of_execution,
verify_shape_match,
convert_state_dict,
save_to_safetensors,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet
from refiners.fluxion.utils import save_to_safetensors
from refiners.fluxion.model_converter import ModelConverter
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1Controlnet
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
from refiners.foundationals.latent_diffusion import SD1UNet
class Args(argparse.Namespace):
source_path: str
output_path: str | None
@torch.no_grad()
def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
def convert(args: Args) -> dict[str, torch.Tensor]:
controlnet_src: nn.Module = ControlNetModel.from_pretrained(pretrained_model_name_or_path=args.source_path) # type: ignore
controlnet = SD1Controlnet(name="mycn")
condition = torch.randn(1, 3, 512, 512)
@ -33,10 +39,16 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
# to diffusers in order, since we compute the residuals inline instead of
# in a separate step.
source_order = forward_order_of_execution(module=controlnet_src, example_args=(x, timestep, clip_text_embedding, condition)) # type: ignore
target_order = forward_order_of_execution(module=controlnet, example_args=(x,))
converter = ModelConverter(
source_model=controlnet_src, target_model=controlnet, skip_output_check=True, verbose=False
)
broken_k = ("Conv2d", (torch.Size([320, 320, 1, 1]), torch.Size([320])))
source_order = converter._trace_module_execution_order(
module=controlnet_src, args=(x, timestep, clip_text_embedding, condition), keys_to_skip=[]
)
target_order = converter._trace_module_execution_order(module=controlnet, args=(x,), keys_to_skip=[])
broken_k = (str(object=nn.Conv2d), (torch.Size([320, 320, 1, 1]), torch.Size([320])))
expected_source_order = [
"down_blocks.0.attentions.0.proj_in",
@ -75,7 +87,7 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
assert target_order[broken_k] == expected_target_order
source_order[broken_k] = fixed_source_order
broken_k = ("Conv2d", (torch.Size([640, 640, 1, 1]), torch.Size([640])))
broken_k = (str(object=nn.Conv2d), (torch.Size([640, 640, 1, 1]), torch.Size([640])))
expected_source_order = [
"down_blocks.1.attentions.0.proj_in",
@ -111,7 +123,7 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
assert target_order[broken_k] == expected_target_order
source_order[broken_k] = fixed_source_order
broken_k = ("Conv2d", (torch.Size([1280, 1280, 1, 1]), torch.Size([1280])))
broken_k = (str(object=nn.Conv2d), (torch.Size([1280, 1280, 1, 1]), torch.Size([1280])))
expected_source_order = [
"down_blocks.2.attentions.0.proj_in",
@ -162,7 +174,7 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
assert target_order[broken_k] == expected_target_order
source_order[broken_k] = fixed_source_order
assert verify_shape_match(source_order=source_order, target_order=target_order)
assert converter._assert_shapes_aligned(source_order=source_order, target_order=target_order), "Shapes not aligned"
mapping: dict[str, str] = {}
for model_type_shape in source_order:
@ -170,7 +182,7 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
target_keys = target_order[model_type_shape]
mapping.update(zip(target_keys, source_keys))
state_dict = convert_state_dict(
state_dict = converter._convert_state_dict(
source_state_dict=controlnet_src.state_dict(),
target_state_dict=controlnet.state_dict(),
state_dict_mapping=mapping,
@ -180,27 +192,33 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
def main() -> None:
import argparse
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(description="Convert a diffusers ControlNet model to a Refiners ControlNet model")
parser.add_argument(
"--from",
type=str,
dest="source",
required=True,
help="Source model",
dest="source_path",
default="lllyasviel/sd-controlnet-depth",
help=(
"Can be a path to a .bin, a .safetensors file, or a model identifier from Hugging Face Hub. Defaults to"
" lllyasviel/sd-controlnet-depth"
),
)
parser.add_argument(
"--output-file",
"--to",
type=str,
dest="output_path",
required=False,
default="output.safetensors",
help="Path for the output file",
default=None,
help=(
"Output path (.safetensors) for converted model. If not provided, the output path will be the same as the"
" source path."
),
)
args = parser.parse_args()
controlnet_src = ControlNetModel.from_pretrained(pretrained_model_name_or_path=args.source) # type: ignore
tensors = convert(controlnet_src=controlnet_src) # type: ignore
save_to_safetensors(path=args.output_file, tensors=tensors)
args = parser.parse_args(namespace=Args())
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}-controlnet.safetensors"
state_dict = convert(args=args)
save_to_safetensors(path=args.output_path, tensors=state_dict)
if __name__ == "__main__":

View file

@ -1,20 +1,17 @@
# Note: this conversion script currently only support simple LoRAs which adapt
# the UNet's attentions such as https://huggingface.co/pcuenq/pokemon-lora
import argparse
from pathlib import Path
from typing import cast
import torch
from torch import Tensor
from torch.nn.init import zeros_
from torch.nn import Parameter as TorchParameter
from diffusers import DiffusionPipeline # type: ignore
import refiners.fluxion.layers as fl
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import save_to_safetensors
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from refiners.foundationals.latent_diffusion.lora import LoraTarget, apply_loras_to_target
from refiners.adapters.lora import Lora
from refiners.fluxion.utils import create_state_dict_mapping
from diffusers import DiffusionPipeline # type: ignore
def get_weight(linear: fl.Linear) -> torch.Tensor:
@ -31,10 +28,17 @@ def build_loras_safetensors(module: fl.Chain, key_prefix: str) -> dict[str, torc
return {f"{key_prefix}{i:03d}": w for i, w in enumerate(iterable=weights)}
class Args(argparse.Namespace):
source_path: str
base_model: str
output_file: str
verbose: bool
@torch.no_grad()
def process(source: str, base_model: str, output_file: str) -> None:
diffusers_state_dict = torch.load(source, map_location="cpu") # type: ignore
diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=base_model) # type: ignore
def process(args: Args) -> None:
diffusers_state_dict = cast(dict[str, Tensor], torch.load(args.source_path, map_location="cpu")) # type: ignore
diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.base_model) # type: ignore
diffusers_model = cast(fl.Module, diffusers_sd.unet) # type: ignore
refiners_model = SD1UNet(in_channels=4, clip_embedding_dim=768)
@ -54,10 +58,16 @@ def process(source: str, base_model: str, output_file: str) -> None:
diffusers_args = (x, timestep, clip_text_embeddings)
diffusers_to_refiners = create_state_dict_mapping(
source_model=refiners_model, target_model=diffusers_model, source_args=refiners_args, target_args=diffusers_args
converter = ModelConverter(
source_model=refiners_model, target_model=diffusers_model, skip_output_check=True, verbose=args.verbose
)
assert diffusers_to_refiners is not None, "Model conversion failed"
if not converter.run(
source_args=refiners_args,
target_args=diffusers_args,
):
raise RuntimeError("Model conversion failed")
diffusers_to_refiners = converter.get_mapping()
apply_loras_to_target(module=refiners_model, target=LoraTarget(target), rank=rank, scale=1.0)
for layer in refiners_model.layers(layer_type=Lora):
@ -83,36 +93,47 @@ def process(source: str, base_model: str, output_file: str) -> None:
state_dict = build_loras_safetensors(module=refiners_model, key_prefix="unet.")
assert len(state_dict) == 320
save_to_safetensors(path=output_file, tensors=state_dict, metadata=metadata)
save_to_safetensors(path=args.output_path, tensors=state_dict, metadata=metadata)
def main() -> None:
import argparse
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(description="Convert LoRAs saved using the diffusers library to refiners format.")
parser.add_argument(
"--from",
type=str,
dest="source",
dest="source_path",
required=True,
help="Source file path (.bin)",
help="Source file path (.bin|safetensors) containing the LoRAs.",
)
parser.add_argument(
"--base-model",
type=str,
required=False,
default="runwayml/stable-diffusion-v1-5",
help="Base model",
help="Base model, used for the UNet structure. Default: runwayml/stable-diffusion-v1-5",
)
parser.add_argument(
"--output-file",
"--to",
type=str,
dest="output_path",
required=False,
default="output.safetensors",
help="Path for the output file",
default=None,
help=(
"Output file path (.safetensors) for converted LoRAs. If not provided, the output path will be the same as"
" the source path."
),
)
args = parser.parse_args()
process(source=args.source, base_model=args.base_model, output_file=args.output_file)
parser.add_argument(
"--verbose",
action="store_true",
dest="verbose",
default=False,
help="Use this flag to print verbose output during conversion.",
)
args = parser.parse_args(namespace=Args())
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}-refiners.safetensors"
process(args=args)
if __name__ == "__main__":

View file

@ -0,0 +1,96 @@
import argparse
from pathlib import Path
import torch
from torch import nn
from refiners.fluxion.model_converter import ModelConverter
from diffusers import UNet2DConditionModel # type: ignore
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
class Args(argparse.Namespace):
source_path: str
output_path: str | None
half: bool
verbose: bool
def setup_converter(args: Args) -> ModelConverter:
source: nn.Module = UNet2DConditionModel.from_pretrained( # type: ignore
pretrained_model_name_or_path=args.source_path, subfolder="unet"
)
source_in_channels: int = source.config.in_channels # type: ignore
source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore
source_has_time_ids: bool = source.config.addition_embed_type == "text_time" # type: ignore
target = (
SDXLUNet(in_channels=source_in_channels)
if source_has_time_ids
else SD1UNet(in_channels=source_in_channels, clip_embedding_dim=source_clip_embedding_dim)
)
x = torch.randn(1, source_in_channels, 32, 32)
timestep = torch.tensor(data=[0])
clip_text_embeddings = torch.randn(1, 77, source_clip_embedding_dim)
target.set_timestep(timestep=timestep)
target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
added_cond_kwargs = {}
if source_has_time_ids:
added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)}
target.set_time_ids(time_ids=added_cond_kwargs["time_ids"])
target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
target_args = (x,)
source_args = {
"positional": (x, timestep, clip_text_embeddings),
"keyword": {"added_cond_kwargs": added_cond_kwargs} if source_has_time_ids else {},
}
converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose)
if not converter.run(
source_args=source_args,
target_args=target_args,
):
raise RuntimeError("Model conversion failed")
return converter
def main() -> None:
parser = argparse.ArgumentParser(
description="Converts a Diffusion UNet model to a Refiners SD1UNet or SDXLUNet model"
)
parser.add_argument(
"--from",
type=str,
dest="source_path",
default="runwayml/stable-diffusion-v1-5",
help=(
"Can be a path to a .bin file, a .safetensors file or a model name from the HuggingFace Hub. Default:"
" runwayml/stable-diffusion-v1-5"
),
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default=None,
help=(
"Output path (.safetensors) for converted model. If not provided, the output path will be the same as the"
" source path."
),
)
parser.add_argument("--half", action="store_true", default=True, help="Convert to half precision. Default: True")
parser.add_argument(
"--verbose",
action="store_true",
default=False,
help="Prints additional information during conversion. Default: False",
)
args = parser.parse_args(namespace=Args())
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}-unet.safetensors"
converter = setup_converter(args=args)
converter.save_to_safetensors(path=args.output_path, half=args.half)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,64 @@
import argparse
from typing import TYPE_CHECKING, cast
import torch
from torch import nn
from refiners.fluxion.model_converter import ModelConverter
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
try:
from model import Generator # type: ignore
except ImportError:
raise ImportError(
"Please download the model.py file from https://github.com/carolineec/informative-drawings and add it to your"
" PYTHONPATH"
)
if TYPE_CHECKING:
Generator = cast(nn.Module, Generator)
class Args(argparse.Namespace):
source_path: str
output_path: str
verbose: bool
half: bool
def setup_converter(args: Args) -> ModelConverter:
source = Generator(3, 1, 3)
source.load_state_dict(state_dict=torch.load(f=args.source_path, map_location="cpu")) # type: ignore
source.eval()
target = InformativeDrawings()
x = torch.randn(1, 3, 512, 512)
converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose)
if not converter.run(source_args=(x,)):
raise RuntimeError("Model conversion failed")
return converter
def main() -> None:
parser = argparse.ArgumentParser(
description="Converts a pretrained Informative Drawings model to a refiners Informative Drawings model"
)
parser.add_argument(
"--from",
type=str,
dest="source_path",
default="model2.pth",
help="Path to the source model. (default: 'model2.pth').",
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default="informative-drawings.safetensors",
help="Path to save the converted model. (default: 'informative-drawings.safetensors').",
)
parser.add_argument("--verbose", action="store_true", dest="verbose")
parser.add_argument("--half", action="store_true", dest="half")
args = parser.parse_args(namespace=Args())
converter = setup_converter(args=args)
converter.save_to_safetensors(path=args.output_path, half=args.half)
if __name__ == "__main__":
main()

View file

@ -1,49 +1,36 @@
import argparse
from functools import partial
from torch import Tensor
from refiners.fluxion.utils import (
load_from_safetensors,
load_metadata_from_safetensors,
save_to_safetensors,
)
from convert_diffusers_unet import setup_converter as convert_unet, Args as UnetConversionArgs
from convert_transformers_clip_text_model import (
setup_converter as convert_text_encoder,
Args as TextEncoderConversionArgs,
)
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from refiners.foundationals.latent_diffusion.lora import LoraTarget
from refiners.fluxion.layers.module import Module
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1UNet
import refiners.fluxion.layers as fl
from refiners.fluxion.utils import create_state_dict_mapping
import torch
from diffusers import DiffusionPipeline # type: ignore
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
@torch.no_grad()
def create_unet_mapping(src_model: UNet2DConditionModel, dst_model: SD1UNet) -> dict[str, str] | None:
x = torch.randn(1, 4, 32, 32)
timestep = torch.tensor(data=[0])
clip_text_embeddings = torch.randn(1, 77, 768)
src_args = (x, timestep, clip_text_embeddings)
dst_model.set_timestep(timestep=timestep)
dst_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
dst_args = (x,)
return create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=src_args, target_args=dst_args) # type: ignore
def get_unet_mapping(source_path: str) -> dict[str, str]:
args = UnetConversionArgs(source_path=source_path, verbose=False)
return convert_unet(args=args).get_mapping()
@torch.no_grad()
def create_text_encoder_mapping(src_model: CLIPTextModel, dst_model: CLIPTextEncoderL) -> dict[str, str] | None:
tokenizer = dst_model.find(layer_type=CLIPTokenizer)
assert tokenizer is not None, "Could not find tokenizer"
tokens = tokenizer("Nice cat")
return create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[tokens], target_args=["Nice cat"]) # type: ignore
def get_text_encoder_mapping(source_path: str) -> dict[str, str]:
args = TextEncoderConversionArgs(source_path=source_path, subfolder="text_encoder", verbose=False)
return convert_text_encoder(
args=args,
).get_mapping()
def main() -> None:
import argparse
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(description="Converts a refiner's LoRA weights to SD-WebUI's LoRA weights")
parser.add_argument(
"-i",
"--input-file",
@ -55,7 +42,7 @@ def main() -> None:
"-o",
"--output-file",
type=str,
required=True,
default="sdwebui_loras.safetensors",
help="Path to the output file with sd-webui's LoRA weights (safetensors format)",
)
parser.add_argument(
@ -66,27 +53,22 @@ def main() -> None:
help="Path (preferred) or repository ID of Stable Diffusion 1.5 model (Hugging Face diffusers format)",
)
args = parser.parse_args()
metadata = load_metadata_from_safetensors(path=args.input_file)
assert metadata is not None
assert metadata is not None, f"Could not load metadata from {args.input_file}"
tensors = load_from_safetensors(path=args.input_file)
diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.sd15) # type: ignore
state_dict: dict[str, torch.Tensor] = {}
state_dict: dict[str, Tensor] = {}
for meta_key, meta_value in metadata.items():
match meta_key:
case "unet_targets":
src_model = diffusers_sd.unet # type: ignore
dst_model = SD1UNet(in_channels=4, clip_embedding_dim=768)
create_mapping = create_unet_mapping
model = SD1UNet(in_channels=4, clip_embedding_dim=768)
create_mapping = partial(get_unet_mapping, source_path=args.sd15)
key_prefix = "unet."
lora_prefix = "lora_unet_"
case "text_encoder_targets":
src_model = diffusers_sd.text_encoder # type: ignore
dst_model = CLIPTextEncoderL()
create_mapping = create_text_encoder_mapping
model = CLIPTextEncoderL()
create_mapping = partial(get_text_encoder_mapping, source_path=args.sd15)
key_prefix = "text_encoder."
lora_prefix = "lora_te_"
case "lda_targets":
@ -94,8 +76,8 @@ def main() -> None:
case _:
raise ValueError(f"Unexpected key in checkpoint metadata: {meta_key}")
submodule_to_key: dict[Module, str] = {}
for name, submodule in dst_model.named_modules():
submodule_to_key: dict[fl.Module, str] = {}
for name, submodule in model.named_modules():
submodule_to_key[submodule] = name
# SD-WebUI expects LoRA state dicts with keys derived from the diffusers format, e.g.:
@ -110,14 +92,14 @@ def main() -> None:
#
# [1]: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/394ffa7/extensions-builtin/Lora/lora.py#L158-L225
refiners_to_diffusers = create_mapping(src_model, dst_model) # type: ignore
refiners_to_diffusers = create_mapping()
assert refiners_to_diffusers is not None
# Compute the corresponding diffusers' keys where LoRA layers must be applied
lora_injection_points: list[str] = [
refiners_to_diffusers[submodule_to_key[linear]]
for target in [LoraTarget(t) for t in meta_value.split(sep=",")]
for layer in dst_model.layers(layer_type=target.get_class())
for layer in model.layers(layer_type=target.get_class())
for linear in layer.layers(layer_type=fl.Linear)
]

View file

@ -0,0 +1,104 @@
import argparse
from pathlib import Path
from torch import nn
from refiners.fluxion.model_converter import ModelConverter
from transformers import CLIPTextModelWithProjection # type: ignore
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
import refiners.fluxion.layers as fl
class Args(argparse.Namespace):
source_path: str
subfolder: str
output_path: str | None
use_half: bool
verbose: bool
def setup_converter(args: Args) -> ModelConverter:
source: nn.Module = CLIPTextModelWithProjection.from_pretrained( # type: ignore
pretrained_model_name_or_path=args.source_path, subfolder=args.subfolder
)
assert isinstance(source, nn.Module), "Source model is not a nn.Module"
architecture: str = source.config.architectures[0] # type: ignore
embedding_dim: int = source.config.hidden_size # type: ignore
projection_dim: int = source.config.projection_dim # type: ignore
num_layers: int = source.config.num_hidden_layers # type: ignore
num_attention_heads: int = source.config.num_attention_heads # type: ignore
feed_forward_dim: int = source.config.intermediate_size # type: ignore
use_quick_gelu: bool = source.config.hidden_act == "quick_gelu" # type: ignore
target = CLIPTextEncoder(
embedding_dim=embedding_dim,
num_layers=num_layers,
num_attention_heads=num_attention_heads,
feedforward_dim=feed_forward_dim,
use_quick_gelu=use_quick_gelu,
)
match architecture:
case "CLIPTextModel":
source.text_projection = fl.Identity()
case "CLIPTextModelWithProjection":
target.append(module=fl.Linear(in_features=embedding_dim, out_features=projection_dim, bias=False))
case _:
raise RuntimeError(f"Unsupported architecture: {architecture}")
text = "What a nice cat you have there!"
tokenizer = target.find(layer_type=CLIPTokenizer)
assert tokenizer is not None, "Could not find tokenizer"
tokens = tokenizer(text)
converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose)
if not converter.run(source_args=(tokens,), target_args=(text,)):
raise RuntimeError("Model conversion failed")
return converter
def main() -> None:
parser = argparse.ArgumentParser(
description="Converts a CLIPTextEncoder from the library transformers from the HuggingFace Hub to refiners."
)
parser.add_argument(
"--from",
type=str,
dest="source_path",
default="runwayml/stable-diffusion-v1-5",
help=(
"Can be a path to a .bin file, a .safetensors file or a model name from the HuggingFace Hub. Default:"
" runwayml/stable-diffusion-v1-5"
),
)
parser.add_argument(
"--subfolder",
type=str,
dest="subfolder",
default="text_encoder",
help=(
"Subfolder in the source path where the model is located inside the Hub. Default: text_encoder (for"
" CLIPTextModel)"
),
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default=None,
help=(
"Output path (.safetensors) for converted model. If not provided, the output path will be the same as the"
" source path."
),
)
parser.add_argument("--half", action="store_true", default=True, help="Convert to half precision. Default: True")
parser.add_argument(
"--verbose",
action="store_true",
default=False,
help="Prints additional information during conversion. Default: False",
)
args = parser.parse_args(namespace=Args())
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}-{args.subfolder}.safetensors"
converter = setup_converter(args=args)
converter.save_to_safetensors(path=args.output_path, half=args.half)
if __name__ == "__main__":
main()

View file

@ -1,52 +0,0 @@
import torch
from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict, save_to_safetensors
from diffusers import DiffusionPipeline # type: ignore
from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
@torch.no_grad()
def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]:
dst_model = CLIPTextEncoderL()
tokenizer = dst_model.find(layer_type=CLIPTokenizer)
assert tokenizer is not None, "Could not find tokenizer"
tokens = tokenizer("Nice cat")
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[tokens], target_args=["Nice cat"]) # type: ignore
assert mapping is not None, "Model conversion failed"
state_dict = convert_state_dict(
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
)
return {k: v.half() for k, v in state_dict.items()}
def main() -> None:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--from",
type=str,
dest="source",
required=False,
default="runwayml/stable-diffusion-v1-5",
help="Source model",
)
parser.add_argument(
"--output-file",
type=str,
required=False,
default="CLIPTextEncoderL.safetensors",
help="Path for the output file",
)
args = parser.parse_args()
src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).text_encoder # type: ignore
tensors = convert(src_model=src_model) # type: ignore
save_to_safetensors(path=args.output_file, tensors=tensors)
if __name__ == "__main__":
main()

View file

@ -1,57 +0,0 @@
# Original weights can be found here: https://huggingface.co/spaces/carolineec/informativedrawings
# Code is at https://github.com/carolineec/informative-drawings
# Copy `model.py` in your `PYTHONPATH`. You can edit it to remove un-necessary code
# and imports if you want, we only need `Generator`.
import torch
from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict, save_to_safetensors
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
from model import Generator # type: ignore
@torch.no_grad()
def convert(checkpoint: str, device: torch.device | str) -> dict[str, torch.Tensor]:
src_model = Generator(3, 1, 3) # type: ignore
src_model.load_state_dict(torch.load(checkpoint, map_location=device)) # type: ignore
src_model.eval() # type: ignore
dst_model = InformativeDrawings()
x = torch.randn(1, 3, 512, 512)
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore
assert mapping is not None, "Model conversion failed"
state_dict = convert_state_dict(source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping) # type: ignore
return {k: v.half() for k, v in state_dict.items()}
def main() -> None:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--from",
type=str,
dest="source",
required=False,
default="model2.pth",
help="Source model",
)
parser.add_argument(
"--output-file",
type=str,
required=False,
default="informative-drawings.safetensors",
help="Path for the output file",
)
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
tensors = convert(checkpoint=args.source, device=device)
save_to_safetensors(path=args.output_file, tensors=tensors)
if __name__ == "__main__":
main()

View file

@ -1,53 +0,0 @@
import torch
from refiners.fluxion.utils import (
create_state_dict_mapping,
convert_state_dict,
save_to_safetensors,
)
from diffusers import DiffusionPipeline # type: ignore
from diffusers.models.autoencoder_kl import AutoencoderKL # type: ignore
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
@torch.no_grad()
def convert(src_model: AutoencoderKL) -> dict[str, torch.Tensor]:
dst_model = LatentDiffusionAutoencoder()
x = torch.randn(1, 3, 512, 512)
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore
assert mapping is not None, "Model conversion failed"
state_dict = convert_state_dict(
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
)
return {k: v.half() for k, v in state_dict.items()}
def main() -> None:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--from",
type=str,
dest="source",
required=False,
default="runwayml/stable-diffusion-v1-5",
help="Source model",
)
parser.add_argument(
"--output-file",
type=str,
required=False,
default="lda.safetensors",
help="Path for the output file",
)
args = parser.parse_args()
src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).vae # type: ignore
tensors = convert(src_model=src_model) # type: ignore
save_to_safetensors(path=args.output_file, tensors=tensors)
if __name__ == "__main__":
main()

View file

@ -1,58 +0,0 @@
import torch
from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict, save_to_safetensors
from diffusers import StableDiffusionInpaintPipeline # type: ignore
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
@torch.no_grad()
def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
dst_model = SD1UNet(in_channels=9, clip_embedding_dim=768)
x = torch.randn(1, 9, 32, 32)
timestep = torch.tensor(data=[0])
clip_text_embeddings = torch.randn(1, 77, 768)
src_args = (x, timestep, clip_text_embeddings)
dst_model.set_timestep(timestep=timestep)
dst_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
dst_args = (x,)
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=src_args, target_args=dst_args) # type: ignore
assert mapping is not None, "Model conversion failed"
state_dict = convert_state_dict(
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
)
return {k: v.half() for k, v in state_dict.items()}
def main() -> None:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--from",
type=str,
dest="source",
required=False,
default="runwayml/stable-diffusion-inpainting",
help="Source model",
)
parser.add_argument(
"--output-file",
type=str,
required=False,
default="stable_diffusion_1_5_inpainting_unet.safetensors",
help="Path for the output file",
)
args = parser.parse_args()
src_model = StableDiffusionInpaintPipeline.from_pretrained(pretrained_model_name_or_path=args.source).unet # type: ignore
tensors = convert(src_model=src_model) # type: ignore
save_to_safetensors(path=args.output_file, tensors=tensors)
if __name__ == "__main__":
main()

View file

@ -1,58 +0,0 @@
import torch
from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict, save_to_safetensors
from diffusers import DiffusionPipeline # type: ignore
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
@torch.no_grad()
def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
dst_model = SD1UNet(in_channels=4, clip_embedding_dim=768)
x = torch.randn(1, 4, 32, 32)
timestep = torch.tensor(data=[0])
clip_text_embeddings = torch.randn(1, 77, 768)
src_args = (x, timestep, clip_text_embeddings)
dst_model.set_timestep(timestep=timestep)
dst_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
dst_args = (x,)
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=src_args, target_args=dst_args) # type: ignore
assert mapping is not None, "Model conversion failed"
state_dict = convert_state_dict(
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
)
return {k: v.half() for k, v in state_dict.items()}
def main() -> None:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--from",
type=str,
dest="source",
required=False,
default="runwayml/stable-diffusion-v1-5",
help="Source model",
)
parser.add_argument(
"--output-file",
type=str,
required=False,
default="stable_diffusion_1_5_unet.safetensors",
help="Path for the output file",
)
args = parser.parse_args()
src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).unet # type: ignore
tensors = convert(src_model=src_model) # type: ignore
save_to_safetensors(path=args.output_file, tensors=tensors)
if __name__ == "__main__":
main()

View file

@ -1,57 +0,0 @@
import torch
from safetensors.torch import save_file # type: ignore
from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict
from diffusers import DiffusionPipeline # type: ignore
from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG
import refiners.fluxion.layers as fl
@torch.no_grad()
def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]:
dst_model = CLIPTextEncoderG()
# Extra projection layer (see CLIPTextModelWithProjection in transformers)
dst_model.append(module=fl.Linear(in_features=1280, out_features=1280, bias=False))
tokenizer = dst_model.find(layer_type=CLIPTokenizer)
assert tokenizer is not None, "Could not find tokenizer"
tokens = tokenizer("Nice cat")
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[tokens], target_args=["Nice cat"]) # type: ignore
if mapping is None:
raise RuntimeError("Could not create state dict mapping")
state_dict = convert_state_dict(
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
)
return {k: v.half() for k, v in state_dict.items()}
def main() -> None:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--from",
type=str,
dest="source",
required=False,
default="stabilityai/stable-diffusion-xl-base-0.9",
help="Source model",
)
parser.add_argument(
"--output-file",
type=str,
required=False,
default="CLIPTextEncoderG.safetensors",
help="Path for the output file",
)
args = parser.parse_args()
src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).text_encoder_2 # type: ignore
tensors = convert(src_model=src_model) # type: ignore
save_file(tensors=tensors, filename=args.output_file)
if __name__ == "__main__":
main()

View file

@ -1,65 +0,0 @@
import torch
from safetensors.torch import save_file # type: ignore
from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict
from diffusers import DiffusionPipeline # type: ignore
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
@torch.no_grad()
def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
dst_model = SDXLUNet(in_channels=4)
x = torch.randn(1, 4, 32, 32)
timestep = torch.tensor(data=[0])
clip_text_embeddings = torch.randn(1, 77, 2048)
added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)}
src_args = (x, timestep, clip_text_embeddings, None, None, None, None, added_cond_kwargs)
dst_model.set_timestep(timestep=timestep)
dst_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
dst_model.set_time_ids(time_ids=added_cond_kwargs["time_ids"])
dst_model.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
dst_args = (x,)
mapping = create_state_dict_mapping(
source_model=src_model, target_model=dst_model, source_args=src_args, target_args=dst_args # type: ignore
)
if mapping is None:
raise RuntimeError("Could not create state dict mapping")
state_dict = convert_state_dict(
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
)
return {k: v for k, v in state_dict.items()}
def main() -> None:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--from",
type=str,
dest="source",
required=False,
default="stabilityai/stable-diffusion-xl-base-0.9",
help="Source model",
)
parser.add_argument(
"--output-file",
type=str,
required=False,
default="stable_diffusion_xl_unet.safetensors",
help="Path for the output file",
)
args = parser.parse_args()
src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).unet # type: ignore
tensors = convert(src_model=src_model) # type: ignore
save_file(tensors=tensors, filename=args.output_file)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,545 @@
from collections import defaultdict
from enum import Enum, auto
from pathlib import Path
from torch import Tensor, nn
from torch.utils.hooks import RemovableHandle
import torch
from typing import Any, DefaultDict, TypedDict
from refiners.fluxion.utils import norm, save_to_safetensors
TORCH_BASIC_LAYERS: list[type[nn.Module]] = [
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
nn.Linear,
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.LayerNorm,
nn.GroupNorm,
nn.Embedding,
nn.MaxPool2d,
nn.AvgPool2d,
nn.AdaptiveAvgPool2d,
]
ModelTypeShape = tuple[str, tuple[torch.Size, ...]]
class ModuleArgsDict(TypedDict):
"""Represents positional and keyword arguments passed to a module.
- `positional`: A tuple of positional arguments.
- `keyword`: A dictionary of keyword arguments.
"""
positional: tuple[Any, ...]
keyword: dict[str, Any]
class ConversionStage(Enum):
"""Represents the current stage of the conversion process.
- `INIT`: The conversion process has not started.
- `BASIC_LAYERS_MATCH`: The source and target models have the same number of basic layers.
"""
INIT = auto()
BASIC_LAYERS_MATCH = auto()
SHAPE_AND_LAYERS_MATCH = auto()
MODELS_OUTPUT_AGREE = auto()
class ModelConverter:
ModuleArgs = tuple[Any, ...] | dict[str, Any] | ModuleArgsDict
stage: ConversionStage = ConversionStage.INIT
_stored_mapping: dict[str, str] | None = None
def __init__(
self,
source_model: nn.Module,
target_model: nn.Module,
source_keys_to_skip: list[str] | None = None,
target_keys_to_skip: list[str] | None = None,
custom_layer_mapping: dict[type[nn.Module], type[nn.Module]] | None = None,
threshold: float = 1e-5,
skip_output_check: bool = False,
verbose: bool = True,
) -> None:
"""
Create a ModelConverter.
- `source_model`: The model to convert from.
- `target_model`: The model to convert to.
- `source_keys_to_skip`: A list of keys to skip when tracing the source model.
- `target_keys_to_skip`: A list of keys to skip when tracing the target model.
- `custom_layer_mapping`: A dictionary mapping custom layer types between the source and target models.
- `threshold`: The threshold for comparing outputs between the source and target models.
- `skip_output_check`: Whether to skip comparing the outputs of the source and target models.
- `verbose`: Whether to print messages during the conversion process.
The conversion process consists of three stages:
1. Verify that the source and target models have the same number of basic layers.
2. Find matching shapes and layers between the source and target models.
3. Convert the source model's state_dict to match the target model's state_dict.
4. Compare the outputs of the source and target models.
The conversion process can be run multiple times, and will resume from the last stage.
### Example:
```
converter = ModelConverter(source_model=source, target_model=target, threshold=0.1, verbose=False)
is_converted = converter(args)
if is_converted:
converter.save_to_safetensors(path="test.pt")
```
"""
self.source_model = source_model
self.target_model = target_model
self.source_keys_to_skip = source_keys_to_skip or []
self.target_keys_to_skip = target_keys_to_skip or []
self.custom_layer_mapping = custom_layer_mapping or {}
self.threshold = threshold
self.skip_output_check = skip_output_check
self.verbose = verbose
def run(self, source_args: ModuleArgs, target_args: ModuleArgs | None = None) -> bool:
"""
Run the conversion process.
- `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
is not provided, these arguments will also be passed to the target model.
- `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
### Returns:
- `True` if the conversion process is done and the models agree.
The conversion process consists of three stages:
1. Verify that the source and target models have the same number of basic layers.
2. Find matching shapes and layers between the source and target models.
3. Convert the source model's state_dict to match the target model's state_dict.
4. Compare the outputs of the source and target models.
The conversion process can be run multiple times, and will resume from the last stage.
"""
if target_args is None:
target_args = source_args
match self.stage:
case ConversionStage.MODELS_OUTPUT_AGREE:
self._log(message="Conversion is done: you can export the converted model using `save_to_safetensors`")
return True
case ConversionStage.SHAPE_AND_LAYERS_MATCH if self._run_models_output_agree_stage():
self.stage = ConversionStage.MODELS_OUTPUT_AGREE
return True
case ConversionStage.BASIC_LAYERS_MATCH if self._run_basic_layers_match_stage(
source_args=source_args, target_args=target_args
):
self.stage = (
ConversionStage.SHAPE_AND_LAYERS_MATCH
if not self.skip_output_check
else ConversionStage.MODELS_OUTPUT_AGREE
)
return self.run(source_args=source_args, target_args=target_args)
case ConversionStage.INIT if self._run_init_stage():
self.stage = ConversionStage.BASIC_LAYERS_MATCH
return self.run(source_args=source_args, target_args=target_args)
case _:
return False
def __repr__(self) -> str:
return (
f"ModelConverter(source_model={self.source_model.__class__.__name__},"
f" target_model={self.target_model.__class__.__name__}, stage={self.stage})"
)
def __bool__(self) -> bool:
return self.stage == ConversionStage.MODELS_OUTPUT_AGREE
def get_state_dict(self) -> dict[str, Tensor]:
"""Get the converted state_dict."""
if not self:
raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
return self.target_model.state_dict()
def get_mapping(self) -> dict[str, str]:
"""Get the mapping between the source and target models' state_dicts."""
if not self:
raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
assert self._stored_mapping is not None, "Mapping is not stored"
return self._stored_mapping
def save_to_safetensors(self, path: Path | str, metadata: dict[str, str] | None = None, half: bool = False) -> None:
"""Save the converted model to a SafeTensors file.
This method can only be called after the conversion process is done.
- `path`: The path to save the converted model to.
- `metadata`: Metadata to save with the converted model.
- `half`: Whether to save the converted model as half precision.
### Raises:
- `ValueError` if the conversion process is not done yet. Run `converter(args)` first.
"""
if not self:
raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
state_dict = self.get_state_dict()
if half:
state_dict = {key: value.half() for key, value in state_dict.items()}
save_to_safetensors(path=path, tensors=state_dict, metadata=metadata)
def map_state_dicts(
self,
source_args: ModuleArgs,
target_args: ModuleArgs | None = None,
) -> dict[str, str] | None:
"""
Find a mapping between the source and target models' state_dicts.
- `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
is not provided, these arguments will also be passed to the target model.
- `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
### Returns:
- A dictionary mapping keys in the target model's state_dict to keys in the source model's state_dict.
"""
if target_args is None:
target_args = source_args
source_order = self._trace_module_execution_order(
module=self.source_model, args=source_args, keys_to_skip=self.source_keys_to_skip
)
target_order = self._trace_module_execution_order(
module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip
)
if not self._assert_shapes_aligned(source_order=source_order, target_order=target_order):
return None
mapping: dict[str, str] = {}
for model_type_shape in source_order:
source_keys = source_order[model_type_shape]
target_keys = target_order[model_type_shape]
mapping.update(zip(target_keys, source_keys))
return mapping
def compare_models(
self,
source_args: ModuleArgs,
target_args: ModuleArgs | None = None,
threshold: float = 1e-5,
) -> bool:
"""
Compare the outputs of the source and target models.
- `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
is not provided, these arguments will also be passed to the target model.
- `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
- `threshold`: The threshold for comparing outputs between the source and target models.
"""
if target_args is None:
target_args = source_args
source_outputs = self._collect_layers_outputs(
module=self.source_model, args=source_args, keys_to_skip=self.source_keys_to_skip
)
target_outputs = self._collect_layers_outputs(
module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip
)
prev_source_key, prev_target_key = None, None
for (source_key, source_output), (target_key, target_output) in zip(source_outputs, target_outputs):
diff = norm(source_output - target_output).item()
if diff > threshold:
self._log(
f"Models diverged between {prev_source_key} and {source_key}, and between {prev_target_key} and"
f" {target_key}, difference in norm: {diff}"
)
return False
prev_source_key, prev_target_key = source_key, target_key
return True
def _run_init_stage(self) -> bool:
"""Run the init stage of the conversion process."""
is_count_correct = self._verify_basic_layers_count()
is_not_missing_layers = self._verify_missing_basic_layers()
return is_count_correct and is_not_missing_layers
def _run_basic_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool:
"""Run the basic layers match stage of the conversion process."""
self._log(message="Finding matching shapes and layers...")
mapping = self.map_state_dicts(source_args=source_args, target_args=target_args)
self._stored_mapping = mapping
if mapping is None:
self._log(message="Models do not have matching shapes.")
return False
self._log(message="Found matching shapes and layers. Converting state_dict...")
source_state_dict = self.source_model.state_dict()
target_state_dict = self.target_model.state_dict()
converted_state_dict = self._convert_state_dict(
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
)
self.target_model.load_state_dict(state_dict=converted_state_dict)
return True
def _run_shape_and_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool:
"""Run the shape and layers match stage of the conversion process."""
if self.compare_models(source_args=source_args, target_args=target_args, threshold=self.threshold):
self._log(message="Models agree. You can export the converted model using `save_to_safetensors`")
return True
else:
self._log(message="Models do not agree. Try to increase the threshold or modify the models.")
return False
def _run_models_output_agree_stage(self) -> bool:
"""Run the models output agree stage of the conversion process."""
self._log(message="Conversion is done: you can export the converted model using `save_to_safetensors`")
return True
def _log(self, message: str) -> None:
"""Print a message if `verbose` is `True`."""
if self.verbose:
print(message)
def _debug_print_shapes(
self,
shape: ModelTypeShape,
source_keys: list[str],
target_keys: list[str],
) -> None:
"""Print the shapes of the sub-modules in `source_keys` and `target_keys`."""
self._log(message=f"{shape}")
max_len = max(len(source_keys), len(target_keys))
for i in range(max_len):
source_key = source_keys[i] if i < len(source_keys) else "---"
target_key = target_keys[i] if i < len(target_keys) else "---"
self._log(f"\t{source_key}\t{target_key}")
@staticmethod
def _unpack_module_args(module_args: ModuleArgs) -> tuple[tuple[Any, ...], dict[str, Any]]:
"""Unpack the positional and keyword arguments passed to a module."""
match module_args:
case tuple(positional_args):
keyword_args: dict[str, Any] = {}
case {"positional": positional_args, "keyword": keyword_args}:
pass
case _:
positional_args = ()
keyword_args = dict(**module_args)
return positional_args, keyword_args
def _infer_basic_layer_type(self, module: nn.Module) -> type[nn.Module] | None:
"""Infer the type of a basic layer."""
for layer_type in TORCH_BASIC_LAYERS:
if isinstance(module, layer_type):
return layer_type
for source_type in self.custom_layer_mapping.keys():
if isinstance(module, source_type):
return source_type
return None
def get_module_signature(self, module: nn.Module) -> ModelTypeShape:
"""Get the signature of a module."""
layer_type = self._infer_basic_layer_type(module=module)
assert layer_type is not None, f"Module {module} is not a basic layer"
param_shapes = [p.shape for p in module.parameters()]
return (str(object=layer_type), tuple(param_shapes))
def _count_basic_layers(self, module: nn.Module) -> dict[type[nn.Module], int]:
"""Count the number of basic layers in a module."""
count: DefaultDict[type[nn.Module], int] = defaultdict(int)
for submodule in module.modules():
layer_type = self._infer_basic_layer_type(module=submodule)
if layer_type is not None:
count[layer_type] += 1
return count
def _verify_basic_layers_count(self) -> bool:
"""Verify that the source and target models have the same number of basic layers."""
source_layers = self._count_basic_layers(module=self.source_model)
target_layers = self._count_basic_layers(module=self.target_model)
diff: dict[type[nn.Module], tuple[int, int]] = {}
for layer_type in set(source_layers.keys()) | set(target_layers.keys()):
source_count = source_layers.get(layer_type, 0)
target_count = target_layers.get(layer_type, 0)
if source_count != target_count:
diff[layer_type] = (source_count, target_count)
if diff:
message = "Models do not have the same number of basic layers:\n"
for layer_type, counts in diff.items():
message += f" {layer_type}: Source {counts[0]} - Target {counts[1]}\n"
self._log(message=message.rstrip())
return False
return True
def _is_weighted_leaf_module(self, module: nn.Module) -> bool:
"""Check if a module is a leaf module with weights."""
return next(module.parameters(), None) is not None and next(module.children(), None) is None
def _check_for_missing_basic_layers(self, module: nn.Module) -> list[type[nn.Module]]:
"""Check if a module has weighted leaf modules that are not basic layers."""
return [
type(submodule)
for submodule in module.modules()
if self._is_weighted_leaf_module(module=submodule) and not self._infer_basic_layer_type(module=submodule)
]
def _verify_missing_basic_layers(self) -> bool:
"""Verify that the source and target models do not have missing basic layers."""
missing_source_layers = self._check_for_missing_basic_layers(module=self.source_model)
missing_target_layers = self._check_for_missing_basic_layers(module=self.target_model)
if missing_source_layers or missing_target_layers:
self._log(
message=(
"Models might have missing basic layers. You can either pass them into keys to skip or set"
f" `check_missing_basic_layer` to `False`: {missing_source_layers}, {missing_target_layers}"
)
)
return False
return True
@torch.no_grad()
def _trace_module_execution_order(
self,
module: nn.Module,
args: ModuleArgs,
keys_to_skip: list[str],
) -> dict[ModelTypeShape, list[str]]:
"""
Execute a forward pass and store the order of execution of specific sub-modules.
- `module`: The module to trace.
- `args`: The arguments to pass to the module it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
- `keys_to_skip`: A list of keys to skip when tracing the module.
### Returns:
- A dictionary mapping the signature of each sub-module to a list of keys in the module's `named_modules`
"""
submodule_to_key: dict[nn.Module, str] = {}
execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list)
def collect_execution_order_hook(layer: nn.Module, *_: Any) -> None:
layer_signature = self.get_module_signature(module=layer)
execution_order[layer_signature].append(submodule_to_key[layer])
hooks: list[RemovableHandle] = []
named_modules: list[tuple[str, nn.Module]] = module.named_modules() # type: ignore
for name, submodule in named_modules:
if (self._infer_basic_layer_type(module=submodule) is not None) and name not in keys_to_skip:
submodule_to_key[submodule] = name # type: ignore
hook = submodule.register_forward_hook(hook=collect_execution_order_hook)
hooks.append(hook)
positional_args, keyword_args = self._unpack_module_args(module_args=args)
module(*positional_args, **keyword_args)
for hook in hooks:
hook.remove()
return dict(execution_order)
def _assert_shapes_aligned(
self, source_order: dict[ModelTypeShape, list[str]], target_order: dict[ModelTypeShape, list[str]]
) -> bool:
"""Assert that the shapes of the sub-modules in `source_order` and `target_order` are aligned."""
model_type_shapes = set(source_order.keys()) | set(target_order.keys())
shape_missmatched = False
for model_type_shape in model_type_shapes:
source_keys = source_order.get(model_type_shape, [])
target_keys = target_order.get(model_type_shape, [])
if len(source_keys) != len(target_keys):
shape_missmatched = True
self._debug_print_shapes(shape=model_type_shape, source_keys=source_keys, target_keys=target_keys)
return not shape_missmatched
@staticmethod
def _convert_state_dict(
source_state_dict: dict[str, Tensor], target_state_dict: dict[str, Tensor], state_dict_mapping: dict[str, str]
) -> dict[str, Tensor]:
"""Convert the source model's state_dict to match the target model's state_dict."""
converted_state_dict: dict[str, Tensor] = {}
for target_key in target_state_dict:
target_prefix, suffix = target_key.rsplit(sep=".", maxsplit=1)
source_prefix = state_dict_mapping[target_prefix]
source_key = ".".join([source_prefix, suffix])
converted_state_dict[target_key] = source_state_dict[source_key]
return converted_state_dict
@torch.no_grad()
def _collect_layers_outputs(
self, module: nn.Module, args: ModuleArgs, keys_to_skip: list[str]
) -> list[tuple[str, Tensor]]:
"""
Execute a forward pass and store the output of specific sub-modules.
- `module`: The module to trace.
- `args`: The arguments to pass to the module it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
- `keys_to_skip`: A list of keys to skip when tracing the module.
### Returns:
- A list of tuples containing the key of each sub-module and its output.
### Note:
- The output of each sub-module is cloned to avoid memory leaks.
"""
submodule_to_key: dict[nn.Module, str] = {}
execution_order: list[tuple[str, Tensor]] = []
def collect_execution_order_hook(layer: nn.Module, _: Any, output: Tensor) -> None:
execution_order.append((submodule_to_key[layer], output.clone()))
hooks: list[RemovableHandle] = []
named_modules: list[tuple[str, nn.Module]] = module.named_modules() # type: ignore
for name, submodule in named_modules:
if (self._infer_basic_layer_type(module=submodule) is not None) and name not in keys_to_skip:
submodule_to_key[submodule] = name # type: ignore
hook = submodule.register_forward_hook(hook=collect_execution_order_hook)
hooks.append(hook)
positional_args, keyword_args = self._unpack_module_args(module_args=args)
module(*positional_args, **keyword_args)
for hook in hooks:
hook.remove()
return execution_order

View file

@ -1,18 +1,13 @@
from collections import defaultdict
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Literal, TypeVar
from typing import Dict, Iterable, Literal, TypeVar
from PIL import Image
from numpy import array, float32
from pathlib import Path
from safetensors import safe_open as _safe_open # type: ignore
from safetensors.torch import save_file as _save_file # type: ignore
from torch import norm as _norm, manual_seed as _manual_seed # type: ignore
import torch
from torch.nn.functional import pad as _pad, interpolate as _interpolate # type: ignore
from torch import Size, Tensor, tensor, no_grad, device as Device, dtype as DType, nn
from torch.utils.hooks import RemovableHandle
if TYPE_CHECKING:
from refiners.fluxion.layers.module import Module
from torch import Tensor, device as Device, dtype as DType
T = TypeVar("T")
@ -31,7 +26,7 @@ def pad(x: Tensor, pad: Iterable[int], value: float = 0.0) -> Tensor:
return _pad(input=x, pad=pad, value=value) # type: ignore
def interpolate(x: Tensor, factor: float | Size, mode: str = "nearest") -> Tensor:
def interpolate(x: Tensor, factor: float | torch.Size, mode: str = "nearest") -> Tensor:
return (
_interpolate(x, scale_factor=factor, mode=mode)
if isinstance(factor, float | int)
@ -44,7 +39,9 @@ def bidirectional_mapping(mapping: Dict[str, str]) -> Dict[str, str]:
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
return tensor(array(image).astype(float32).transpose(2, 0, 1) / 255.0, device=device, dtype=dtype).unsqueeze(0)
return torch.tensor(array(image).astype(float32).transpose(2, 0, 1) / 255.0, device=device, dtype=dtype).unsqueeze(
0
)
def tensor_to_image(tensor: Tensor) -> Image.Image:
@ -77,220 +74,3 @@ def load_metadata_from_safetensors(path: Path | str) -> dict[str, str] | None:
def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None:
_save_file(tensors, path, metadata) # type: ignore
class BasicLayers(Enum):
Conv1d = nn.Conv1d
Conv2d = nn.Conv2d
Conv3d = nn.Conv3d
ConvTranspose1d = nn.ConvTranspose1d
ConvTranspose2d = nn.ConvTranspose2d
ConvTranspose3d = nn.ConvTranspose3d
Linear = nn.Linear
BatchNorm1d = nn.BatchNorm1d
BatchNorm2d = nn.BatchNorm2d
BatchNorm3d = nn.BatchNorm3d
LayerNorm = nn.LayerNorm
GroupNorm = nn.GroupNorm
Embedding = nn.Embedding
MaxPool2d = nn.MaxPool2d
AvgPool2d = nn.AvgPool2d
AdaptiveAvgPool2d = nn.AdaptiveAvgPool2d
ModelTypeShape = tuple[str, tuple[Size, ...]]
def infer_basic_layer_type(module: nn.Module) -> BasicLayers | None:
"""Identify if the provided module matches any in the BasicLayers enum."""
for layer_type in BasicLayers:
if isinstance(module, layer_type.value):
return layer_type
return None
def get_module_signature(module: "Module") -> ModelTypeShape:
"""Return a tuple representing the module's type and parameter shapes."""
layer_type = infer_basic_layer_type(module=module)
assert layer_type is not None, f"Module {module} is not a basic layer"
param_shapes = [p.shape for p in module.parameters()]
return (str(object=layer_type), tuple(param_shapes))
def forward_order_of_execution(
module: "Module",
example_args: tuple[Any, ...],
key_skipper: Callable[[str], bool] | None = None,
) -> dict[ModelTypeShape, list[str]]:
"""
Determine the execution order of sub-modules during a forward pass.
Optionally skips specific modules using `key_skipper`.
"""
key_skipper = key_skipper or (lambda _: False)
submodule_to_key: dict["Module", str] = {}
execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list)
def collect_execution_order_hook(layer: "Module", *_: Any) -> None:
layer_signature = get_module_signature(module=layer)
execution_order[layer_signature].append(submodule_to_key[layer])
hooks: list[RemovableHandle] = []
for name, submodule in module.named_modules():
if (infer_basic_layer_type(module=submodule) is not None) and not key_skipper(name):
submodule_to_key[submodule] = name
hook = submodule.register_forward_hook(hook=collect_execution_order_hook)
hooks.append(hook)
with no_grad():
module(*example_args)
for hook in hooks:
hook.remove()
return dict(execution_order)
def print_side_by_side(
shape: ModelTypeShape,
source_keys: list[str],
target_keys: list[str],
) -> None:
"""Print module keys side by side, useful for debugging shape mismatches."""
print(f"{shape}")
max_len = max(len(source_keys), len(target_keys))
for i in range(max_len):
source_key = source_keys[i] if i < len(source_keys) else "---"
target_key = target_keys[i] if i < len(target_keys) else "---"
print(f"\t{source_key}\t{target_key}")
def verify_shape_match(
source_order: dict[ModelTypeShape, list[str]], target_order: dict[ModelTypeShape, list[str]]
) -> bool:
"""Check if the sub-modules in source and target have matching shapes."""
model_type_shapes = set(source_order.keys()) | set(target_order.keys())
shape_missmatched = False
for model_type_shape in model_type_shapes:
source_keys = source_order.get(model_type_shape, [])
target_keys = target_order.get(model_type_shape, [])
if len(source_keys) != len(target_keys):
shape_missmatched = True
print_side_by_side(shape=model_type_shape, source_keys=source_keys, target_keys=target_keys)
return not shape_missmatched
def create_state_dict_mapping(
source_model: "Module",
target_model: "Module",
source_args: tuple[Any, ...],
target_args: tuple[Any, ...] | None = None,
source_key_skipper: Callable[[str], bool] | None = None,
target_key_skipper: Callable[[str], bool] | None = None,
) -> dict[str, str] | None:
"""
Create a mapping between state_dict keys of the source and target models.
This facilitates the transfer of weights when architectures have slight differences.
"""
if target_args is None:
target_args = source_args
source_order = forward_order_of_execution(
module=source_model, example_args=source_args, key_skipper=source_key_skipper
)
target_order = forward_order_of_execution(
module=target_model, example_args=target_args, key_skipper=target_key_skipper
)
if not verify_shape_match(source_order=source_order, target_order=target_order):
return None
mapping: dict[str, str] = {}
for model_type_shape in source_order:
source_keys = source_order[model_type_shape]
target_keys = target_order[model_type_shape]
mapping.update(zip(target_keys, source_keys))
return mapping
def convert_state_dict(
source_state_dict: dict[str, Tensor], target_state_dict: dict[str, Tensor], state_dict_mapping: dict[str, str]
) -> dict[str, Tensor]:
"""Convert source state_dict based on the provided mapping to match target state_dict structure."""
converted_state_dict: dict[str, Tensor] = {}
for target_key in target_state_dict:
target_prefix, suffix = target_key.rsplit(sep=".", maxsplit=1)
source_prefix = state_dict_mapping[target_prefix]
source_key = ".".join([source_prefix, suffix])
converted_state_dict[target_key] = source_state_dict[source_key]
return converted_state_dict
def forward_store_outputs(
module: "Module",
example_args: tuple[Any, ...],
key_skipper: Callable[[str], bool] | None = None,
) -> list[tuple[str, Tensor]]:
"""Execute a forward pass and store outputs of specific sub-modules."""
key_skipper = key_skipper or (lambda _: False)
submodule_to_key: dict["Module", str] = {}
execution_order: list[tuple[str, Tensor]] = [] # Store outputs in a list
def collect_execution_order_hook(layer: "Module", _: Any, output: Tensor) -> None:
execution_order.append((submodule_to_key[layer], output.clone())) # Store a copy of the output
hooks: list[RemovableHandle] = []
for name, submodule in module.named_modules():
if (infer_basic_layer_type(module=module) is not None) and not key_skipper(name):
submodule_to_key[submodule] = name
hook = submodule.register_forward_hook(hook=collect_execution_order_hook)
hooks.append(hook)
with no_grad():
module(*example_args)
for hook in hooks:
hook.remove()
return execution_order
def compare_models(
source_model: "Module",
target_model: "Module",
source_args: tuple[Any, ...],
target_args: tuple[Any, ...] | None = None,
source_key_skipper: Callable[[str], bool] | None = None,
target_key_skipper: Callable[[str], bool] | None = None,
threshold: float = 1e-5,
) -> bool:
"""
Compare the outputs of two models given the same inputs.
Flag if any difference exceeds the given threshold.
"""
if target_args is None:
target_args = source_args
source_order = forward_store_outputs(module=source_model, example_args=source_args, key_skipper=source_key_skipper)
target_order = forward_store_outputs(module=target_model, example_args=target_args, key_skipper=target_key_skipper)
prev_source_key, prev_target_key = None, None
for (source_key, source_output), (target_key, target_output) in zip(source_order, target_order):
diff = norm(source_output - target_output).item()
if diff > threshold:
print(
f"Models diverged between {prev_source_key} and {source_key}, and between {prev_target_key} and"
f" {target_key}, difference in norm: {diff}"
)
return False
prev_source_key, prev_target_key = source_key, target_key
return True

View file

@ -3,9 +3,10 @@ from pathlib import Path
from warnings import warn
import pytest
import torch
from refiners.fluxion.utils import manual_seed
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
from refiners.fluxion.utils import compare_models
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
from refiners.fluxion.model_converter import ModelConverter
@pytest.fixture(scope="module")
@ -47,23 +48,28 @@ def refiners_sdxl_unet(sdxl_unet_weights_std: Path) -> SDXLUNet:
@torch.no_grad()
def test_sdxl_unet(diffusers_sdxl_unet: Any, refiners_sdxl_unet: SDXLUNet) -> None:
torch.manual_seed(seed=0) # type: ignore
source = diffusers_sdxl_unet
target = refiners_sdxl_unet
manual_seed(seed=0)
x = torch.randn(1, 4, 32, 32)
timestep = torch.tensor(data=[0])
clip_text_embeddings = torch.randn(1, 77, 2048)
added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)}
source_args = (x, timestep, clip_text_embeddings, None, None, None, None, added_cond_kwargs)
refiners_sdxl_unet.set_timestep(timestep=timestep)
refiners_sdxl_unet.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
refiners_sdxl_unet.set_time_ids(time_ids=added_cond_kwargs["time_ids"])
refiners_sdxl_unet.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
target.set_timestep(timestep=timestep)
target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
target.set_time_ids(time_ids=added_cond_kwargs["time_ids"])
target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
target_args = (x,)
source_args = {
"positional": (x, timestep, clip_text_embeddings),
"keyword": {"added_cond_kwargs": added_cond_kwargs},
}
assert compare_models(
source_model=diffusers_sdxl_unet,
target_model=refiners_sdxl_unet,
converter = ModelConverter(source_model=source, target_model=target, verbose=False, threshold=1e-2)
assert converter.run(
source_args=source_args,
target_args=target_args,
threshold=1e-2,
)