delete old conversion scripts

This commit is contained in:
Laurent 2024-10-09 09:24:59 +00:00 committed by Laureηt
parent 2796117d2d
commit 189cfa1a69
18 changed files with 0 additions and 3747 deletions

View file

@ -1 +0,0 @@
::: refiners.fluxion.model_converter

View file

@ -1,81 +0,0 @@
import argparse
from pathlib import Path
import torch
from diffusers import AutoencoderKL # type: ignore
from torch import nn
from refiners.fluxion.model_converter import ModelConverter
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
class Args(argparse.Namespace):
source_path: str
output_path: str | None
use_half: bool
verbose: bool
def setup_converter(args: Args) -> ModelConverter:
target = LatentDiffusionAutoencoder()
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
source: nn.Module = AutoencoderKL.from_pretrained( # type: ignore
pretrained_model_name_or_path=args.source_path,
subfolder=args.subfolder,
low_cpu_mem_usage=False,
) # type: ignore
x = torch.randn(1, 3, 512, 512)
converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose)
if not converter.run(source_args=(x,)):
raise RuntimeError("Model conversion failed")
return converter
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert a pretrained diffusers AutoencoderKL model to a refiners Latent Diffusion Autoencoder"
)
parser.add_argument(
"--from",
type=str,
dest="source_path",
default="runwayml/stable-diffusion-v1-5",
help="Path to the source pretrained model (default: 'runwayml/stable-diffusion-v1-5').",
)
parser.add_argument(
"--subfolder",
type=str,
dest="subfolder",
default="vae",
help="Subfolder in the source path where the model is located inside the Hub (default: 'vae')",
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default=None,
help=(
"Path to save the converted model (extension will be .safetensors). If not specified, the output path will"
" be the source path with the extension changed to .safetensors."
),
)
parser.add_argument(
"--half",
action="store_true",
dest="use_half",
default=False,
help="Use this flag to save the output file as half precision (default: full precision).",
)
parser.add_argument(
"--verbose",
action="store_true",
dest="verbose",
default=False,
help="Use this flag to print verbose output during conversion.",
)
args = parser.parse_args(namespace=Args())
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}-autoencoder.safetensors"
assert args.output_path is not None
converter = setup_converter(args=args)
converter.save_to_safetensors(path=args.output_path, half=args.use_half)

View file

@ -1,233 +0,0 @@
# pyright: reportPrivateUsage=false
import argparse
from pathlib import Path
import torch
from diffusers import ControlNetModel # type: ignore
from torch import nn
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import no_grad, save_to_safetensors
from refiners.foundationals.latent_diffusion import (
DPMSolver,
SD1ControlnetAdapter,
SD1UNet,
)
class Args(argparse.Namespace):
source_path: str
output_path: str | None
@no_grad()
def convert(args: Args) -> dict[str, torch.Tensor]:
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
controlnet_src: nn.Module = ControlNetModel.from_pretrained( # type: ignore
pretrained_model_name_or_path=args.source_path,
low_cpu_mem_usage=False,
)
unet = SD1UNet(in_channels=4)
adapter = SD1ControlnetAdapter(unet, name="mycn").inject()
controlnet = adapter.controlnet
condition = torch.randn(1, 3, 512, 512)
adapter.set_controlnet_condition(condition=condition)
clip_text_embedding = torch.rand(1, 77, 768)
unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
solver = DPMSolver(num_inference_steps=10)
timestep = solver.timesteps[0].unsqueeze(dim=0)
unet.set_timestep(timestep=timestep.unsqueeze(dim=0))
x = torch.randn(1, 4, 64, 64)
# We need the hack below because our implementation is not strictly equivalent
# to diffusers in order, since we compute the residuals inline instead of
# in a separate step.
converter = ModelConverter(
source_model=controlnet_src, target_model=controlnet, skip_output_check=True, verbose=False
)
source_order = converter._trace_module_execution_order(
module=controlnet_src, args=(x, timestep, clip_text_embedding, condition), keys_to_skip=[]
)
target_order = converter._trace_module_execution_order(module=controlnet, args=(x,), keys_to_skip=[])
broken_k = (nn.Conv2d, (torch.Size([320, 320, 1, 1]), torch.Size([320])))
expected_source_order = [
"down_blocks.0.attentions.0.proj_in",
"down_blocks.0.attentions.0.proj_out",
"down_blocks.0.attentions.1.proj_in",
"down_blocks.0.attentions.1.proj_out",
"controlnet_down_blocks.0",
"controlnet_down_blocks.1",
"controlnet_down_blocks.2",
"controlnet_down_blocks.3",
]
expected_target_order = [
"DownBlocks.Chain_1.Passthrough.Conv2d",
"DownBlocks.Chain_2.CLIPLCrossAttention.Chain_1.Conv2d",
"DownBlocks.Chain_2.CLIPLCrossAttention.Chain_3.Conv2d",
"DownBlocks.Chain_2.Passthrough.Conv2d",
"DownBlocks.Chain_3.CLIPLCrossAttention.Chain_1.Conv2d",
"DownBlocks.Chain_3.CLIPLCrossAttention.Chain_3.Conv2d",
"DownBlocks.Chain_3.Passthrough.Conv2d",
"DownBlocks.Chain_4.Passthrough.Conv2d",
]
fixed_source_order = [
"controlnet_down_blocks.0",
"down_blocks.0.attentions.0.proj_in",
"down_blocks.0.attentions.0.proj_out",
"controlnet_down_blocks.1",
"down_blocks.0.attentions.1.proj_in",
"down_blocks.0.attentions.1.proj_out",
"controlnet_down_blocks.2",
"controlnet_down_blocks.3",
]
assert source_order[broken_k] == expected_source_order
assert target_order[broken_k] == expected_target_order
source_order[broken_k] = fixed_source_order
broken_k = (nn.Conv2d, (torch.Size([640, 640, 1, 1]), torch.Size([640])))
expected_source_order = [
"down_blocks.1.attentions.0.proj_in",
"down_blocks.1.attentions.0.proj_out",
"down_blocks.1.attentions.1.proj_in",
"down_blocks.1.attentions.1.proj_out",
"controlnet_down_blocks.4",
"controlnet_down_blocks.5",
"controlnet_down_blocks.6",
]
expected_target_order = [
"DownBlocks.Chain_5.CLIPLCrossAttention.Chain_1.Conv2d",
"DownBlocks.Chain_5.CLIPLCrossAttention.Chain_3.Conv2d",
"DownBlocks.Chain_5.Passthrough.Conv2d",
"DownBlocks.Chain_6.CLIPLCrossAttention.Chain_1.Conv2d",
"DownBlocks.Chain_6.CLIPLCrossAttention.Chain_3.Conv2d",
"DownBlocks.Chain_6.Passthrough.Conv2d",
"DownBlocks.Chain_7.Passthrough.Conv2d",
]
fixed_source_order = [
"down_blocks.1.attentions.0.proj_in",
"down_blocks.1.attentions.0.proj_out",
"controlnet_down_blocks.4",
"down_blocks.1.attentions.1.proj_in",
"down_blocks.1.attentions.1.proj_out",
"controlnet_down_blocks.5",
"controlnet_down_blocks.6",
]
assert source_order[broken_k] == expected_source_order
assert target_order[broken_k] == expected_target_order
source_order[broken_k] = fixed_source_order
broken_k = (nn.Conv2d, (torch.Size([1280, 1280, 1, 1]), torch.Size([1280])))
expected_source_order = [
"down_blocks.2.attentions.0.proj_in",
"down_blocks.2.attentions.0.proj_out",
"down_blocks.2.attentions.1.proj_in",
"down_blocks.2.attentions.1.proj_out",
"mid_block.attentions.0.proj_in",
"mid_block.attentions.0.proj_out",
"controlnet_down_blocks.7",
"controlnet_down_blocks.8",
"controlnet_down_blocks.9",
"controlnet_down_blocks.10",
"controlnet_down_blocks.11",
"controlnet_mid_block",
]
expected_target_order = [
"DownBlocks.Chain_8.CLIPLCrossAttention.Chain_1.Conv2d",
"DownBlocks.Chain_8.CLIPLCrossAttention.Chain_3.Conv2d",
"DownBlocks.Chain_8.Passthrough.Conv2d",
"DownBlocks.Chain_9.CLIPLCrossAttention.Chain_1.Conv2d",
"DownBlocks.Chain_9.CLIPLCrossAttention.Chain_3.Conv2d",
"DownBlocks.Chain_9.Passthrough.Conv2d",
"DownBlocks.Chain_10.Passthrough.Conv2d",
"DownBlocks.Chain_11.Passthrough.Conv2d",
"DownBlocks.Chain_12.Passthrough.Conv2d",
"MiddleBlock.CLIPLCrossAttention.Chain_1.Conv2d",
"MiddleBlock.CLIPLCrossAttention.Chain_3.Conv2d",
"MiddleBlock.Passthrough.Conv2d",
]
fixed_source_order = [
"down_blocks.2.attentions.0.proj_in",
"down_blocks.2.attentions.0.proj_out",
"controlnet_down_blocks.7",
"down_blocks.2.attentions.1.proj_in",
"down_blocks.2.attentions.1.proj_out",
"controlnet_down_blocks.8",
"controlnet_down_blocks.9",
"controlnet_down_blocks.10",
"controlnet_down_blocks.11",
"mid_block.attentions.0.proj_in",
"mid_block.attentions.0.proj_out",
"controlnet_mid_block",
]
assert source_order[broken_k] == expected_source_order
assert target_order[broken_k] == expected_target_order
source_order[broken_k] = fixed_source_order
assert converter._assert_shapes_aligned(source_order=source_order, target_order=target_order), "Shapes not aligned"
mapping: dict[str, str] = {}
for model_type_shape in source_order:
source_keys = source_order[model_type_shape]
target_keys = target_order[model_type_shape]
mapping.update(zip(target_keys, source_keys))
state_dict = converter._convert_state_dict(
source_state_dict=controlnet_src.state_dict(),
target_state_dict=controlnet.state_dict(),
state_dict_mapping=mapping,
)
return {k: v.half() for k, v in state_dict.items()}
def main() -> None:
parser = argparse.ArgumentParser(description="Convert a diffusers ControlNet model to a Refiners ControlNet model")
parser.add_argument(
"--from",
type=str,
dest="source_path",
default="lllyasviel/sd-controlnet-depth",
help=(
"Can be a path to a .bin, a .safetensors file, or a model identifier from Hugging Face Hub. Defaults to"
" lllyasviel/sd-controlnet-depth"
),
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
required=False,
default=None,
help=(
"Output path (.safetensors) for converted model. If not provided, the output path will be the same as the"
" source path."
),
)
args = parser.parse_args(namespace=Args())
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}-controlnet.safetensors"
state_dict = convert(args=args)
save_to_safetensors(path=args.output_path, tensors=state_dict)
if __name__ == "__main__":
main()

View file

@ -1,154 +0,0 @@
import argparse
from pathlib import Path
import torch
from refiners.fluxion.utils import save_to_safetensors
from refiners.foundationals.latent_diffusion import SD1IPAdapter, SD1UNet, SDXLIPAdapter, SDXLUNet
# Running:
#
# from diffusers import UNet2DConditionModel
# unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
# for k in unet.attn_processors.keys():
# print(k)
#
# Gives:
#
# down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor
# down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor
# ...
# down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor
# up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor
# up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor
# ...
# up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor
# mid_block.attentions.0.transformer_blocks.0.attn1.processor
# mid_block.attentions.0.transformer_blocks.0.attn2.processor
#
# With attn1=self-attention and attn2=cross-attention, and middle block in last position. So in terms of increasing
# indices:
#
# DownBlocks -> [1, 3, 5, 7, 9, 11]
# MiddleBlock -> [31]
# UpBlocks -> [13, 15, 17, 19, 21, 23, 25, 27, 29]
#
# Same for SDXL with more layers (70 cross-attentions vs. 16)
CROSS_ATTN_MAPPING: dict[str, list[int]] = {
"sd15": list(range(1, 12, 2)) + [31] + list(range(13, 30, 2)),
"sdxl": list(range(1, 48, 2)) + list(range(121, 140, 2)) + list(range(49, 120, 2)),
}
def main() -> None:
parser = argparse.ArgumentParser(description="Converts a IP-Adapter diffusers model to refiners.")
parser.add_argument(
"--from",
type=str,
required=True,
dest="source_path",
help="Path to the source model. (e.g.: 'ip-adapter_sd15.bin').",
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default=None,
help=(
"Path to save the converted model. If not specified, the output path will be the source path with the"
" extension changed to .safetensors."
),
)
parser.add_argument("--verbose", action="store_true", dest="verbose")
parser.add_argument("--half", action="store_true", dest="half")
args = parser.parse_args()
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}.safetensors"
# Do not use `load_tensors`: first-level values are not tensors.
weights: dict[str, dict[str, torch.Tensor]] = torch.load(args.source_path, "cpu") # type: ignore
assert isinstance(weights, dict)
assert sorted(weights.keys()) == ["image_proj", "ip_adapter"]
image_proj_weights = weights["image_proj"]
ip_adapter_weights = weights["ip_adapter"]
fine_grained = "latents" in image_proj_weights # aka IP-Adapter plus
match len(ip_adapter_weights):
case 32:
ip_adapter = SD1IPAdapter(target=SD1UNet(in_channels=4), fine_grained=fine_grained)
cross_attn_mapping = CROSS_ATTN_MAPPING["sd15"]
case 140:
ip_adapter = SDXLIPAdapter(target=SDXLUNet(in_channels=4), fine_grained=fine_grained)
cross_attn_mapping = CROSS_ATTN_MAPPING["sdxl"]
case _:
raise ValueError("Unexpected number of keys in input checkpoint")
# Manual conversion to avoid any runtime dependency on IP-Adapter[1] custom classes
# [1]: https://github.com/tencent-ailab/IP-Adapter
state_dict: dict[str, torch.Tensor] = {}
image_proj_state_dict: dict[str, torch.Tensor]
if fine_grained:
w = image_proj_weights
image_proj_state_dict = {
"LatentsToken.Parameter.weight": w["latents"].squeeze(0), # drop batch dim = 1
"Linear_1.weight": w["proj_in.weight"],
"Linear_1.bias": w["proj_in.bias"],
"Linear_2.weight": w["proj_out.weight"],
"Linear_2.bias": w["proj_out.bias"],
"LayerNorm.weight": w["norm_out.weight"],
"LayerNorm.bias": w["norm_out.bias"],
}
for i in range(4):
t_pfx, s_pfx = f"Transformer.TransformerLayer_{i+1}.Residual_", f"layers.{i}."
image_proj_state_dict.update(
{
f"{t_pfx}1.PerceiverAttention.Distribute.LayerNorm_1.weight": w[f"{s_pfx}0.norm1.weight"],
f"{t_pfx}1.PerceiverAttention.Distribute.LayerNorm_1.bias": w[f"{s_pfx}0.norm1.bias"],
f"{t_pfx}1.PerceiverAttention.Distribute.LayerNorm_2.weight": w[f"{s_pfx}0.norm2.weight"],
f"{t_pfx}1.PerceiverAttention.Distribute.LayerNorm_2.bias": w[f"{s_pfx}0.norm2.bias"],
f"{t_pfx}1.PerceiverAttention.Parallel.Chain_2.Linear.weight": w[f"{s_pfx}0.to_q.weight"],
f"{t_pfx}1.PerceiverAttention.Parallel.Chain_1.Linear.weight": w[f"{s_pfx}0.to_kv.weight"],
f"{t_pfx}1.PerceiverAttention.Linear.weight": w[f"{s_pfx}0.to_out.weight"],
f"{t_pfx}2.LayerNorm.weight": w[f"{s_pfx}1.0.weight"],
f"{t_pfx}2.LayerNorm.bias": w[f"{s_pfx}1.0.bias"],
f"{t_pfx}2.FeedForward.Linear_1.weight": w[f"{s_pfx}1.1.weight"],
f"{t_pfx}2.FeedForward.Linear_2.weight": w[f"{s_pfx}1.3.weight"],
}
)
else:
image_proj_state_dict = {
"Linear.weight": image_proj_weights["proj.weight"],
"Linear.bias": image_proj_weights["proj.bias"],
"LayerNorm.weight": image_proj_weights["norm.weight"],
"LayerNorm.bias": image_proj_weights["norm.bias"],
}
ip_adapter.image_proj.load_state_dict(state_dict=image_proj_state_dict)
for k, v in image_proj_state_dict.items():
state_dict[f"image_proj.{k}"] = v
assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2
for i, _ in enumerate(ip_adapter.sub_adapters):
cross_attn_index = cross_attn_mapping[i]
k_ip = f"{cross_attn_index}.to_k_ip.weight"
v_ip = f"{cross_attn_index}.to_v_ip.weight"
# the name of the key is not checked at runtime, so we keep the original name
state_dict[f"ip_adapter.{i:03d}.to_k_ip.weight"] = ip_adapter_weights[k_ip]
state_dict[f"ip_adapter.{i:03d}.to_v_ip.weight"] = ip_adapter_weights[v_ip]
if args.half:
state_dict = {key: value.half() for key, value in state_dict.items()}
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}.safetensors"
save_to_safetensors(path=args.output_path, tensors=state_dict)
if __name__ == "__main__":
main()

View file

@ -1,63 +0,0 @@
import argparse
from pathlib import Path
import torch
from diffusers import T2IAdapter # type: ignore
from torch import nn
from refiners.fluxion.model_converter import ModelConverter
from refiners.foundationals.latent_diffusion.t2i_adapter import ConditionEncoder, ConditionEncoderXL
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert a pretrained diffusers T2I-Adapter model to refiners")
parser.add_argument(
"--from",
type=str,
dest="source_path",
required=True,
help="Path or repository name of the source model. (e.g.: 'ip-adapter_sd15.bin').",
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default=None,
help=(
"Path to save the converted model (extension will be .safetensors). If not specified, the output path will"
" be the source path with the extension changed to .safetensors."
),
)
parser.add_argument(
"--half",
action="store_true",
dest="use_half",
default=False,
help="Use this flag to save the output file as half precision (default: full precision).",
)
parser.add_argument(
"--verbose",
action="store_true",
dest="verbose",
default=False,
help="Use this flag to print verbose output during conversion.",
)
args = parser.parse_args()
if args.output_path is None:
args.output_path = f"{Path(args.source_path).name}.safetensors"
assert args.output_path is not None
sdxl = "xl" in args.source_path
target = ConditionEncoderXL() if sdxl else ConditionEncoder()
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
source: nn.Module = T2IAdapter.from_pretrained( # type: ignore
pretrained_model_name_or_path=args.source_path,
low_cpu_mem_usage=False,
)
assert isinstance(source, nn.Module), "Source model is not a nn.Module"
x = torch.randn(1, 3, 1024, 1024) if sdxl else torch.randn(1, 3, 512, 512)
converter = ModelConverter(source_model=source, target_model=target, verbose=args.verbose)
if not converter.run(source_args=(x,)):
raise RuntimeError("Model conversion failed")
converter.save_to_safetensors(path=args.output_path, half=args.use_half)

View file

@ -1,148 +0,0 @@
import argparse
from pathlib import Path
from typing import Any
import torch
from diffusers import UNet2DConditionModel # type: ignore
from torch import nn
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import load_from_safetensors, load_tensors
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import SDXLLcmAdapter
class Args(argparse.Namespace):
source_path: str
output_path: str | None
subfolder: str
half: bool
verbose: bool
skip_init_check: bool
override_weights: str | None
def setup_converter(args: Args) -> ModelConverter:
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
source: nn.Module = UNet2DConditionModel.from_pretrained( # type: ignore
pretrained_model_name_or_path=args.source_path,
subfolder=args.subfolder,
low_cpu_mem_usage=False,
)
if args.override_weights is not None:
if args.override_weights.endswith(".pth"):
sd = load_tensors(args.override_weights)
elif args.override_weights.endswith(".safetensors"):
sd = load_from_safetensors(args.override_weights)
else:
raise ValueError(f"Unsupported file format: {args.override_weights}")
source.load_state_dict(sd)
source_in_channels: int = source.config.in_channels # type: ignore
source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore
source_has_time_ids: bool = source.config.addition_embed_type == "text_time" # type: ignore
source_is_lcm: bool = source.config.time_cond_proj_dim is not None
if source_has_time_ids:
target = SDXLUNet(in_channels=source_in_channels)
else:
target = SD1UNet(in_channels=source_in_channels)
if source_is_lcm:
assert isinstance(target, SDXLUNet)
SDXLLcmAdapter(target=target).inject()
x = torch.randn(1, source_in_channels, 32, 32)
timestep = torch.tensor(data=[0])
clip_text_embeddings = torch.randn(1, 77, source_clip_embedding_dim)
target.set_timestep(timestep=timestep)
target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
added_cond_kwargs = {}
if isinstance(target, SDXLUNet):
added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)}
target.set_time_ids(time_ids=added_cond_kwargs["time_ids"])
target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
target_args = (x,)
source_kwargs: dict[str, Any] = {}
if source_has_time_ids:
source_kwargs["added_cond_kwargs"] = added_cond_kwargs
if source_is_lcm:
source_kwargs["timestep_cond"] = torch.randn(1, source.config.time_cond_proj_dim)
source_args = {
"positional": (x, timestep, clip_text_embeddings),
"keyword": source_kwargs,
}
converter = ModelConverter(
source_model=source,
target_model=target,
skip_init_check=args.skip_init_check,
skip_output_check=True,
verbose=args.verbose,
)
if not converter.run(
source_args=source_args,
target_args=target_args,
):
raise RuntimeError("Model conversion failed")
return converter
def main() -> None:
parser = argparse.ArgumentParser(
description="Converts a Diffusion UNet model to a Refiners SD1UNet or SDXLUNet model"
)
parser.add_argument(
"--from",
type=str,
dest="source_path",
default="runwayml/stable-diffusion-v1-5",
help=(
"Can be a path to a .bin file, a .safetensors file or a model name from the HuggingFace Hub. Default:"
" runwayml/stable-diffusion-v1-5"
),
)
parser.add_argument(
"--override-weights",
type=str,
default=None,
help=(
"Path to a weights file to override the source model (keeping its config). "
"This is useful for models distributed as .pth files."
),
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default=None,
help=(
"Output path (.safetensors) for converted model. If not provided, the output path will be the same as the"
" source path."
),
)
parser.add_argument("--subfolder", type=str, default="unet", help="Subfolder. Default: unet.")
parser.add_argument(
"--skip-init-check",
action="store_true",
help="Skip check that source and target have the same layers count.",
)
parser.add_argument("--half", action="store_true", help="Convert to half precision.")
parser.add_argument(
"--verbose",
action="store_true",
default=False,
help="Prints additional information during conversion. Default: False",
)
args = parser.parse_args(namespace=Args())
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}-unet.safetensors"
converter = setup_converter(args=args)
converter.save_to_safetensors(path=args.output_path, half=args.half)
if __name__ == "__main__":
main()

View file

@ -1,176 +0,0 @@
import argparse
from pathlib import Path
import torch
from refiners.fluxion.utils import load_tensors, save_to_safetensors
def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None:
"""Convert a DINOv2 weights from facebook to refiners."""
# get depth from "blocks" keys
depth = max([int(k.split(".")[1]) for k in weights.keys() if k.startswith("blocks.")]) + 1
# only needed when pre-training
del weights["mask_token"]
# squeeze cls_token and position_embeddings
weights["cls_token"] = weights["cls_token"].squeeze(0)
weights["pos_embed"] = weights["pos_embed"].squeeze(0)
# rename "w12" to "fc1" and "w3" to "fc2", only for giant model
for key in list(weights.keys()):
if "w3" in key:
new_key = key.replace("w3", "fc2")
weights[new_key] = weights.pop(key)
elif "w12" in key:
# we swap w1 and w2 because of the difference between our GLU implementation and theirs
# see https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/layers/swiglu_ffn.py#L31-L34
# and https://github.com/finegrain-ai/refiners/blob/a2ee70578361e4d84a65a8708564480a9b0ec67e/src/refiners/fluxion/layers/activations.py#L158-L160
weight = weights.pop(key)
w1, w2 = weight.chunk(2, dim=0)
w21 = torch.cat([w2, w1], dim=0)
new_key = key.replace("w12", "fc1")
weights[new_key] = w21
rename_keys: list[tuple[str, str]] = [
("cls_token", "Concatenate.ClassToken.Parameter.weight"),
("pos_embed", "PositionalEncoder.PositionalEmbedding.Parameter.weight"),
("patch_embed.proj.weight", "Concatenate.PatchEncoder.Conv2d.weight"),
("patch_embed.proj.bias", "Concatenate.PatchEncoder.Conv2d.bias"),
("norm.weight", "LayerNorm.weight"),
("norm.bias", "LayerNorm.bias"),
]
for i in range(depth):
rename_keys.append(
(
f"blocks.{i}.norm1.weight",
f"Transformer.TransformerLayer_{i+1}.Residual_1.LayerNorm.weight",
),
)
rename_keys.append(
(
f"blocks.{i}.norm1.bias",
f"Transformer.TransformerLayer_{i+1}.Residual_1.LayerNorm.bias",
),
)
rename_keys.append(
(
f"blocks.{i}.attn.proj.weight",
f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Linear.weight",
),
)
rename_keys.append(
(
f"blocks.{i}.attn.proj.bias",
f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Linear.bias",
),
)
rename_keys.append(
(
f"blocks.{i}.ls1.gamma",
f"Transformer.TransformerLayer_{i+1}.Residual_1.LayerScale.weight",
),
)
rename_keys.append(
(
f"blocks.{i}.norm2.weight",
f"Transformer.TransformerLayer_{i+1}.Residual_2.LayerNorm.weight",
),
)
rename_keys.append(
(
f"blocks.{i}.norm2.bias",
f"Transformer.TransformerLayer_{i+1}.Residual_2.LayerNorm.bias",
),
)
rename_keys.append(
(
f"blocks.{i}.mlp.fc1.weight",
f"Transformer.TransformerLayer_{i+1}.Residual_2.FeedForward.Linear_1.weight",
),
)
rename_keys.append(
(
f"blocks.{i}.mlp.fc1.bias",
f"Transformer.TransformerLayer_{i+1}.Residual_2.FeedForward.Linear_1.bias",
),
)
rename_keys.append(
(
f"blocks.{i}.mlp.fc2.weight",
f"Transformer.TransformerLayer_{i+1}.Residual_2.FeedForward.Linear_2.weight",
),
)
rename_keys.append(
(
f"blocks.{i}.mlp.fc2.bias",
f"Transformer.TransformerLayer_{i+1}.Residual_2.FeedForward.Linear_2.bias",
),
)
rename_keys.append(
(
f"blocks.{i}.ls2.gamma",
f"Transformer.TransformerLayer_{i+1}.Residual_2.LayerScale.weight",
),
)
if "register_tokens" in weights:
weights["register_tokens"] = weights["register_tokens"].squeeze(0)
rename_keys.append(("register_tokens", "Registers.Parameter.weight"))
# rename keys
for old_key, new_key in rename_keys:
weights[new_key] = weights.pop(old_key)
# split the qkv weights and biases
for i in range(depth):
qkv_weight = weights.pop(f"blocks.{i}.attn.qkv.weight")
q_weight, k_weight, v_weight = qkv_weight.chunk(3, dim=0)
weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_1.weight"] = q_weight
weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_2.weight"] = k_weight
weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_3.weight"] = v_weight
qkv_bias = weights.pop(f"blocks.{i}.attn.qkv.bias")
q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_1.bias"] = q_bias
weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_2.bias"] = k_bias
weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_3.bias"] = v_bias
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--from",
type=str,
required=True,
dest="source_path",
help=(
"Official checkpoint from https://github.com/facebookresearch/dinov2"
" e.g. /path/to/dinov2_vits14_pretrain.pth"
),
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default=None,
help=(
"Path to save the converted model. If not specified, the output path will be the source path with the"
" extension changed to .safetensors."
),
)
parser.add_argument("--half", action="store_true", dest="half")
args = parser.parse_args()
weights = load_tensors(args.source_path)
convert_dinov2_facebook(weights)
if args.half:
weights = {key: value.half() for key, value in weights.items()}
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}.safetensors"
save_to_safetensors(path=args.output_path, tensors=weights)
if __name__ == "__main__":
main()

View file

@ -1,102 +0,0 @@
import argparse
from pathlib import Path
import torch
from huggingface_hub import hf_hub_download # type: ignore
from refiners.fluxion.utils import load_from_safetensors, save_to_safetensors
class Args(argparse.Namespace):
source_path: str
output_path: str | None
use_half: bool
def convert(args: Args) -> dict[str, torch.Tensor]:
if Path(args.source_path).suffix != ".safetensors":
args.source_path = hf_hub_download(
repo_id=args.source_path, filename="ella-sd1.5-tsc-t5xl.safetensors", local_dir="tests/weights/ELLA-Adapter"
)
weights = load_from_safetensors(args.source_path)
for key in list(weights.keys()):
if "latents" in key:
new_key = "PerceiverResampler.Latents.ParameterInitialized.weight"
weights[new_key] = weights.pop(key)
elif "time_embedding" in key:
new_key = key.replace("time_embedding", "TimestepEncoder.RangeEncoder").replace("linear", "Linear")
weights[new_key] = weights.pop(key)
elif "proj_in" in key:
new_key = f"PerceiverResampler.Linear.{key.split('.')[-1]}"
weights[new_key] = weights.pop(key)
elif "time_aware" in key:
new_key = f"PerceiverResampler.Residual.Linear.{key.split('.')[-1]}"
weights[new_key] = weights.pop(key)
elif "attn.in_proj" in key:
layer_num = int(key.split(".")[2])
query_param, key_param, value_param = weights.pop(key).chunk(3, dim=0)
param_type = "weight" if "weight" in key else "bias"
for i, param in enumerate([query_param, key_param, value_param]):
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_1.PerceiverAttention.Attention.Distribute.Linear_{i+1}.{param_type}"
weights[new_key] = param
elif "attn.out_proj" in key:
layer_num = int(key.split(".")[2])
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_1.PerceiverAttention.Attention.Linear.{key.split('.')[-1]}"
weights[new_key] = weights.pop(key)
elif "ln_ff" in key:
layer_num = int(key.split(".")[2])
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_2.AdaLayerNorm.Parallel.Chain.Linear.{key.split('.')[-1]}"
weights[new_key] = weights.pop(key)
elif "ln_1" in key or "ln_2" in key:
layer_num = int(key.split(".")[2])
n = 1 if int(key.split(".")[3].split("_")[-1]) == 2 else 2
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_1.PerceiverAttention.Distribute.AdaLayerNorm_{n}.Parallel.Chain.Linear.{key.split('.')[-1]}"
weights[new_key] = weights.pop(key)
elif "mlp" in key:
layer_num = int(key.split(".")[2])
n = 1 if "c_fc" in key else 2
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_2.FeedForward.Linear_{n}.{key.split('.')[-1]}"
weights[new_key] = weights.pop(key)
if args.use_half:
weights = {key: value.half() for key, value in weights.items()}
return weights
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert a pretrained Ella Adapter to refiners implementation")
parser.add_argument(
"--from",
type=str,
dest="source_path",
default="QQGYLab/ELLA",
help=(
"A path to a local .safetensors weights. If not provided, a repo from Hugging Face Hub will be used"
"Default to QQGYLab/ELLA"
),
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default=None,
help=(
"Path to save the converted model (extension will be .safetensors). If not specified, the output path will"
" be the source path with the prefix set to refiners"
),
)
parser.add_argument(
"--half",
action="store_true",
dest="use_half",
default=True,
help="Use this flag to save the output file as half precision (default: full precision).",
)
args = parser.parse_args(namespace=Args())
weights = convert(args)
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}-refiners.safetensors"
save_to_safetensors(path=args.output_path, tensors=weights)

View file

@ -1,348 +0,0 @@
import argparse
import logging
from logging import info
from pathlib import Path
from huggingface_hub import hf_hub_download # type: ignore
from torch import Tensor
from torch.nn import Parameter as TorchParameter
from refiners.fluxion.adapters.lora import Lora, LoraAdapter, auto_attach_loras
from refiners.fluxion.layers import Conv2d
from refiners.fluxion.layers.linear import Linear
from refiners.fluxion.utils import load_from_safetensors, save_to_safetensors
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.control_lora import (
ConditionEncoder,
ControlLora,
ControlLoraAdapter,
ZeroConvolution,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
def sort_keys(key: str, /) -> tuple[str, int]:
"""Compute the score of a key, relatively to its suffix.
When used by [`sorted`][sorted], the keys will only be sorted "at the suffix level".
Args:
key: The key to sort.
Returns:
The padded suffix of the key.
The score of the key's suffix.
"""
if "time_embed" in key: # HACK: will place the "time_embed" layers at very start of the list
return ("", -2)
if "label_emb" in key: # HACK: will place the "label_emb" layers right after "time_embed"
return ("", -1)
if "proj_out" in key: # HACK: will place the "proj_out" layers at the end of each "transformer_blocks"
return (key.removesuffix("proj_out") + "transformer_blocks.99.ff.net.2", 10)
return SDLoraManager.sort_keys(key)
def load_lora_layers(
name: str,
state_dict: dict[str, Tensor],
control_lora: ControlLora,
) -> dict[str, Lora[Linear | Conv2d]]:
"""Load the LoRA layers from the state_dict into the ControlLora.
Args:
name: The name of the LoRA.
state_dict: The state_dict of the LoRA.
control_lora: The ControlLora to load the LoRA layers into.
"""
# filter from the state_dict the layers that will be used for the LoRA layers
lora_weights = {f"{key}.weight": value for key, value in state_dict.items() if ".up" in key or ".down" in key}
# move the tensors to the device and dtype of the ControlLora
lora_weights = {
key: value.to(
dtype=control_lora.dtype,
device=control_lora.device,
)
for key, value in lora_weights.items()
}
# load every LoRA layers from the filtered state_dict
lora_layers = Lora.from_dict(name, state_dict=lora_weights)
# sort all the LoRA's keys using the `sort_keys` method
lora_layers = {
key: lora_layers[key]
for key in sorted(
lora_layers.keys(),
key=sort_keys,
)
}
# auto-attach the LoRA layers to the U-Net
auto_attach_loras(lora_layers, control_lora, exclude=["ZeroConvolution", "ConditionEncoder"])
# eject all the LoRA adapters from the U-Net
# because we need each target path as if the adapter wasn't injected
for lora_layer in lora_layers.values():
lora_adapter = lora_layer.parent
assert isinstance(lora_adapter, LoraAdapter)
lora_adapter.eject()
return lora_layers
def load_condition_encoder(
state_dict: dict[str, Tensor],
control_lora: ControlLora,
) -> None:
"""Load the ConditionEncoder's Conv2d layers from the state_dict into the ControlLora.
Args:
state_dict: The state_dict of the ConditionEncoder.
control_lora: The control_lora to load the ConditionEncoder's Conv2d layers into.
"""
# filter from the state_dict the layers that will be used for the ConditionEncoder
condition_encoder_tensors = {key: value for key, value in state_dict.items() if "input_hint_block" in key}
# move the tensors to the device and dtype of the ControlLora
condition_encoder_tensors = {
key: value.to(
dtype=control_lora.dtype,
device=control_lora.device,
)
for key, value in condition_encoder_tensors.items()
}
# find the ConditionEncoder's Conv2d layers
condition_encoder_layer = control_lora.ensure_find(ConditionEncoder)
condition_encoder_conv2ds = list(condition_encoder_layer.layers(Conv2d))
# replace the Conv2d layers' weights and biases with the ones from the state_dict
for i, layer in enumerate(condition_encoder_conv2ds):
layer.weight = TorchParameter(condition_encoder_tensors[f"input_hint_block.{i*2}.weight"])
layer.bias = TorchParameter(condition_encoder_tensors[f"input_hint_block.{i*2}.bias"])
def load_zero_convolutions(
state_dict: dict[str, Tensor],
control_lora: ControlLora,
) -> None:
"""Load the ZeroConvolution's Conv2d layers from the state_dict into the ControlLora.
Args:
state_dict: The state_dict of the ZeroConvolution.
control_lora: The ControlLora to load the ZeroConvolution's Conv2d layers into.
"""
# filter from the state_dict the layers that will be used for the ZeroConvolution layers
zero_convolution_tensors = {key: value for key, value in state_dict.items() if "zero_convs" in key}
n = len(zero_convolution_tensors) // 2
zero_convolution_tensors[f"zero_convs.{n}.0.weight"] = state_dict["middle_block_out.0.weight"]
zero_convolution_tensors[f"zero_convs.{n}.0.bias"] = state_dict["middle_block_out.0.bias"]
# move the tensors to the device and dtype of the ControlLora
zero_convolution_tensors = {
key: value.to(
dtype=control_lora.dtype,
device=control_lora.device,
)
for key, value in zero_convolution_tensors.items()
}
# find the ZeroConvolution's Conv2d layers
zero_convolution_layers = list(control_lora.layers(ZeroConvolution))
zero_convolution_conv2ds = [layer.ensure_find(Conv2d) for layer in zero_convolution_layers]
# replace the Conv2d layers' weights and biases with the ones from the state_dict
for i, layer in enumerate(zero_convolution_conv2ds):
layer.weight = TorchParameter(zero_convolution_tensors[f"zero_convs.{i}.0.weight"])
layer.bias = TorchParameter(zero_convolution_tensors[f"zero_convs.{i}.0.bias"])
def simplify_key(key: str, prefix: str, index: int | None = None) -> str:
"""Simplify a key by stripping everything to the left of the prefix.
Also optionally add a zero-padded index to the prefix.
Example:
>>> simplify_key("foo.bar.ControlLora.something", "ControlLora", 1)
"ControlLora_01.something"
>>> simplify_key("foo.bar.ControlLora.DownBlocks.something", "ControlLora")
"ControlLora.DownBlocks.something"
Args:
key: The key to simplify.
prefix: The prefix to remove.
index: The index to add.
"""
_, right = key.split(prefix, maxsplit=1)
if index:
return f"{prefix}_{index:02d}{right}"
else:
return f"{prefix}{right}"
def convert_lora_layers(
lora_layers: dict[str, Lora[Linear | Conv2d]],
control_lora: ControlLora,
refiners_state_dict: dict[str, Tensor],
) -> None:
"""Convert the LoRA layers to the refiners format.
Args:
lora_layers: The LoRA layers to convert.
control_lora: The ControlLora to convert the LoRA layers from.
refiners_state_dict: The refiners state dict to update with the converted LoRA layers.
"""
for lora_layer in lora_layers.values():
# get the adapter associated with the LoRA layer
lora_adapter = lora_layer.parent
assert isinstance(lora_adapter, LoraAdapter)
# get the path of the adapter's target in the ControlLora
target = lora_adapter.target
path = target.get_path(parent=control_lora.ensure_find_parent(target))
state_dict = {
f"{path}.down": lora_layer.down.weight,
f"{path}.up": lora_layer.up.weight,
}
state_dict = {simplify_key(key, "ControlLora."): param for key, param in state_dict.items()}
refiners_state_dict.update(state_dict)
def convert_zero_convolutions(
control_lora: ControlLora,
refiners_state_dict: dict[str, Tensor],
) -> None:
"""Convert the ZeroConvolution layers to the refiners format.
Args:
control_lora: The ControlLora to convert the ZeroConvolution layers from.
refiners_state_dict: The refiners state dict to update with the converted ZeroConvolution layers.
"""
zero_convolution_layers = list(control_lora.layers(ZeroConvolution))
for i, zero_convolution_layer in enumerate(zero_convolution_layers):
state_dict = zero_convolution_layer.state_dict()
path = zero_convolution_layer.get_path()
state_dict = {f"{path}.{key}": param for key, param in state_dict.items()}
state_dict = {simplify_key(key, "ZeroConvolution", i + 1): param for key, param in state_dict.items()}
refiners_state_dict.update(state_dict)
def convert_condition_encoder(
control_lora: ControlLora,
refiners_state_dict: dict[str, Tensor],
) -> None:
"""Convert the ConditionEncoder to the refiners format.
Args:
control_lora: The ControlLora to convert the ConditionEncoder from.
refiners_state_dict: The refiners state dict to update with the converted ConditionEncoder.
"""
condition_encoder_layer = control_lora.ensure_find(ConditionEncoder)
path = condition_encoder_layer.get_path()
state_dict = condition_encoder_layer.state_dict()
state_dict = {f"{path}.{key}": param for key, param in state_dict.items()}
state_dict = {simplify_key(key, "ConditionEncoder"): param for key, param in state_dict.items()}
refiners_state_dict.update(state_dict)
def convert(
name: str,
state_dict_path: Path,
output_path: Path,
) -> None:
sdxl = StableDiffusion_XL()
info("Stable Diffusion XL model initialized")
fooocus_state_dict = load_from_safetensors(state_dict_path)
info(f"Fooocus weights loaded from: {state_dict_path}")
control_lora_adapter = ControlLoraAdapter(target=sdxl.unet, name=name).inject()
control_lora = control_lora_adapter.control_lora
info("ControlLoraAdapter initialized")
lora_layers = load_lora_layers(name, fooocus_state_dict, control_lora)
info("LoRA layers loaded")
load_zero_convolutions(fooocus_state_dict, control_lora)
info("ZeroConvolution layers loaded")
load_condition_encoder(fooocus_state_dict, control_lora)
info("ConditionEncoder loaded")
refiners_state_dict: dict[str, Tensor] = {}
convert_lora_layers(lora_layers, control_lora, refiners_state_dict)
info("LoRA layers converted to refiners format")
convert_zero_convolutions(control_lora, refiners_state_dict)
info("ZeroConvolution layers converted to refiners format")
convert_condition_encoder(control_lora, refiners_state_dict)
info("ConditionEncoder converted to refiners format")
output_path.parent.mkdir(parents=True, exist_ok=True)
save_to_safetensors(path=output_path, tensors=refiners_state_dict)
info(f"Converted ControlLora state dict saved to disk at: {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert ControlLora (from Fooocus) weights to refiners.",
)
parser.add_argument(
"--from",
type=Path,
dest="source_path",
default="lllyasviel/misc:control-lora-canny-rank128.safetensors",
help="Path to the state_dict of the ControlLora, or a Hugging Face model ID.",
)
parser.add_argument(
"--to",
type=Path,
dest="output_path",
help=(
"Path to save the converted model (extension will be .safetensors)."
"If not specified, the output path will be the source path with the extension changed to .safetensors."
),
)
parser.add_argument(
"--verbose",
action="store_true",
dest="verbose",
default=False,
help="Use this flag to print verbose output during conversion.",
)
args = parser.parse_args()
if args.verbose:
logging.basicConfig(
level=logging.INFO,
format="%(levelname)s: %(message)s",
)
if not args.source_path.exists():
repo_id, filename = str(args.source_path).split(":")
args.source_path = Path(
hf_hub_download(
repo_id=repo_id,
filename=filename,
)
)
if args.output_path is None:
args.output_path = Path(f"refiners_{args.source_path.stem}.safetensors")
convert(
name=args.source_path.stem,
state_dict_path=args.source_path,
output_path=args.output_path,
)

View file

@ -1,81 +0,0 @@
import argparse
from torch import Tensor
from refiners.fluxion.utils import load_tensors, save_to_safetensors
def main() -> None:
parser = argparse.ArgumentParser(description="Convert HQ SAM model to Refiners state_dict format")
parser.add_argument(
"--from",
type=str,
dest="source_path",
required=True,
default="sam_hq_vit_h.pth",
help="Path to the source model checkpoint.",
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
required=True,
default="refiners_sam_hq_vit_h.safetensors",
help="Path to save the converted model in Refiners format.",
)
args = parser.parse_args()
source_state_dict = load_tensors(args.source_path)
state_dict: dict[str, Tensor] = {}
for suffix in ["weight", "bias"]:
state_dict[f"HQFeatures.CompressViTFeat.ConvTranspose2d_1.{suffix}"] = source_state_dict[
f"mask_decoder.compress_vit_feat.0.{suffix}"
]
state_dict[f"HQFeatures.EmbeddingEncoder.ConvTranspose2d_1.{suffix}"] = source_state_dict[
f"mask_decoder.embedding_encoder.0.{suffix}"
]
state_dict[f"EmbeddingMaskfeature.Conv2d_1.{suffix}"] = source_state_dict[
f"mask_decoder.embedding_maskfeature.0.{suffix}"
]
state_dict[f"HQFeatures.CompressViTFeat.LayerNorm2d.{suffix}"] = source_state_dict[
f"mask_decoder.compress_vit_feat.1.{suffix}"
]
state_dict[f"HQFeatures.EmbeddingEncoder.LayerNorm2d.{suffix}"] = source_state_dict[
f"mask_decoder.embedding_encoder.1.{suffix}"
]
state_dict[f"EmbeddingMaskfeature.LayerNorm2d.{suffix}"] = source_state_dict[
f"mask_decoder.embedding_maskfeature.1.{suffix}"
]
state_dict[f"HQFeatures.CompressViTFeat.ConvTranspose2d_2.{suffix}"] = source_state_dict[
f"mask_decoder.compress_vit_feat.3.{suffix}"
]
state_dict[f"HQFeatures.EmbeddingEncoder.ConvTranspose2d_2.{suffix}"] = source_state_dict[
f"mask_decoder.embedding_encoder.3.{suffix}"
]
state_dict[f"EmbeddingMaskfeature.Conv2d_2.{suffix}"] = source_state_dict[
f"mask_decoder.embedding_maskfeature.3.{suffix}"
]
state_dict = {f"Chain.HQSAMMaskPrediction.Chain.DenseEmbeddingUpscalingHQ.{k}": v for k, v in state_dict.items()}
# HQ Token
state_dict["MaskDecoderTokensExtender.hq_token.weight"] = source_state_dict["mask_decoder.hf_token.weight"]
# HQ MLP
for i in range(3):
state_dict[f"Chain.HQSAMMaskPrediction.HQTokenMLP.MultiLinear.Linear_{i+1}.weight"] = source_state_dict[
f"mask_decoder.hf_mlp.layers.{i}.weight"
]
state_dict[f"Chain.HQSAMMaskPrediction.HQTokenMLP.MultiLinear.Linear_{i+1}.bias"] = source_state_dict[
f"mask_decoder.hf_mlp.layers.{i}.bias"
]
save_to_safetensors(path=args.output_path, tensors=state_dict)
if __name__ == "__main__":
main()

View file

@ -1,89 +0,0 @@
import argparse
from pathlib import Path
from convert_diffusers_unet import Args as UNetArgs, setup_converter as setup_unet_converter
from huggingface_hub import hf_hub_download # type: ignore
from refiners.fluxion.utils import load_from_safetensors, save_to_safetensors
class Args(argparse.Namespace):
source_path: str
output_path: str | None
subfolder: str
half: bool
verbose: bool
reference_unet_path: str
def main() -> None:
parser = argparse.ArgumentParser(description="Converts IC-Light patch weights to work with Refiners")
parser.add_argument(
"--from",
type=str,
dest="source_path",
default="lllyasviel/ic-light",
help=(
"Can be a path to a .bin file, a .safetensors file or a model name from the Hugging Face Hub. Default:"
" lllyasviel/ic-light"
),
)
parser.add_argument("--filename", type=str, default="iclight_sd15_fc.safetensors", help="Filename inside the hub.")
parser.add_argument(
"--to",
type=str,
dest="output_path",
default=None,
help=(
"Output path (.safetensors) for converted model. If not provided, the output path will be the same as the"
" source path."
),
)
parser.add_argument(
"--verbose",
action="store_true",
default=False,
help="Prints additional information during conversion. Default: False",
)
parser.add_argument(
"--reference-unet-path",
type=str,
dest="reference_unet_path",
default="runwayml/stable-diffusion-v1-5",
help="Path to the reference UNet weights.",
)
args = parser.parse_args(namespace=Args())
if args.output_path is None:
args.output_path = f"{Path(args.filename).stem}-refiners.safetensors"
patch_file = (
Path(args.source_path)
if args.source_path.endswith(".safetensors")
else Path(
hf_hub_download(
repo_id=args.source_path,
filename=args.filename,
)
)
)
patch_weights = load_from_safetensors(patch_file)
unet_args = UNetArgs(
source_path=args.reference_unet_path,
subfolder="unet",
half=False,
verbose=False,
skip_init_check=True,
override_weights=None,
)
converter = setup_unet_converter(args=unet_args)
result = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage]
source_state_dict=patch_weights,
target_state_dict=converter.target_model.state_dict(),
state_dict_mapping=converter.get_mapping(),
)
save_to_safetensors(path=args.output_path, tensors=result)
if __name__ == "__main__":
main()

View file

@ -1,65 +0,0 @@
import argparse
from typing import cast
import torch
from torch import nn
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import load_tensors
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
try:
from model import Generator # type: ignore
except ImportError:
raise ImportError(
"Please download the model.py file from https://github.com/carolineec/informative-drawings and add it to your"
" PYTHONPATH"
)
class Args(argparse.Namespace):
source_path: str
output_path: str
verbose: bool
half: bool
def setup_converter(args: Args) -> ModelConverter:
source = cast(nn.Module, Generator(3, 1, 3))
source.load_state_dict(state_dict=load_tensors(args.source_path))
source.eval()
target = InformativeDrawings()
x = torch.randn(1, 3, 512, 512)
converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose)
if not converter.run(source_args=(x,)):
raise RuntimeError("Model conversion failed")
return converter
def main() -> None:
parser = argparse.ArgumentParser(
description="Converts a pretrained Informative Drawings model to a refiners Informative Drawings model"
)
parser.add_argument(
"--from",
type=str,
dest="source_path",
default="model2.pth",
help="Path to the source model. (default: 'model2.pth').",
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default="informative-drawings.safetensors",
help="Path to save the converted model. (default: 'informative-drawings.safetensors').",
)
parser.add_argument("--verbose", action="store_true", dest="verbose")
parser.add_argument("--half", action="store_true", dest="half")
args = parser.parse_args(namespace=Args())
converter = setup_converter(args=args)
converter.save_to_safetensors(path=args.output_path, half=args.half)
if __name__ == "__main__":
main()

View file

@ -1,40 +0,0 @@
import argparse
from pathlib import Path
from refiners.fluxion.utils import load_tensors, save_to_safetensors
from refiners.foundationals.swin.mvanet.converter import convert_weights
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--from",
type=str,
required=True,
dest="source_path",
help="A MVANet checkpoint. One can be found at https://github.com/qianyu-dlut/MVANet",
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default=None,
help=(
"Path to save the converted model. If not specified, the output path will be the source path with the"
" extension changed to .safetensors."
),
)
parser.add_argument("--half", action="store_true", dest="half")
args = parser.parse_args()
src_weights = load_tensors(args.source_path)
weights = convert_weights(src_weights)
if args.half:
weights = {key: value.half() for key, value in weights.items()}
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}.safetensors"
save_to_safetensors(path=args.output_path, tensors=weights)
if __name__ == "__main__":
main()

View file

@ -1,268 +0,0 @@
import argparse
import types
from typing import Any, Callable, cast
import torch
import torch.nn as nn
from segment_anything import build_sam_vit_h # type: ignore
from segment_anything.modeling.common import LayerNorm2d # type: ignore
from torch import Tensor
import refiners.fluxion.layers as fl
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import load_tensors, manual_seed, save_to_safetensors
from refiners.foundationals.segment_anything.image_encoder import PositionalEncoder, SAMViTH
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
class FacebookSAM(nn.Module):
image_encoder: nn.Module
prompt_encoder: nn.Module
mask_decoder: nn.Module
build_sam_vit_h = cast(Callable[[], FacebookSAM], build_sam_vit_h)
assert issubclass(LayerNorm2d, nn.Module)
custom_layers = {LayerNorm2d: fl.LayerNorm2d}
class Args(argparse.Namespace):
source_path: str
output_path: str
half: bool
verbose: bool
def convert_mask_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]:
manual_seed(seed=0)
refiners_mask_encoder = MaskEncoder()
converter = ModelConverter(
source_model=prompt_encoder.mask_downscaling,
target_model=refiners_mask_encoder,
custom_layer_mapping=custom_layers, # type: ignore
)
x = torch.randn(1, 256, 256)
mapping = converter.map_state_dicts(source_args=(x,))
assert mapping
source_state_dict = prompt_encoder.mask_downscaling.state_dict()
target_state_dict = refiners_mask_encoder.state_dict()
# Mapping handled manually (see below) because nn.Parameter is a special case
del target_state_dict["no_mask_embedding"]
converted_source = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage]
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
)
state_dict: dict[str, Tensor] = {
"no_mask_embedding": nn.Parameter(data=prompt_encoder.no_mask_embed.weight.clone()), # type: ignore
}
state_dict.update(converted_source)
refiners_mask_encoder.load_state_dict(state_dict=state_dict)
return state_dict
def convert_point_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]:
manual_seed(seed=0)
point_embeddings: list[Tensor] = [pe.weight for pe in prompt_encoder.point_embeddings] + [
prompt_encoder.not_a_point_embed.weight
] # type: ignore
pe = prompt_encoder.pe_layer.positional_encoding_gaussian_matrix # type: ignore
assert isinstance(pe, Tensor)
state_dict: dict[str, Tensor] = {
"Residual.PointTypeEmbedding.weight": nn.Parameter(data=torch.cat(tensors=point_embeddings, dim=0)),
"CoordinateEncoder.Linear.weight": nn.Parameter(data=pe.T.contiguous()),
}
refiners_prompt_encoder = PointEncoder()
refiners_prompt_encoder.load_state_dict(state_dict=state_dict)
return state_dict
def convert_vit(vit: nn.Module) -> dict[str, Tensor]:
manual_seed(seed=0)
refiners_sam_vit_h = SAMViTH()
converter = ModelConverter(
source_model=vit,
target_model=refiners_sam_vit_h,
custom_layer_mapping=custom_layers, # type: ignore
)
converter.skip_init_check = True
x = torch.randn(1, 3, 1024, 1024)
mapping = converter.map_state_dicts(source_args=(x,))
assert mapping
mapping["PositionalEncoder.Parameter.weight"] = "pos_embed"
target_state_dict = refiners_sam_vit_h.state_dict()
del target_state_dict["PositionalEncoder.Parameter.weight"]
source_state_dict = vit.state_dict()
pos_embed = source_state_dict["pos_embed"]
del source_state_dict["pos_embed"]
target_rel_keys = [
(
f"Transformer.TransformerLayer_{i}.Residual_1.FusedSelfAttention.RelativePositionAttention.horizontal_embedding",
f"Transformer.TransformerLayer_{i}.Residual_1.FusedSelfAttention.RelativePositionAttention.vertical_embedding",
)
for i in range(1, 33)
]
source_rel_keys = [(f"blocks.{i}.attn.rel_pos_w", f"blocks.{i}.attn.rel_pos_h") for i in range(32)]
rel_items: dict[str, Tensor] = {}
for (key_w, key_h), (target_key_w, target_key_h) in zip(source_rel_keys, target_rel_keys):
rel_items[target_key_w] = source_state_dict[key_w]
rel_items[target_key_h] = source_state_dict[key_h]
del source_state_dict[key_w]
del source_state_dict[key_h]
del target_state_dict[target_key_w]
del target_state_dict[target_key_h]
converted_source = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage]
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
)
positional_encoder = refiners_sam_vit_h.layer("PositionalEncoder", PositionalEncoder)
embed = pos_embed.reshape_as(positional_encoder.layer("Parameter", fl.Parameter).weight)
converted_source["PositionalEncoder.Parameter.weight"] = embed # type: ignore
converted_source.update(rel_items)
refiners_sam_vit_h.load_state_dict(state_dict=converted_source)
assert converter.compare_models((x,), threshold=1e-2)
return converted_source
def convert_mask_decoder(mask_decoder: nn.Module) -> dict[str, Tensor]:
manual_seed(seed=0)
refiners_mask_decoder = MaskDecoder()
image_embedding = torch.randn(1, 256, 64, 64)
dense_positional_embedding = torch.randn(1, 256, 64, 64)
point_embedding = torch.randn(1, 3, 256)
mask_embedding = torch.randn(1, 256, 64, 64)
from segment_anything.modeling.common import LayerNorm2d # type: ignore
import refiners.fluxion.layers as fl
assert issubclass(LayerNorm2d, nn.Module)
custom_layers = {LayerNorm2d: fl.LayerNorm2d}
converter = ModelConverter(
source_model=mask_decoder,
target_model=refiners_mask_decoder,
custom_layer_mapping=custom_layers, # type: ignore
)
inputs = {
"image_embeddings": image_embedding,
"image_pe": dense_positional_embedding,
"sparse_prompt_embeddings": point_embedding,
"dense_prompt_embeddings": mask_embedding,
"multimask_output": True,
}
refiners_mask_decoder.set_image_embedding(image_embedding)
refiners_mask_decoder.set_point_embedding(point_embedding)
refiners_mask_decoder.set_mask_embedding(mask_embedding)
refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
mapping = converter.map_state_dicts(source_args=inputs, target_args={})
assert mapping is not None
mapping["MaskDecoderTokens.Parameter"] = "iou_token"
state_dict = converter._convert_state_dict( # type: ignore
source_state_dict=mask_decoder.state_dict(),
target_state_dict=refiners_mask_decoder.state_dict(),
state_dict_mapping=mapping,
)
state_dict["MaskDecoderTokens.Parameter.weight"] = torch.cat(
tensors=[mask_decoder.iou_token.weight, mask_decoder.mask_tokens.weight], dim=0
) # type: ignore
refiners_mask_decoder.load_state_dict(state_dict=state_dict)
refiners_mask_decoder.set_image_embedding(image_embedding)
refiners_mask_decoder.set_point_embedding(point_embedding)
refiners_mask_decoder.set_mask_embedding(mask_embedding)
refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
# Perform (1) upscaling then (2) mask prediction in this order (= like in the official implementation) to make
# `compare_models` happy (MaskPrediction's Matmul runs those in the reverse order by default)
matmul = refiners_mask_decoder.ensure_find(fl.Matmul)
def forward_swapped_order(self: Any, *args: Any) -> Any:
y = self[1](*args) # (1)
x = self[0](*args) # (2)
return torch.matmul(input=x, other=y)
matmul.forward = types.MethodType(forward_swapped_order, matmul)
assert converter.compare_models(source_args=inputs, target_args={}, threshold=1e-3)
return state_dict
def main() -> None:
parser = argparse.ArgumentParser(description="Converts a Segment Anything ViT model to a Refiners SAMViTH model")
parser.add_argument(
"--from",
type=str,
dest="source_path",
default="sam_vit_h_4b8939.pth",
# required=True,
help="Path to the Segment Anything model weights",
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default="segment-anything-h.safetensors",
help="Output path for converted model (as safetensors).",
)
parser.add_argument("--half", action="store_true", default=False, help="Convert to half precision. Default: False")
parser.add_argument(
"--verbose",
action="store_true",
default=False,
help="Prints additional information during conversion. Default: False",
)
args = parser.parse_args(namespace=Args())
sam_h = build_sam_vit_h() # type: ignore
sam_h.load_state_dict(state_dict=load_tensors(args.source_path))
vit_state_dict = convert_vit(vit=sam_h.image_encoder)
mask_decoder_state_dict = convert_mask_decoder(mask_decoder=sam_h.mask_decoder)
point_encoder_state_dict = convert_point_encoder(prompt_encoder=sam_h.prompt_encoder)
mask_encoder_state_dict = convert_mask_encoder(prompt_encoder=sam_h.prompt_encoder)
output_state_dict = {
**{f"SAMViTH.{key}": value for key, value in vit_state_dict.items()},
**{f"MaskDecoder.{key}": value for key, value in mask_decoder_state_dict.items()},
**{f"PointEncoder.{key}": value for key, value in point_encoder_state_dict.items()},
**{f"MaskEncoder.{key}": value for key, value in mask_encoder_state_dict.items()},
}
if args.half:
output_state_dict = {key: value.half() for key, value in output_state_dict.items()}
save_to_safetensors(path=args.output_path, tensors=output_state_dict)
if __name__ == "__main__":
main()

View file

@ -1,149 +0,0 @@
import argparse
from pathlib import Path
from typing import NamedTuple, cast
import torch
from torch import nn
from transformers import CLIPVisionModelWithProjection # type: ignore
import refiners.fluxion.layers as fl
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import save_to_safetensors
from refiners.foundationals.clip.image_encoder import CLIPImageEncoder
class Args(argparse.Namespace):
source_path: str
subfolder: str
output_path: str | None
half: bool
verbose: bool
threshold: float
class CLIPImageEncoderConfig(NamedTuple):
architectures: list[str]
num_channels: int
hidden_size: int
hidden_act: str
image_size: int
projection_dim: int
patch_size: int
num_hidden_layers: int
num_attention_heads: int
intermediate_size: int
layer_norm_eps: float
def setup_converter(args: Args) -> ModelConverter:
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
source: nn.Module = CLIPVisionModelWithProjection.from_pretrained( # type: ignore
pretrained_model_name_or_path=args.source_path,
subfolder=args.subfolder,
low_cpu_mem_usage=False,
)
assert isinstance(source, nn.Module), "Source model is not a nn.Module"
config = cast(CLIPImageEncoderConfig, source.config) # pyright: ignore[reportArgumentType, reportUnknownMemberType]
assert (
config.architectures[0] == "CLIPVisionModelWithProjection"
), f"Unsupported architecture: {config.architectures[0]}"
assert config.num_channels == 3, f"Expected 3 input channels, got {config.num_channels}"
assert config.hidden_act == "gelu", f"Unsupported activation: {config.hidden_act}"
target = CLIPImageEncoder(
image_size=config.image_size,
embedding_dim=config.hidden_size,
output_dim=config.projection_dim,
patch_size=config.patch_size,
num_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
feedforward_dim=config.intermediate_size,
layer_norm_eps=config.layer_norm_eps,
)
x = torch.randn(1, 3, config.image_size, config.image_size)
converter = ModelConverter(source_model=source, target_model=target, verbose=True)
# Custom conversion logic since the class embedding (fl.Parameter layer) is not supported out-of-the-box by the
# converter
mapping = converter.map_state_dicts((x,))
assert mapping is not None
source_state_dict = source.state_dict()
target_state_dict = target.state_dict()
# Remove the class embedding from state dict since it was not mapped by the model converter
class_embedding = target.ensure_find(fl.Parameter)
class_embedding_key = next((n for n, p in target.named_parameters() if id(p) == id(class_embedding.weight)), None)
assert class_embedding_key is not None
assert class_embedding_key in target_state_dict
del target_state_dict[class_embedding_key]
converted_state_dict = converter._convert_state_dict( # type: ignore[reportPrivateUsage]
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
)
target.load_state_dict(state_dict=converted_state_dict, strict=False)
# Ad hoc post-conversion steps
embed = source.vision_model.embeddings.class_embedding
class_embedding.weight = torch.nn.Parameter(embed.clone().reshape_as(class_embedding.weight)) # type: ignore
assert converter.compare_models((x,), threshold=args.threshold)
return converter
def main() -> None:
parser = argparse.ArgumentParser(
description="Converts a CLIPImageEncoder from the library transformers from the HuggingFace Hub to refiners."
)
parser.add_argument(
"--from",
type=str,
dest="source_path",
default="stabilityai/stable-diffusion-2-1-unclip",
help=(
"Can be a path to a .bin file, a .safetensors file or a model name from the HuggingFace Hub. Default:"
" stabilityai/stable-diffusion-2-1-unclip"
),
)
parser.add_argument(
"--subfolder",
type=str,
dest="subfolder",
default="image_encoder",
help="Subfolder in the source path where the model is located inside the Hub. Default: image_encoder",
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default=None,
help=(
"Output path (.safetensors) for converted model. If not provided, the output path will be the same as the"
" source path."
),
)
parser.add_argument("--half", action="store_true", help="Convert to half precision.")
parser.add_argument(
"--verbose",
action="store_true",
default=False,
help="Prints additional information during conversion. Default: False",
)
parser.add_argument("--threshold", type=float, default=1e-2, help="Threshold for model comparison. Default: 1e-2")
args = parser.parse_args(namespace=Args())
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}-{args.subfolder}.safetensors"
converter = setup_converter(args=args)
# Do not use converter.save_to_safetensors since it is not in a valid state due to the ad hoc conversion
state_dict = converter.target_model.state_dict()
if args.half:
state_dict = {key: value.half() for key, value in state_dict.items()}
save_to_safetensors(path=args.output_path, tensors=state_dict)
if __name__ == "__main__":
main()

View file

@ -1,150 +0,0 @@
import argparse
from pathlib import Path
from typing import NamedTuple, cast
from torch import nn
from transformers import CLIPTextModel, CLIPTextModelWithProjection # type: ignore
import refiners.fluxion.layers as fl
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import save_to_safetensors
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, CLIPTextEncoderG, CLIPTextEncoderL
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
class Args(argparse.Namespace):
source_path: str
subfolder: str
output_path: str | None
half: bool
verbose: bool
class CLIPTextEncoderConfig(NamedTuple):
architectures: list[str]
vocab_size: int
hidden_size: int
intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
hidden_act: str
layer_norm_eps: float
projection_dim: int
def setup_converter(args: Args, with_projection: bool = False) -> ModelConverter:
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
cls = CLIPTextModelWithProjection if with_projection else CLIPTextModel
source: nn.Module = cls.from_pretrained( # type: ignore
pretrained_model_name_or_path=args.source_path,
subfolder=args.subfolder,
low_cpu_mem_usage=False,
)
assert isinstance(source, nn.Module), "Source model is not a nn.Module"
config = cast(CLIPTextEncoderConfig, source.config) # pyright: ignore[reportArgumentType, reportUnknownMemberType]
architecture: str = config.architectures[0]
embedding_dim: int = config.hidden_size
projection_dim: int = config.projection_dim
use_quick_gelu = config.hidden_act == "quick_gelu"
assert architecture in ("CLIPTextModel", "CLIPTextModelWithProjection"), f"Unsupported architecture: {architecture}"
target = CLIPTextEncoder(
embedding_dim=config.hidden_size,
num_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
feedforward_dim=config.intermediate_size,
use_quick_gelu=use_quick_gelu,
)
if architecture == "CLIPTextModelWithProjection":
target.append(module=fl.Linear(in_features=embedding_dim, out_features=projection_dim, bias=False))
text = "What a nice cat you have there!"
tokenizer = target.ensure_find(CLIPTokenizer)
tokens = tokenizer(text)
converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose)
if not converter.run(source_args=(tokens,), target_args=(text,)):
raise RuntimeError("Model conversion failed")
return converter
def main() -> None:
parser = argparse.ArgumentParser(
description="Converts a CLIPTextEncoder from the library transformers from the HuggingFace Hub to refiners."
)
parser.add_argument(
"--from",
type=str,
dest="source_path",
default="runwayml/stable-diffusion-v1-5",
help=(
"Can be a path to a .bin file, a .safetensors file or a model name from the HuggingFace Hub. Default:"
" runwayml/stable-diffusion-v1-5"
),
)
parser.add_argument(
"--subfolder",
type=str,
dest="subfolder",
default="text_encoder",
help=(
"Subfolder in the source path where the model is located inside the Hub. Default: text_encoder (for"
" CLIPTextModel)"
),
)
parser.add_argument(
"--subfolder2",
type=str,
dest="subfolder2",
default=None,
help="Additional subfolder for the 2nd text encoder (useful for SDXL). Default: None",
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default=None,
help=(
"Output path (.safetensors) for converted model. If not provided, the output path will be the same as the"
" source path."
),
)
parser.add_argument("--half", action="store_true", help="Convert to half precision.")
parser.add_argument(
"--verbose",
action="store_true",
default=False,
help="Prints additional information during conversion. Default: False",
)
args = parser.parse_args(namespace=Args())
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}-{args.subfolder}.safetensors"
converter = setup_converter(args=args)
if args.subfolder2 is not None:
# Assume this is the second text encoder of Stable Diffusion XL
args.subfolder = args.subfolder2
converter2 = setup_converter(args=args, with_projection=True)
text_encoder_l = CLIPTextEncoderL()
text_encoder_l.load_state_dict(state_dict=converter.get_state_dict())
projection = cast(CLIPTextEncoder, converter2.target_model)[-1]
assert isinstance(projection, fl.Linear)
text_encoder_g_with_projection = CLIPTextEncoderG()
text_encoder_g_with_projection.append(module=projection)
text_encoder_g_with_projection.load_state_dict(state_dict=converter2.get_state_dict())
projection = text_encoder_g_with_projection.pop(index=-1)
assert isinstance(projection, fl.Linear)
double_text_encoder = DoubleTextEncoder(
text_encoder_l=text_encoder_l, text_encoder_g=text_encoder_g_with_projection, projection=projection
)
state_dict = double_text_encoder.state_dict()
if args.half:
state_dict = {key: value.half() for key, value in state_dict.items()}
save_to_safetensors(path=args.output_path, tensors=state_dict)
else:
converter.save_to_safetensors(path=args.output_path, half=args.half)
if __name__ == "__main__":
main()

View file

@ -1,945 +0,0 @@
"""
Download and convert weights for testing
To see what weights will be downloaded and converted, run:
DRY_RUN=1 python scripts/prepare_test_weights.py
"""
import hashlib
import os
import subprocess
import sys
from urllib.parse import urlparse
import gdown
import requests
from tqdm import tqdm
# Set the base directory to the parent directory of the script
project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
test_weights_dir = os.path.join(project_dir, "tests", "weights")
previous_line = "\033[F"
download_count = 0
bytes_count = 0
def die(message: str) -> None:
print(message, file=sys.stderr)
sys.exit(1)
def rel(path: str) -> str:
return os.path.relpath(path, project_dir)
def calc_hash(filepath: str) -> str:
with open(filepath, "rb") as f:
data = f.read()
found = hashlib.blake2b(data, digest_size=int(32 / 8)).hexdigest()
return found
def check_hash(path: str, expected: str) -> str:
found = calc_hash(path)
if found != expected:
die(f"❌ Invalid hash for {path} ({found} != {expected})")
return found
def download_file(
url: str,
dest_folder: str,
dry_run: bool | None = None,
skip_existing: bool = True,
expected_hash: str | None = None,
filename: str | None = None,
):
"""
Downloads a file
Features:
- shows a progress bar
- skips existing files
- uses a temporary file to prevent partial downloads
- can do a dry run to check the url is valid
- displays the downloaded file hash
"""
global download_count, bytes_count
filename = os.path.basename(urlparse(url).path) if filename is None else filename
dest_filename = os.path.join(dest_folder, filename)
temp_filename = dest_filename + ".part"
dry_run = bool(os.environ.get("DRY_RUN") == "1") if dry_run is None else dry_run
is_downloaded = os.path.exists(dest_filename)
if is_downloaded and skip_existing:
skip_icon = "✖️ "
else:
skip_icon = "🔽"
if dry_run:
response = requests.head(url, allow_redirects=True)
readable_size = ""
if response.status_code == 200:
content_length = response.headers.get("content-length")
if content_length:
size_in_bytes = int(content_length)
readable_size = human_readable_size(size_in_bytes)
download_count += 1
bytes_count += size_in_bytes
print(f"{skip_icon} {response.status_code} READY {readable_size:<8} {url}")
else:
print(f"{skip_icon} {response.status_code} ERROR {readable_size:<8} {url}")
return
if skip_existing and is_downloaded:
print(f"{skip_icon} Skipping previously downloaded {url}")
if expected_hash is not None:
check_hash(dest_filename, expected_hash)
return
os.makedirs(dest_folder, exist_ok=True)
print(f"🔽 Downloading {url} => '{rel(dest_filename)}'", end="\n")
response = requests.get(url, stream=True)
if response.status_code != 200:
print(response.content[:1000])
die(f"Failed to download {url}. Status code: {response.status_code}")
total = int(response.headers.get("content-length", 0))
bar = tqdm(
desc=filename,
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
leave=False,
)
with open(temp_filename, "wb") as f, bar:
for data in response.iter_content(chunk_size=1024 * 1000):
size = f.write(data)
bar.update(size)
os.rename(temp_filename, dest_filename)
calculated_hash = calc_hash(dest_filename)
print(f"{previous_line}✅ Downloaded {calculated_hash} {url} => '{rel(dest_filename)}' ")
if expected_hash is not None:
check_hash(dest_filename, expected_hash)
def download_files(urls: list[str], dest_folder: str):
for url in urls:
download_file(url, dest_folder)
def human_readable_size(size: int | float, decimal_places: int = 2) -> str:
for unit in ["B", "KB", "MB", "GB", "TB", "PB"]:
if size < 1024.0:
break
size /= 1024.0
return f"{size:.{decimal_places}f}{unit}" # type: ignore
def download_sd_text_encoder(hf_repo_id: str = "runwayml/stable-diffusion-v1-5", subdir: str = "text_encoder"):
encoder_filename = "model.safetensors" if "inpainting" not in hf_repo_id else "model.fp16.safetensors"
base_url = f"https://huggingface.co/{hf_repo_id}"
download_files(
urls=[
f"{base_url}/raw/main/{subdir}/config.json",
f"{base_url}/resolve/main/{subdir}/{encoder_filename}",
],
dest_folder=os.path.join(test_weights_dir, hf_repo_id, subdir),
)
def download_sd_tokenizer(hf_repo_id: str = "runwayml/stable-diffusion-v1-5", subdir: str = "tokenizer"):
download_files(
urls=[
f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/merges.txt",
f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/special_tokens_map.json",
f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/tokenizer_config.json",
f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/vocab.json",
],
dest_folder=os.path.join(test_weights_dir, hf_repo_id, subdir),
)
def download_sd_base(hf_repo_id: str = "runwayml/stable-diffusion-v1-5"):
is_inpainting = "inpainting" in hf_repo_id
ext = "safetensors" if not is_inpainting else "bin"
base_folder = os.path.join(test_weights_dir, hf_repo_id)
download_file(f"https://huggingface.co/{hf_repo_id}/raw/main/model_index.json", base_folder)
download_file(
f"https://huggingface.co/{hf_repo_id}/raw/main/scheduler/scheduler_config.json",
os.path.join(base_folder, "scheduler"),
)
for subdir in ["unet", "vae"]:
subdir_folder = os.path.join(base_folder, subdir)
download_file(f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/config.json", subdir_folder)
download_file(
f"https://huggingface.co/{hf_repo_id}/resolve/main/{subdir}/diffusion_pytorch_model.{ext}", subdir_folder
)
# we only need the unet for the inpainting model
if not is_inpainting:
download_sd_text_encoder(hf_repo_id, "text_encoder")
download_sd_tokenizer(hf_repo_id, "tokenizer")
def download_sd15(hf_repo_id: str = "runwayml/stable-diffusion-v1-5"):
download_sd_base(hf_repo_id)
base_folder = os.path.join(test_weights_dir, hf_repo_id)
subdir = "feature_extractor"
download_file(
f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/preprocessor_config.json",
os.path.join(base_folder, subdir),
)
if "inpainting" not in hf_repo_id:
subdir = "safety_checker"
subdir_folder = os.path.join(base_folder, subdir)
download_file(f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/config.json", subdir_folder)
download_file(f"https://huggingface.co/{hf_repo_id}/resolve/main/{subdir}/model.safetensors", subdir_folder)
def download_sdxl(hf_repo_id: str = "stabilityai/stable-diffusion-xl-base-1.0"):
download_sd_base(hf_repo_id)
download_sd_text_encoder(hf_repo_id, "text_encoder_2")
download_sd_tokenizer(hf_repo_id, "tokenizer_2")
def download_vae_fp16_fix():
download_files(
urls=[
"https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/raw/main/config.json",
"https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/diffusion_pytorch_model.safetensors",
],
dest_folder=os.path.join(test_weights_dir, "madebyollin", "sdxl-vae-fp16-fix"),
)
def download_vae_ft_mse():
download_files(
urls=[
"https://huggingface.co/stabilityai/sd-vae-ft-mse/raw/main/config.json",
"https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/diffusion_pytorch_model.safetensors",
],
dest_folder=os.path.join(test_weights_dir, "stabilityai", "sd-vae-ft-mse"),
)
def download_loras():
dest_folder = os.path.join(test_weights_dir, "loras", "pokemon-lora")
download_file(
"https://huggingface.co/pcuenq/pokemon-lora/resolve/main/pytorch_lora_weights.bin",
dest_folder,
expected_hash="89992ea6",
)
dest_folder = os.path.join(test_weights_dir, "loras", "dpo-lora")
download_file(
"https://huggingface.co/radames/sdxl-DPO-LoRA/resolve/main/pytorch_lora_weights.safetensors",
dest_folder,
expected_hash="a51e9144",
)
dest_folder = os.path.join(test_weights_dir, "loras", "sliders")
download_file("https://sliders.baulab.info/weights/xl_sliders/age.pt", dest_folder, expected_hash="908f07d3")
download_file(
"https://sliders.baulab.info/weights/xl_sliders/cartoon_style.pt", dest_folder, expected_hash="25652004"
)
download_file("https://sliders.baulab.info/weights/xl_sliders/eyesize.pt", dest_folder, expected_hash="ee170e4d")
dest_folder = os.path.join(test_weights_dir, "loras")
download_file(
"https://civitai.com/api/download/models/140624",
filename="Sci-fi_Environments_sdxl.safetensors",
dest_folder=dest_folder,
expected_hash="6a4afda8",
)
download_file(
"https://civitai.com/api/download/models/135931",
filename="pixel-art-xl-v1.1.safetensors",
dest_folder=dest_folder,
expected_hash="71aaa6ca",
)
def download_preprocessors():
dest_folder = os.path.join(test_weights_dir, "carolineec", "informativedrawings")
download_file("https://huggingface.co/spaces/carolineec/informativedrawings/resolve/main/model2.pth", dest_folder)
def download_controlnet():
base_folder = os.path.join(test_weights_dir, "lllyasviel")
controlnets = [
"control_v11p_sd15_canny",
"control_v11f1p_sd15_depth",
"control_v11p_sd15_normalbae",
"control_v11p_sd15_lineart",
]
for net in controlnets:
net_folder = os.path.join(base_folder, net)
urls = [
f"https://huggingface.co/lllyasviel/{net}/raw/main/config.json",
f"https://huggingface.co/lllyasviel/{net}/resolve/main/diffusion_pytorch_model.safetensors",
]
download_files(urls, net_folder)
tile_folder = os.path.join(base_folder, "control_v11f1e_sd15_tile")
urls = [
"https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile/raw/main/config.json",
"https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile/resolve/main/diffusion_pytorch_model.bin",
]
download_files(urls, tile_folder)
mfidabel_folder = os.path.join(test_weights_dir, "mfidabel", "controlnet-segment-anything")
urls = [
"https://huggingface.co/mfidabel/controlnet-segment-anything/raw/main/config.json",
"https://huggingface.co/mfidabel/controlnet-segment-anything/resolve/main/diffusion_pytorch_model.bin",
]
download_files(urls, mfidabel_folder)
def download_control_lora_fooocus():
base_folder = os.path.join(test_weights_dir, "lllyasviel", "misc")
download_file(
url=f"https://huggingface.co/lllyasviel/misc/resolve/main/control-lora-canny-rank128.safetensors",
dest_folder=base_folder,
expected_hash="fec9e32b",
)
download_file(
url=f"https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_xl_cpds_128.safetensors",
dest_folder=base_folder,
expected_hash="fc04b120",
)
def download_unclip():
base_folder = os.path.join(test_weights_dir, "stabilityai", "stable-diffusion-2-1-unclip")
download_file(
"https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/raw/main/model_index.json", base_folder
)
image_encoder_folder = os.path.join(base_folder, "image_encoder")
urls = [
"https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/raw/main/image_encoder/config.json",
"https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/resolve/main/image_encoder/model.safetensors",
]
download_files(urls, image_encoder_folder)
def download_ip_adapter():
base_folder = os.path.join(test_weights_dir, "h94", "IP-Adapter")
models_folder = os.path.join(base_folder, "models")
urls = [
"https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter_sd15.bin",
"https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-plus_sd15.bin",
]
download_files(urls, models_folder)
sdxl_models_folder = os.path.join(base_folder, "sdxl_models")
urls = [
"https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter_sdxl_vit-h.bin",
"https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin",
]
download_files(urls, sdxl_models_folder)
def download_t5xl_fp16():
base_folder = os.path.join(test_weights_dir, "QQGYLab", "T5XLFP16")
urls = [
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/config.json",
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/model.safetensors",
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/special_tokens_map.json",
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/spiece.model",
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/tokenizer.json",
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/tokenizer_config.json",
]
download_files(urls, base_folder)
def download_ella_adapter():
download_t5xl_fp16()
base_folder = os.path.join(test_weights_dir, "QQGYLab", "ELLA")
download_file(
"https://huggingface.co/QQGYLab/ELLA/resolve/main/ella-sd1.5-tsc-t5xl.safetensors",
base_folder,
expected_hash="5af7b200",
)
def download_t2i_adapter():
base_folder = os.path.join(test_weights_dir, "TencentARC", "t2iadapter_depth_sd15v2")
urls = [
"https://huggingface.co/TencentARC/t2iadapter_depth_sd15v2/raw/main/config.json",
"https://huggingface.co/TencentARC/t2iadapter_depth_sd15v2/resolve/main/diffusion_pytorch_model.bin",
]
download_files(urls, base_folder)
canny_sdxl_folder = os.path.join(test_weights_dir, "TencentARC", "t2i-adapter-canny-sdxl-1.0")
urls = [
"https://huggingface.co/TencentARC/t2i-adapter-canny-sdxl-1.0/raw/main/config.json",
"https://huggingface.co/TencentARC/t2i-adapter-canny-sdxl-1.0/resolve/main/diffusion_pytorch_model.safetensors",
]
download_files(urls, canny_sdxl_folder)
def download_sam():
weights_folder = os.path.join(test_weights_dir)
download_file(
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", weights_folder, expected_hash="06785e66"
)
def download_hq_sam():
weights_folder = os.path.join(test_weights_dir)
download_file(
"https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth", weights_folder, expected_hash="66da2472"
)
def download_dinov2():
# For conversion
weights_folder = os.path.join(test_weights_dir)
urls = [
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth",
]
download_files(urls, weights_folder)
def download_lcm_base():
base_folder = os.path.join(test_weights_dir, "latent-consistency/lcm-sdxl")
download_file(f"https://huggingface.co/latent-consistency/lcm-sdxl/raw/main/config.json", base_folder)
download_file(
f"https://huggingface.co/latent-consistency/lcm-sdxl/resolve/main/diffusion_pytorch_model.safetensors",
base_folder,
)
def download_lcm_lora():
download_file(
"https://huggingface.co/latent-consistency/lcm-lora-sdxl/resolve/main/pytorch_lora_weights.safetensors",
dest_folder=test_weights_dir,
filename="sdxl-lcm-lora.safetensors",
expected_hash="6312a30a",
)
def download_sdxl_lightning_base():
base_folder = os.path.join(test_weights_dir, "ByteDance/SDXL-Lightning")
download_file(
f"https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_4step_unet.safetensors",
base_folder,
expected_hash="1b76cca3",
)
download_file(
f"https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_1step_unet_x0.safetensors",
base_folder,
expected_hash="38e605bd",
)
def download_sdxl_lightning_lora():
download_file(
"https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_4step_lora.safetensors",
dest_folder=test_weights_dir,
expected_hash="9783edac",
)
def download_ic_light():
download_file(
"https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors",
dest_folder=test_weights_dir,
expected_hash="bce70123",
)
def download_mvanet():
fn = "Model_80.pth"
dest_folder = os.path.join(test_weights_dir, "mvanet")
dest_filename = os.path.join(dest_folder, fn)
if os.environ.get("DRY_RUN") == "1":
return
if os.path.exists(dest_filename):
print(f"✖️ Skipping previously downloaded mvanet/{fn}")
else:
os.makedirs(dest_folder, exist_ok=True)
print(f"🔽 Downloading mvanet/{fn} => '{rel(dest_filename)}'", end="\n")
gdown.download(id="1_gabQXOF03MfXnf3EWDK1d_8wKiOemOv", output=dest_filename, quiet=True)
print(f"{previous_line}✅ Downloaded mvanet/{fn} => '{rel(dest_filename)}' ")
check_hash(dest_filename, "b915d492")
def download_box_segmenter():
download_file(
"https://huggingface.co/finegrain/finegrain-box-segmenter/resolve/v0.1/model.safetensors",
dest_folder=test_weights_dir,
filename="finegrain-box-segmenter-v0-1.safetensors",
expected_hash="e0450e8c",
)
def printg(msg: str):
"""print in green color"""
print("\033[92m" + msg + "\033[0m")
def run_conversion_script(
script_filename: str,
from_weights: str,
to_weights: str,
half: bool = False,
expected_hash: str | None = None,
additional_args: list[str] | None = None,
skip_existing: bool = True,
):
if skip_existing and expected_hash and os.path.exists(to_weights):
found_hash = check_hash(to_weights, expected_hash)
if expected_hash == found_hash:
printg(f"✖️ Skipping converted from {from_weights} to {to_weights} (hash {found_hash} confirmed) ")
return
msg = f"Converting {from_weights} to {to_weights}"
printg(msg)
args = ["python", f"scripts/conversion/{script_filename}", "--from", from_weights, "--to", to_weights]
if half:
args.append("--half")
if additional_args:
args.extend(additional_args)
subprocess.run(args, check=True)
if expected_hash is not None:
found_hash = check_hash(to_weights, expected_hash)
printg(f"✅ Converted from {from_weights} to {to_weights} (hash {found_hash} confirmed) ")
else:
printg(f"✅⚠️ Converted from {from_weights} to {to_weights} (no hash check performed)")
def convert_sd15():
run_conversion_script(
script_filename="convert_transformers_clip_text_model.py",
from_weights="tests/weights/runwayml/stable-diffusion-v1-5",
to_weights="tests/weights/CLIPTextEncoderL.safetensors",
half=True,
expected_hash="6c9cbc59",
)
run_conversion_script(
"convert_diffusers_autoencoder_kl.py",
"tests/weights/runwayml/stable-diffusion-v1-5",
"tests/weights/lda.safetensors",
expected_hash="329e369c",
)
run_conversion_script(
"convert_diffusers_unet.py",
"tests/weights/runwayml/stable-diffusion-v1-5",
"tests/weights/unet.safetensors",
half=True,
expected_hash="f81ac65a",
)
os.makedirs("tests/weights/inpainting", exist_ok=True)
run_conversion_script(
"convert_diffusers_unet.py",
"tests/weights/runwayml/stable-diffusion-inpainting",
"tests/weights/inpainting/unet.safetensors",
half=True,
expected_hash="c07a8c61",
)
def convert_sdxl():
run_conversion_script(
"convert_transformers_clip_text_model.py",
"tests/weights/stabilityai/stable-diffusion-xl-base-1.0",
"tests/weights/DoubleCLIPTextEncoder.safetensors",
half=True,
expected_hash="7f99c30b",
additional_args=["--subfolder2", "text_encoder_2"],
)
run_conversion_script(
"convert_diffusers_autoencoder_kl.py",
"tests/weights/stabilityai/stable-diffusion-xl-base-1.0",
"tests/weights/sdxl-lda.safetensors",
half=True,
expected_hash="7464e9dc",
)
run_conversion_script(
"convert_diffusers_unet.py",
"tests/weights/stabilityai/stable-diffusion-xl-base-1.0",
"tests/weights/sdxl-unet.safetensors",
half=True,
expected_hash="2e5c4911",
)
def convert_vae_ft_mse():
run_conversion_script(
"convert_diffusers_autoencoder_kl.py",
"tests/weights/stabilityai/sd-vae-ft-mse",
"tests/weights/lda_ft_mse.safetensors",
half=True,
expected_hash="4d0bae7e",
)
def convert_vae_fp16_fix():
run_conversion_script(
"convert_diffusers_autoencoder_kl.py",
"tests/weights/madebyollin/sdxl-vae-fp16-fix",
"tests/weights/sdxl-lda-fp16-fix.safetensors",
additional_args=["--subfolder", "''"],
half=True,
expected_hash="98c7e998",
)
def convert_preprocessors():
subprocess.run(
[
"curl",
"-L",
"https://raw.githubusercontent.com/carolineec/informative-drawings/main/model.py",
"-o",
"src/model.py",
],
check=True,
)
run_conversion_script(
"convert_informative_drawings.py",
"tests/weights/carolineec/informativedrawings/model2.pth",
"tests/weights/informative-drawings.safetensors",
expected_hash="93dca207",
)
os.remove("src/model.py")
def convert_controlnet():
os.makedirs("tests/weights/controlnet", exist_ok=True)
run_conversion_script(
"convert_diffusers_controlnet.py",
"tests/weights/lllyasviel/control_v11p_sd15_canny",
"tests/weights/controlnet/lllyasviel_control_v11p_sd15_canny.safetensors",
expected_hash="9a1a48cf",
)
run_conversion_script(
"convert_diffusers_controlnet.py",
"tests/weights/lllyasviel/control_v11f1p_sd15_depth",
"tests/weights/controlnet/lllyasviel_control_v11f1p_sd15_depth.safetensors",
expected_hash="bbe7e5a6",
)
run_conversion_script(
"convert_diffusers_controlnet.py",
"tests/weights/lllyasviel/control_v11p_sd15_normalbae",
"tests/weights/controlnet/lllyasviel_control_v11p_sd15_normalbae.safetensors",
expected_hash="9fa88ed5",
)
run_conversion_script(
"convert_diffusers_controlnet.py",
"tests/weights/lllyasviel/control_v11p_sd15_lineart",
"tests/weights/controlnet/lllyasviel_control_v11p_sd15_lineart.safetensors",
expected_hash="c29e8c03",
)
run_conversion_script(
"convert_diffusers_controlnet.py",
"tests/weights/mfidabel/controlnet-segment-anything",
"tests/weights/controlnet/mfidabel_controlnet-segment-anything.safetensors",
expected_hash="d536eebb",
)
run_conversion_script(
"convert_diffusers_controlnet.py",
"tests/weights/lllyasviel/control_v11f1e_sd15_tile",
"tests/weights/controlnet/lllyasviel_control_v11f1e_sd15_tile.safetensors",
expected_hash="42463af8",
)
def convert_unclip():
run_conversion_script(
"convert_transformers_clip_image_model.py",
"tests/weights/stabilityai/stable-diffusion-2-1-unclip",
"tests/weights/CLIPImageEncoderH.safetensors",
half=True,
expected_hash="4ddb44d2",
)
def convert_ip_adapter():
run_conversion_script(
"convert_diffusers_ip_adapter.py",
"tests/weights/h94/IP-Adapter/models/ip-adapter_sd15.bin",
"tests/weights/ip-adapter_sd15.safetensors",
expected_hash="3fb0472e",
)
run_conversion_script(
"convert_diffusers_ip_adapter.py",
"tests/weights/h94/IP-Adapter/sdxl_models/ip-adapter_sdxl_vit-h.bin",
"tests/weights/ip-adapter_sdxl_vit-h.safetensors",
half=True,
expected_hash="860518fe",
)
run_conversion_script(
"convert_diffusers_ip_adapter.py",
"tests/weights/h94/IP-Adapter/models/ip-adapter-plus_sd15.bin",
"tests/weights/ip-adapter-plus_sd15.safetensors",
half=True,
expected_hash="aba8503b",
)
run_conversion_script(
"convert_diffusers_ip_adapter.py",
"tests/weights/h94/IP-Adapter/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin",
"tests/weights/ip-adapter-plus_sdxl_vit-h.safetensors",
half=True,
expected_hash="545d5ce7",
)
def convert_ella_adapter():
os.makedirs("tests/weights/ELLA-Adapter", exist_ok=True)
run_conversion_script(
"convert_ella_adapter.py",
"tests/weights/QQGYLab/ELLA/ella-sd1.5-tsc-t5xl.safetensors",
"tests/weights/ELLA-Adapter/ella-sd1.5-tsc-t5xl.safetensors",
half=True,
expected_hash="b8244cb6",
)
def convert_t2i_adapter():
os.makedirs("tests/weights/T2I-Adapter", exist_ok=True)
run_conversion_script(
"convert_diffusers_t2i_adapter.py",
"tests/weights/TencentARC/t2iadapter_depth_sd15v2",
"tests/weights/T2I-Adapter/t2iadapter_depth_sd15v2.safetensors",
half=True,
expected_hash="bb2b3115",
)
run_conversion_script(
"convert_diffusers_t2i_adapter.py",
"tests/weights/TencentARC/t2i-adapter-canny-sdxl-1.0",
"tests/weights/T2I-Adapter/t2i-adapter-canny-sdxl-1.0.safetensors",
half=True,
expected_hash="f07249a6",
)
def convert_sam():
run_conversion_script(
"convert_segment_anything.py",
"tests/weights/sam_vit_h_4b8939.pth",
"tests/weights/segment-anything-h.safetensors",
expected_hash="5ffb976f",
)
def convert_hq_sam():
run_conversion_script(
"convert_hq_segment_anything.py",
"tests/weights/sam_hq_vit_h.pth",
"tests/weights/refiners-sam-hq-vit-h.safetensors",
expected_hash="b2f5e79f",
)
def convert_dinov2():
run_conversion_script(
"convert_dinov2.py",
"tests/weights/dinov2_vits14_pretrain.pth",
"tests/weights/dinov2_vits14_pretrain.safetensors",
expected_hash="af000ded",
)
run_conversion_script(
"convert_dinov2.py",
"tests/weights/dinov2_vitb14_pretrain.pth",
"tests/weights/dinov2_vitb14_pretrain.safetensors",
expected_hash="d6294087",
)
run_conversion_script(
"convert_dinov2.py",
"tests/weights/dinov2_vitl14_pretrain.pth",
"tests/weights/dinov2_vitl14_pretrain.safetensors",
expected_hash="ddd4819f",
)
run_conversion_script(
"convert_dinov2.py",
"tests/weights/dinov2_vitg14_pretrain.pth",
"tests/weights/dinov2_vitg14_pretrain.safetensors",
expected_hash="880c61f5",
)
run_conversion_script(
"convert_dinov2.py",
"tests/weights/dinov2_vits14_reg4_pretrain.pth",
"tests/weights/dinov2_vits14_reg4_pretrain.safetensors",
expected_hash="080247c7",
)
run_conversion_script(
"convert_dinov2.py",
"tests/weights/dinov2_vitb14_reg4_pretrain.pth",
"tests/weights/dinov2_vitb14_reg4_pretrain.safetensors",
expected_hash="5cd4d408",
)
run_conversion_script(
"convert_dinov2.py",
"tests/weights/dinov2_vitl14_reg4_pretrain.pth",
"tests/weights/dinov2_vitl14_reg4_pretrain.safetensors",
expected_hash="b1221702",
)
run_conversion_script(
"convert_dinov2.py",
"tests/weights/dinov2_vitg14_reg4_pretrain.pth",
"tests/weights/dinov2_vitg14_reg4_pretrain.safetensors",
expected_hash="639398eb",
)
def convert_control_lora_fooocus():
run_conversion_script(
"convert_fooocus_control_lora.py",
"tests/weights/lllyasviel/misc/control-lora-canny-rank128.safetensors",
"tests/weights/control-loras/refiners_control-lora-canny-rank128.safetensors",
expected_hash="4d505134",
)
run_conversion_script(
"convert_fooocus_control_lora.py",
"tests/weights/lllyasviel/misc/fooocus_xl_cpds_128.safetensors",
"tests/weights/control-loras/refiners_fooocus_xl_cpds_128.safetensors",
expected_hash="d81aa461",
)
def convert_lcm_base():
run_conversion_script(
"convert_diffusers_unet.py",
"tests/weights/latent-consistency/lcm-sdxl",
"tests/weights/sdxl-lcm-unet.safetensors",
half=True,
expected_hash="e161b20c",
)
def convert_sdxl_lightning_base():
run_conversion_script(
"convert_diffusers_unet.py",
"tests/weights/stabilityai/stable-diffusion-xl-base-1.0",
"tests/weights/sdxl_lightning_4step_unet.safetensors",
additional_args=[
"--override-weights",
"tests/weights/ByteDance/SDXL-Lightning/sdxl_lightning_4step_unet.safetensors",
],
half=True,
expected_hash="cfdc46da",
)
run_conversion_script(
"convert_diffusers_unet.py",
"tests/weights/stabilityai/stable-diffusion-xl-base-1.0",
"tests/weights/sdxl_lightning_1step_unet_x0.safetensors",
additional_args=[
"--override-weights",
"tests/weights/ByteDance/SDXL-Lightning/sdxl_lightning_1step_unet_x0.safetensors",
],
half=True,
expected_hash="21166a64",
)
def convert_ic_light():
run_conversion_script(
"convert_ic_light.py",
"tests/weights/iclight_sd15_fc.safetensors",
"tests/weights/iclight_sd15_fc-refiners.safetensors",
half=False,
expected_hash="be315c1f",
)
def convert_mvanet():
run_conversion_script(
"convert_mvanet.py",
"tests/weights/mvanet/Model_80.pth",
"tests/weights/mvanet/mvanet.safetensors",
half=True,
expected_hash="bf9ae4cb",
)
def download_all():
print(f"\nAll weights will be downloaded to {test_weights_dir}\n")
download_sd15("runwayml/stable-diffusion-v1-5")
download_sd15("runwayml/stable-diffusion-inpainting")
download_sdxl("stabilityai/stable-diffusion-xl-base-1.0")
download_vae_ft_mse()
download_vae_fp16_fix()
download_loras()
download_preprocessors()
download_controlnet()
download_unclip()
download_ip_adapter()
download_t2i_adapter()
download_ella_adapter()
download_sam()
download_hq_sam()
download_dinov2()
download_control_lora_fooocus()
download_lcm_base()
download_lcm_lora()
download_sdxl_lightning_base()
download_sdxl_lightning_lora()
download_ic_light()
download_mvanet()
download_box_segmenter()
def convert_all():
convert_sd15()
convert_sdxl()
convert_vae_ft_mse()
convert_vae_fp16_fix()
# Note: no convert loras: this is done at runtime by `SDLoraManager`
convert_preprocessors()
convert_controlnet()
convert_unclip()
convert_ip_adapter()
convert_t2i_adapter()
convert_ella_adapter()
convert_sam()
convert_hq_sam()
convert_dinov2()
convert_control_lora_fooocus()
convert_lcm_base()
convert_sdxl_lightning_base()
convert_ic_light()
convert_mvanet()
def main():
try:
download_all()
print(f"{download_count} files ({human_readable_size(bytes_count)})\n")
if not bool(os.environ.get("DRY_RUN") == "1"):
printg("Converting weights to refiners format\n")
convert_all()
except KeyboardInterrupt:
print("Stopped")
if __name__ == "__main__":
main()

View file

@ -1,654 +0,0 @@
from collections import defaultdict
from enum import Enum, auto
from pathlib import Path
from typing import Any, DefaultDict, TypedDict
import torch
from torch import Tensor, nn
from torch.utils.hooks import RemovableHandle
from refiners.fluxion.utils import no_grad, norm, save_to_safetensors
TORCH_BASIC_LAYERS: list[type[nn.Module]] = [
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
nn.Linear,
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.LayerNorm,
nn.GroupNorm,
nn.Embedding,
nn.MaxPool2d,
nn.AvgPool2d,
nn.AdaptiveAvgPool2d,
]
ModelTypeShape = tuple[type[nn.Module], tuple[torch.Size, ...]]
class ModuleArgsDict(TypedDict):
"""Represents positional and keyword arguments passed to a module.
- `positional`: A tuple of positional arguments.
- `keyword`: A dictionary of keyword arguments.
"""
positional: tuple[Any, ...]
keyword: dict[str, Any]
class ConversionStage(Enum):
"""Represents the current stage of the conversion process.
Attributes:
INIT: The conversion process has not started.
BASIC_LAYERS_MATCH: The source and target models have the same number of basic layers.
SHAPE_AND_LAYERS_MATCH: The shape of both models agree.
MODELS_OUTPUT_AGREE: The source and target models agree.
"""
INIT = auto()
BASIC_LAYERS_MATCH = auto()
SHAPE_AND_LAYERS_MATCH = auto()
MODELS_OUTPUT_AGREE = auto()
class ModelConverter:
"""Converts a model's state_dict to match another model's state_dict.
Note: The conversion process consists of three stages
1. Verify that the source and target models have the same number of basic layers.
2. Find matching shapes and layers between the source and target models.
3. Convert the source model's state_dict to match the target model's state_dict.
4. Compare the outputs of the source and target models.
The conversion process can be run multiple times, and will resume from the last stage.
Example:
```py
source = ...
target = ...
converter = ModelConverter(
source_model=source,
target_model=target,
threshold=0.1,
verbose=False
)
is_converted = converter(args)
if is_converted:
converter.save_to_safetensors(path="converted_model.pt")
```
"""
ModuleArgs = tuple[Any, ...] | dict[str, Any] | ModuleArgsDict
stage: ConversionStage = ConversionStage.INIT
_stored_mapping: dict[str, str] | None = None
def __init__(
self,
source_model: nn.Module,
target_model: nn.Module,
source_keys_to_skip: list[str] | None = None,
target_keys_to_skip: list[str] | None = None,
custom_layer_mapping: dict[type[nn.Module], type[nn.Module]] | None = None,
threshold: float = 1e-5,
skip_output_check: bool = False,
skip_init_check: bool = False,
verbose: bool = True,
) -> None:
"""Initializes the ModelConverter.
Args:
source_model: The model to convert from.
target_model: The model to convert to.
source_keys_to_skip: A list of keys to skip when tracing the source model.
target_keys_to_skip: A list of keys to skip when tracing the target model.
custom_layer_mapping: A dictionary mapping custom layer types between the source and target models.
threshold: The threshold for comparing outputs between the source and target models.
skip_output_check: Whether to skip comparing the outputs of the source and target models.
skip_init_check: Whether to skip checking that the source and target models have the same number of basic
layers.
verbose: Whether to print messages during the conversion process.
"""
self.source_model = source_model
self.target_model = target_model
self.source_keys_to_skip = source_keys_to_skip or []
self.target_keys_to_skip = target_keys_to_skip or []
self.custom_layer_mapping = custom_layer_mapping or {}
self.threshold = threshold
self.skip_output_check = skip_output_check
self.skip_init_check = skip_init_check
self.verbose = verbose
def __repr__(self) -> str:
return (
f"ModelConverter(source_model={self.source_model.__class__.__name__},"
f" target_model={self.target_model.__class__.__name__}, stage={self.stage})"
)
def __bool__(self) -> bool:
return self.stage.value >= 2 if self.skip_output_check else self.stage.value >= 3
def run(self, source_args: ModuleArgs, target_args: ModuleArgs | None = None) -> bool:
"""Run the conversion process.
Args:
source_args: The arguments to pass to the source model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
is not provided, these arguments will also be passed to the target model.
target_args: The arguments to pass to the target model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
Returns:
True if the conversion process is done and the models agree.
"""
if target_args is None:
target_args = source_args
match self.stage:
case ConversionStage.MODELS_OUTPUT_AGREE:
self._increment_stage()
return True
case ConversionStage.SHAPE_AND_LAYERS_MATCH if self._run_shape_and_layers_match_stage(
source_args=source_args, target_args=target_args
):
self._increment_stage()
return True
case ConversionStage.BASIC_LAYERS_MATCH if self._run_basic_layers_match_stage(
source_args=source_args, target_args=target_args
):
self._increment_stage()
return self.run(source_args=source_args, target_args=target_args)
case ConversionStage.INIT if self._run_init_stage():
self._increment_stage()
return self.run(source_args=source_args, target_args=target_args)
case _:
self._log(message=f"Conversion failed at stage {self.stage.value}")
return False
def _increment_stage(self) -> None:
"""Increment the stage of the conversion process."""
match self.stage:
case ConversionStage.INIT:
self.stage = ConversionStage.BASIC_LAYERS_MATCH
self._log(
message=(
"Stage 0 -> 1 - Models have the same number of basic layers. Finding matching shapes and"
" layers..."
)
)
case ConversionStage.BASIC_LAYERS_MATCH:
self.stage = ConversionStage.SHAPE_AND_LAYERS_MATCH
self._log(
message=(
"Stage 1 -> 2 - Shape of both models agree. Applying state_dict to target model. Comparing"
" models..."
)
)
case ConversionStage.SHAPE_AND_LAYERS_MATCH:
if self.skip_output_check:
self._log(
message=(
"Stage 2 - Nothing to do. Skipping output check. If you want to compare the outputs, set"
" `skip_output_check` to `False`"
)
)
else:
self.stage = ConversionStage.MODELS_OUTPUT_AGREE
self._log(
message=(
"Stage 2 -> 3 - Conversion is done and source and target models agree: you can export the"
" converted model using `save_to_safetensors`"
)
)
case ConversionStage.MODELS_OUTPUT_AGREE:
self._log(
message=(
"Stage 3 - Nothing to do. Conversion is done and source and target models agree: you can export"
" the converted model using `save_to_safetensors`"
)
)
def get_state_dict(self) -> dict[str, Tensor]:
"""Get the converted state_dict."""
if not self:
raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
return self.target_model.state_dict()
def get_mapping(self) -> dict[str, str]:
"""Get the mapping between the source and target models' state_dicts."""
if not self:
raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
assert self._stored_mapping is not None, "Mapping is not stored"
return self._stored_mapping
def save_to_safetensors(self, path: Path | str, metadata: dict[str, str] | None = None, half: bool = False) -> None:
"""Save the converted model to a SafeTensors file.
Warning:
This method can only be called after the conversion process is done.
Args:
path: The path to save the converted model to.
metadata: Metadata to save with the converted model.
half: Whether to save the converted model as half precision.
Raises:
ValueError: If the conversion process is not done yet. Run `converter` first.
"""
if not self:
raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
state_dict = self.get_state_dict()
if half:
state_dict = {key: value.half() for key, value in state_dict.items()}
save_to_safetensors(path=path, tensors=state_dict, metadata=metadata)
def map_state_dicts(
self,
source_args: ModuleArgs,
target_args: ModuleArgs | None = None,
) -> dict[str, str] | None:
"""Find a mapping between the source and target models' state_dicts.
Args:
source_args: The arguments to pass to the source model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
is not provided, these arguments will also be passed to the target model.
target_args: The arguments to pass to the target model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
Returns:
A dictionary mapping keys in the target model's state_dict to keys in the source model's state_dict.
"""
if target_args is None:
target_args = source_args
source_order = self._trace_module_execution_order(
module=self.source_model, args=source_args, keys_to_skip=self.source_keys_to_skip
)
target_order = self._trace_module_execution_order(
module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip
)
if not self._assert_shapes_aligned(source_order=source_order, target_order=target_order):
return None
mapping: dict[str, str] = {}
for source_type_shape in source_order:
source_keys = source_order[source_type_shape]
target_type_shape = source_type_shape
if not self._is_torch_basic_layer(module_type=source_type_shape[0]):
for source_custom_type, target_custom_type in self.custom_layer_mapping.items():
if source_custom_type == source_type_shape[0]:
target_type_shape = (target_custom_type, source_type_shape[1])
break
target_keys = target_order[target_type_shape]
mapping.update(zip(target_keys, source_keys))
return mapping
def compare_models(
self,
source_args: ModuleArgs,
target_args: ModuleArgs | None = None,
threshold: float = 1e-5,
) -> bool:
"""Compare the outputs of the source and target models.
Args:
source_args: The arguments to pass to the source model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
is not provided, these arguments will also be passed to the target model.
target_args: The arguments to pass to the target model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
threshold: The threshold for comparing outputs between the source and target models.
Returns:
True if the outputs of the source and target models agree.
"""
if target_args is None:
target_args = source_args
source_outputs = self._collect_layers_outputs(
module=self.source_model, args=source_args, keys_to_skip=self.source_keys_to_skip
)
target_outputs = self._collect_layers_outputs(
module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip
)
diff, prev_source_key, prev_target_key = None, None, None
for (source_key, source_output), (target_key, target_output) in zip(source_outputs, target_outputs):
diff = norm(source_output - target_output.reshape(shape=source_output.shape)).item()
if diff > threshold:
self._log(
f"Models diverged between {prev_source_key} and {source_key}, and between {prev_target_key} and"
f" {target_key}, difference in norm: {diff}"
)
return False
prev_source_key, prev_target_key = source_key, target_key
self._log(message=f"Models agree. Difference in norm: {diff}")
return True
def _run_init_stage(self) -> bool:
"""Run the init stage of the conversion process."""
if self.skip_init_check:
self._log(
message=(
"Skipping init check. If you want to check the number of basic layers, set `skip_init_check` to"
" `False`"
)
)
return True
is_count_correct = self._verify_basic_layers_count()
is_not_missing_layers = self._verify_missing_basic_layers()
return is_count_correct and is_not_missing_layers
def _run_basic_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool:
"""Run the basic layers match stage of the conversion process."""
mapping = self.map_state_dicts(source_args=source_args, target_args=target_args)
self._stored_mapping = mapping
if mapping is None:
self._log(message="Models do not have matching shapes.")
return False
source_state_dict = self.source_model.state_dict()
target_state_dict = self.target_model.state_dict()
converted_state_dict = self._convert_state_dict(
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
)
self.target_model.load_state_dict(state_dict=converted_state_dict)
return True
def _run_shape_and_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool:
"""Run the shape and layers match stage of the conversion process."""
if self.skip_output_check:
self._log(
message="Skipping output check. If you want to compare the outputs, set `skip_output_check` to `False`"
)
return True
try:
if self.compare_models(source_args=source_args, target_args=target_args, threshold=self.threshold):
self._log(message="Models agree. You can export the converted model using `save_to_safetensors`")
return True
else:
self._log(message="Models do not agree. Try to increase the threshold or modify the models.")
return False
except Exception as e:
self._log(message=f"An error occurred while comparing the models: {e}")
return False
def _log(self, message: str) -> None:
"""Print a message if `verbose` is `True`."""
if self.verbose:
print(message)
def _debug_print_shapes(
self,
shape: ModelTypeShape,
source_keys: list[str],
target_keys: list[str],
) -> None:
"""Print the shapes of the sub-modules in `source_keys` and `target_keys`."""
self._log(message=f"{shape}")
max_len = max(len(source_keys), len(target_keys))
for i in range(max_len):
source_key = source_keys[i] if i < len(source_keys) else "---"
target_key = target_keys[i] if i < len(target_keys) else "---"
self._log(f"\t{source_key}\t{target_key}")
@staticmethod
def _unpack_module_args(module_args: ModuleArgs) -> tuple[tuple[Any, ...], dict[str, Any]]:
"""Unpack the positional and keyword arguments passed to a module."""
match module_args:
case tuple(positional_args):
keyword_args: dict[str, Any] = {}
case {"positional": positional_args, "keyword": keyword_args}:
pass
case _:
positional_args = ()
keyword_args = dict(**module_args)
return positional_args, keyword_args
def _is_torch_basic_layer(self, module_type: type[nn.Module]) -> bool:
"""Check if a module type is a subclass of a torch basic layer."""
return any(issubclass(module_type, torch_basic_layer) for torch_basic_layer in TORCH_BASIC_LAYERS)
def _infer_basic_layer_type(self, module: nn.Module) -> type[nn.Module] | None:
"""Infer the type of a basic layer."""
layer_types = (
set(self.custom_layer_mapping.keys()) | set(self.custom_layer_mapping.values()) | set(TORCH_BASIC_LAYERS)
)
for layer_type in layer_types:
if isinstance(module, layer_type):
return layer_type
return None
def get_module_signature(self, module: nn.Module) -> ModelTypeShape:
"""Get the signature of a module."""
layer_type = self._infer_basic_layer_type(module=module)
assert layer_type is not None, f"Module {module} is not a basic layer"
param_shapes = [p.shape for p in module.parameters()]
return (layer_type, tuple(param_shapes))
def _count_basic_layers(self, module: nn.Module) -> dict[type[nn.Module], int]:
"""Count the number of basic layers in a module."""
count: DefaultDict[type[nn.Module], int] = defaultdict(int)
for submodule in module.modules():
layer_type = self._infer_basic_layer_type(module=submodule)
if layer_type is not None:
count[layer_type] += 1
return count
def _verify_basic_layers_count(self) -> bool:
"""Verify that the source and target models have the same number of basic layers."""
source_layers = self._count_basic_layers(module=self.source_model)
target_layers = self._count_basic_layers(module=self.target_model)
reverse_mapping = {v: k for k, v in self.custom_layer_mapping.items()}
diff: dict[type[nn.Module], tuple[int, int]] = {}
for layer_type, source_count in source_layers.items():
target_type = self.custom_layer_mapping.get(layer_type, layer_type)
target_count = target_layers.get(target_type, 0)
if source_count != target_count:
diff[layer_type] = (source_count, target_count)
for layer_type, target_count in target_layers.items():
source_type = reverse_mapping.get(layer_type, layer_type)
source_count = source_layers.get(source_type, 0)
if source_count != target_count:
diff[layer_type] = (source_count, target_count)
if diff:
message = "Models do not have the same number of basic layers:\n"
for layer_type, counts in diff.items():
message += f" {layer_type}: Source {counts[0]} - Target {counts[1]}\n"
self._log(message=message.strip())
return False
return True
def _is_weighted_leaf_module(self, module: nn.Module) -> bool:
"""Check if a module is a leaf module with weights."""
return next(module.parameters(), None) is not None and next(module.children(), None) is None
def _check_for_missing_basic_layers(self, module: nn.Module) -> list[type[nn.Module]]:
"""Check if a module has weighted leaf modules that are not basic layers."""
return [
type(submodule)
for submodule in module.modules()
if self._is_weighted_leaf_module(module=submodule) and not self._infer_basic_layer_type(module=submodule)
]
def _verify_missing_basic_layers(self) -> bool:
"""Verify that the source and target models do not have missing basic layers."""
missing_source_layers = self._check_for_missing_basic_layers(module=self.source_model)
missing_target_layers = self._check_for_missing_basic_layers(module=self.target_model)
if missing_source_layers or missing_target_layers:
self._log(
message=(
"Models might have missing basic layers. If you want to skip this check, set"
f" `skip_init_check` to `True`: {missing_source_layers}, {missing_target_layers}"
)
)
return False
return True
@no_grad()
def _trace_module_execution_order(
self,
module: nn.Module,
args: ModuleArgs,
keys_to_skip: list[str],
) -> dict[ModelTypeShape, list[str]]:
"""Execute a forward pass and store the order of execution of specific sub-modules.
Args:
module: The module to trace.
args: The arguments to pass to the module it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
keys_to_skip: A list of keys to skip when tracing the module.
Returns:
A dictionary mapping the signature of each sub-module to a list of keys in the module's `named_modules`
"""
submodule_to_key: dict[nn.Module, str] = {}
execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list)
def collect_execution_order_hook(layer: nn.Module, *_: Any) -> None:
layer_signature = self.get_module_signature(module=layer)
execution_order[layer_signature].append(submodule_to_key[layer])
hooks: list[RemovableHandle] = []
named_modules: list[tuple[str, nn.Module]] = module.named_modules() # type: ignore
for name, submodule in named_modules:
if (self._infer_basic_layer_type(module=submodule) is not None) and name not in keys_to_skip:
submodule_to_key[submodule] = name # type: ignore
hook = submodule.register_forward_hook(hook=collect_execution_order_hook)
hooks.append(hook)
positional_args, keyword_args = self._unpack_module_args(module_args=args)
module(*positional_args, **keyword_args)
for hook in hooks:
hook.remove()
return dict(execution_order)
def _assert_shapes_aligned(
self, source_order: dict[ModelTypeShape, list[str]], target_order: dict[ModelTypeShape, list[str]]
) -> bool:
"""Assert that the shapes of the sub-modules in `source_order` and `target_order` are aligned."""
model_type_shapes = set(source_order.keys()) | set(target_order.keys())
default_type_shapes = [
type_shape for type_shape in model_type_shapes if self._is_torch_basic_layer(module_type=type_shape[0])
]
shape_mismatched = False
for model_type_shape in default_type_shapes:
source_keys = source_order.get(model_type_shape, [])
target_keys = target_order.get(model_type_shape, [])
if len(source_keys) != len(target_keys):
shape_mismatched = True
self._debug_print_shapes(shape=model_type_shape, source_keys=source_keys, target_keys=target_keys)
for source_custom_type in self.custom_layer_mapping.keys():
# iterate over all type_shapes that have the same type as source_custom_type
for source_type_shape in [
type_shape for type_shape in model_type_shapes if type_shape[0] == source_custom_type
]:
source_keys = source_order.get(source_type_shape, [])
target_custom_type = self.custom_layer_mapping[source_custom_type]
target_type_shape = (target_custom_type, source_type_shape[1])
target_keys = target_order.get(target_type_shape, [])
if len(source_keys) != len(target_keys):
shape_mismatched = True
self._debug_print_shapes(shape=source_type_shape, source_keys=source_keys, target_keys=target_keys)
return not shape_mismatched
@staticmethod
def _convert_state_dict(
source_state_dict: dict[str, Tensor], target_state_dict: dict[str, Tensor], state_dict_mapping: dict[str, str]
) -> dict[str, Tensor]:
"""Convert the source model's state_dict to match the target model's state_dict."""
converted_state_dict: dict[str, Tensor] = {}
for target_key in target_state_dict:
target_prefix, suffix = target_key.rsplit(sep=".", maxsplit=1)
source_prefix = state_dict_mapping[target_prefix]
source_key = ".".join([source_prefix, suffix])
converted_state_dict[target_key] = source_state_dict[source_key]
return converted_state_dict
@no_grad()
def _collect_layers_outputs(
self, module: nn.Module, args: ModuleArgs, keys_to_skip: list[str]
) -> list[tuple[str, Tensor]]:
"""Execute a forward pass and store the output of specific sub-modules.
Args:
module: The module to trace.
args: The arguments to pass to the module it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
keys_to_skip: A list of keys to skip when tracing the module.
Returns:
A list of tuples containing the key of each sub-module and its output.
Note:
The output of each sub-module is cloned to avoid memory leaks.
"""
submodule_to_key: dict[nn.Module, str] = {}
execution_order: list[tuple[str, Tensor]] = []
def collect_execution_order_hook(layer: nn.Module, _: Any, output: Tensor) -> None:
execution_order.append((submodule_to_key[layer], output.clone()))
hooks: list[RemovableHandle] = []
named_modules: list[tuple[str, nn.Module]] = module.named_modules() # type: ignore
for name, submodule in named_modules:
if (self._infer_basic_layer_type(module=submodule) is not None) and name not in keys_to_skip:
submodule_to_key[submodule] = name # type: ignore
hook = submodule.register_forward_hook(hook=collect_execution_order_hook)
hooks.append(hook)
positional_args, keyword_args = self._unpack_module_args(module_args=args)
module(*positional_args, **keyword_args)
for hook in hooks:
hook.remove()
return execution_order