mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
delete old conversion scripts
This commit is contained in:
parent
2796117d2d
commit
189cfa1a69
|
@ -1 +0,0 @@
|
|||
::: refiners.fluxion.model_converter
|
|
@ -1,81 +0,0 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from diffusers import AutoencoderKL # type: ignore
|
||||
from torch import nn
|
||||
|
||||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||
|
||||
|
||||
class Args(argparse.Namespace):
|
||||
source_path: str
|
||||
output_path: str | None
|
||||
use_half: bool
|
||||
verbose: bool
|
||||
|
||||
|
||||
def setup_converter(args: Args) -> ModelConverter:
|
||||
target = LatentDiffusionAutoencoder()
|
||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
||||
source: nn.Module = AutoencoderKL.from_pretrained( # type: ignore
|
||||
pretrained_model_name_or_path=args.source_path,
|
||||
subfolder=args.subfolder,
|
||||
low_cpu_mem_usage=False,
|
||||
) # type: ignore
|
||||
x = torch.randn(1, 3, 512, 512)
|
||||
converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose)
|
||||
if not converter.run(source_args=(x,)):
|
||||
raise RuntimeError("Model conversion failed")
|
||||
return converter
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert a pretrained diffusers AutoencoderKL model to a refiners Latent Diffusion Autoencoder"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--from",
|
||||
type=str,
|
||||
dest="source_path",
|
||||
default="runwayml/stable-diffusion-v1-5",
|
||||
help="Path to the source pretrained model (default: 'runwayml/stable-diffusion-v1-5').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--subfolder",
|
||||
type=str,
|
||||
dest="subfolder",
|
||||
default="vae",
|
||||
help="Subfolder in the source path where the model is located inside the Hub (default: 'vae')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
type=str,
|
||||
dest="output_path",
|
||||
default=None,
|
||||
help=(
|
||||
"Path to save the converted model (extension will be .safetensors). If not specified, the output path will"
|
||||
" be the source path with the extension changed to .safetensors."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--half",
|
||||
action="store_true",
|
||||
dest="use_half",
|
||||
default=False,
|
||||
help="Use this flag to save the output file as half precision (default: full precision).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
dest="verbose",
|
||||
default=False,
|
||||
help="Use this flag to print verbose output during conversion.",
|
||||
)
|
||||
args = parser.parse_args(namespace=Args())
|
||||
if args.output_path is None:
|
||||
args.output_path = f"{Path(args.source_path).stem}-autoencoder.safetensors"
|
||||
assert args.output_path is not None
|
||||
converter = setup_converter(args=args)
|
||||
converter.save_to_safetensors(path=args.output_path, half=args.use_half)
|
|
@ -1,233 +0,0 @@
|
|||
# pyright: reportPrivateUsage=false
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from diffusers import ControlNetModel # type: ignore
|
||||
from torch import nn
|
||||
|
||||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from refiners.fluxion.utils import no_grad, save_to_safetensors
|
||||
from refiners.foundationals.latent_diffusion import (
|
||||
DPMSolver,
|
||||
SD1ControlnetAdapter,
|
||||
SD1UNet,
|
||||
)
|
||||
|
||||
|
||||
class Args(argparse.Namespace):
|
||||
source_path: str
|
||||
output_path: str | None
|
||||
|
||||
|
||||
@no_grad()
|
||||
def convert(args: Args) -> dict[str, torch.Tensor]:
|
||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
||||
controlnet_src: nn.Module = ControlNetModel.from_pretrained( # type: ignore
|
||||
pretrained_model_name_or_path=args.source_path,
|
||||
low_cpu_mem_usage=False,
|
||||
)
|
||||
unet = SD1UNet(in_channels=4)
|
||||
adapter = SD1ControlnetAdapter(unet, name="mycn").inject()
|
||||
controlnet = adapter.controlnet
|
||||
|
||||
condition = torch.randn(1, 3, 512, 512)
|
||||
adapter.set_controlnet_condition(condition=condition)
|
||||
|
||||
clip_text_embedding = torch.rand(1, 77, 768)
|
||||
unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
||||
|
||||
solver = DPMSolver(num_inference_steps=10)
|
||||
timestep = solver.timesteps[0].unsqueeze(dim=0)
|
||||
unet.set_timestep(timestep=timestep.unsqueeze(dim=0))
|
||||
|
||||
x = torch.randn(1, 4, 64, 64)
|
||||
|
||||
# We need the hack below because our implementation is not strictly equivalent
|
||||
# to diffusers in order, since we compute the residuals inline instead of
|
||||
# in a separate step.
|
||||
|
||||
converter = ModelConverter(
|
||||
source_model=controlnet_src, target_model=controlnet, skip_output_check=True, verbose=False
|
||||
)
|
||||
|
||||
source_order = converter._trace_module_execution_order(
|
||||
module=controlnet_src, args=(x, timestep, clip_text_embedding, condition), keys_to_skip=[]
|
||||
)
|
||||
target_order = converter._trace_module_execution_order(module=controlnet, args=(x,), keys_to_skip=[])
|
||||
|
||||
broken_k = (nn.Conv2d, (torch.Size([320, 320, 1, 1]), torch.Size([320])))
|
||||
|
||||
expected_source_order = [
|
||||
"down_blocks.0.attentions.0.proj_in",
|
||||
"down_blocks.0.attentions.0.proj_out",
|
||||
"down_blocks.0.attentions.1.proj_in",
|
||||
"down_blocks.0.attentions.1.proj_out",
|
||||
"controlnet_down_blocks.0",
|
||||
"controlnet_down_blocks.1",
|
||||
"controlnet_down_blocks.2",
|
||||
"controlnet_down_blocks.3",
|
||||
]
|
||||
|
||||
expected_target_order = [
|
||||
"DownBlocks.Chain_1.Passthrough.Conv2d",
|
||||
"DownBlocks.Chain_2.CLIPLCrossAttention.Chain_1.Conv2d",
|
||||
"DownBlocks.Chain_2.CLIPLCrossAttention.Chain_3.Conv2d",
|
||||
"DownBlocks.Chain_2.Passthrough.Conv2d",
|
||||
"DownBlocks.Chain_3.CLIPLCrossAttention.Chain_1.Conv2d",
|
||||
"DownBlocks.Chain_3.CLIPLCrossAttention.Chain_3.Conv2d",
|
||||
"DownBlocks.Chain_3.Passthrough.Conv2d",
|
||||
"DownBlocks.Chain_4.Passthrough.Conv2d",
|
||||
]
|
||||
|
||||
fixed_source_order = [
|
||||
"controlnet_down_blocks.0",
|
||||
"down_blocks.0.attentions.0.proj_in",
|
||||
"down_blocks.0.attentions.0.proj_out",
|
||||
"controlnet_down_blocks.1",
|
||||
"down_blocks.0.attentions.1.proj_in",
|
||||
"down_blocks.0.attentions.1.proj_out",
|
||||
"controlnet_down_blocks.2",
|
||||
"controlnet_down_blocks.3",
|
||||
]
|
||||
|
||||
assert source_order[broken_k] == expected_source_order
|
||||
assert target_order[broken_k] == expected_target_order
|
||||
source_order[broken_k] = fixed_source_order
|
||||
|
||||
broken_k = (nn.Conv2d, (torch.Size([640, 640, 1, 1]), torch.Size([640])))
|
||||
|
||||
expected_source_order = [
|
||||
"down_blocks.1.attentions.0.proj_in",
|
||||
"down_blocks.1.attentions.0.proj_out",
|
||||
"down_blocks.1.attentions.1.proj_in",
|
||||
"down_blocks.1.attentions.1.proj_out",
|
||||
"controlnet_down_blocks.4",
|
||||
"controlnet_down_blocks.5",
|
||||
"controlnet_down_blocks.6",
|
||||
]
|
||||
|
||||
expected_target_order = [
|
||||
"DownBlocks.Chain_5.CLIPLCrossAttention.Chain_1.Conv2d",
|
||||
"DownBlocks.Chain_5.CLIPLCrossAttention.Chain_3.Conv2d",
|
||||
"DownBlocks.Chain_5.Passthrough.Conv2d",
|
||||
"DownBlocks.Chain_6.CLIPLCrossAttention.Chain_1.Conv2d",
|
||||
"DownBlocks.Chain_6.CLIPLCrossAttention.Chain_3.Conv2d",
|
||||
"DownBlocks.Chain_6.Passthrough.Conv2d",
|
||||
"DownBlocks.Chain_7.Passthrough.Conv2d",
|
||||
]
|
||||
|
||||
fixed_source_order = [
|
||||
"down_blocks.1.attentions.0.proj_in",
|
||||
"down_blocks.1.attentions.0.proj_out",
|
||||
"controlnet_down_blocks.4",
|
||||
"down_blocks.1.attentions.1.proj_in",
|
||||
"down_blocks.1.attentions.1.proj_out",
|
||||
"controlnet_down_blocks.5",
|
||||
"controlnet_down_blocks.6",
|
||||
]
|
||||
|
||||
assert source_order[broken_k] == expected_source_order
|
||||
assert target_order[broken_k] == expected_target_order
|
||||
source_order[broken_k] = fixed_source_order
|
||||
|
||||
broken_k = (nn.Conv2d, (torch.Size([1280, 1280, 1, 1]), torch.Size([1280])))
|
||||
|
||||
expected_source_order = [
|
||||
"down_blocks.2.attentions.0.proj_in",
|
||||
"down_blocks.2.attentions.0.proj_out",
|
||||
"down_blocks.2.attentions.1.proj_in",
|
||||
"down_blocks.2.attentions.1.proj_out",
|
||||
"mid_block.attentions.0.proj_in",
|
||||
"mid_block.attentions.0.proj_out",
|
||||
"controlnet_down_blocks.7",
|
||||
"controlnet_down_blocks.8",
|
||||
"controlnet_down_blocks.9",
|
||||
"controlnet_down_blocks.10",
|
||||
"controlnet_down_blocks.11",
|
||||
"controlnet_mid_block",
|
||||
]
|
||||
|
||||
expected_target_order = [
|
||||
"DownBlocks.Chain_8.CLIPLCrossAttention.Chain_1.Conv2d",
|
||||
"DownBlocks.Chain_8.CLIPLCrossAttention.Chain_3.Conv2d",
|
||||
"DownBlocks.Chain_8.Passthrough.Conv2d",
|
||||
"DownBlocks.Chain_9.CLIPLCrossAttention.Chain_1.Conv2d",
|
||||
"DownBlocks.Chain_9.CLIPLCrossAttention.Chain_3.Conv2d",
|
||||
"DownBlocks.Chain_9.Passthrough.Conv2d",
|
||||
"DownBlocks.Chain_10.Passthrough.Conv2d",
|
||||
"DownBlocks.Chain_11.Passthrough.Conv2d",
|
||||
"DownBlocks.Chain_12.Passthrough.Conv2d",
|
||||
"MiddleBlock.CLIPLCrossAttention.Chain_1.Conv2d",
|
||||
"MiddleBlock.CLIPLCrossAttention.Chain_3.Conv2d",
|
||||
"MiddleBlock.Passthrough.Conv2d",
|
||||
]
|
||||
|
||||
fixed_source_order = [
|
||||
"down_blocks.2.attentions.0.proj_in",
|
||||
"down_blocks.2.attentions.0.proj_out",
|
||||
"controlnet_down_blocks.7",
|
||||
"down_blocks.2.attentions.1.proj_in",
|
||||
"down_blocks.2.attentions.1.proj_out",
|
||||
"controlnet_down_blocks.8",
|
||||
"controlnet_down_blocks.9",
|
||||
"controlnet_down_blocks.10",
|
||||
"controlnet_down_blocks.11",
|
||||
"mid_block.attentions.0.proj_in",
|
||||
"mid_block.attentions.0.proj_out",
|
||||
"controlnet_mid_block",
|
||||
]
|
||||
|
||||
assert source_order[broken_k] == expected_source_order
|
||||
assert target_order[broken_k] == expected_target_order
|
||||
source_order[broken_k] = fixed_source_order
|
||||
|
||||
assert converter._assert_shapes_aligned(source_order=source_order, target_order=target_order), "Shapes not aligned"
|
||||
|
||||
mapping: dict[str, str] = {}
|
||||
for model_type_shape in source_order:
|
||||
source_keys = source_order[model_type_shape]
|
||||
target_keys = target_order[model_type_shape]
|
||||
mapping.update(zip(target_keys, source_keys))
|
||||
|
||||
state_dict = converter._convert_state_dict(
|
||||
source_state_dict=controlnet_src.state_dict(),
|
||||
target_state_dict=controlnet.state_dict(),
|
||||
state_dict_mapping=mapping,
|
||||
)
|
||||
|
||||
return {k: v.half() for k, v in state_dict.items()}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Convert a diffusers ControlNet model to a Refiners ControlNet model")
|
||||
parser.add_argument(
|
||||
"--from",
|
||||
type=str,
|
||||
dest="source_path",
|
||||
default="lllyasviel/sd-controlnet-depth",
|
||||
help=(
|
||||
"Can be a path to a .bin, a .safetensors file, or a model identifier from Hugging Face Hub. Defaults to"
|
||||
" lllyasviel/sd-controlnet-depth"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
type=str,
|
||||
dest="output_path",
|
||||
required=False,
|
||||
default=None,
|
||||
help=(
|
||||
"Output path (.safetensors) for converted model. If not provided, the output path will be the same as the"
|
||||
" source path."
|
||||
),
|
||||
)
|
||||
args = parser.parse_args(namespace=Args())
|
||||
if args.output_path is None:
|
||||
args.output_path = f"{Path(args.source_path).stem}-controlnet.safetensors"
|
||||
state_dict = convert(args=args)
|
||||
save_to_safetensors(path=args.output_path, tensors=state_dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,154 +0,0 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from refiners.fluxion.utils import save_to_safetensors
|
||||
from refiners.foundationals.latent_diffusion import SD1IPAdapter, SD1UNet, SDXLIPAdapter, SDXLUNet
|
||||
|
||||
# Running:
|
||||
#
|
||||
# from diffusers import UNet2DConditionModel
|
||||
# unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
|
||||
# for k in unet.attn_processors.keys():
|
||||
# print(k)
|
||||
#
|
||||
# Gives:
|
||||
#
|
||||
# down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor
|
||||
# down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor
|
||||
# ...
|
||||
# down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor
|
||||
# up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor
|
||||
# up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor
|
||||
# ...
|
||||
# up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor
|
||||
# mid_block.attentions.0.transformer_blocks.0.attn1.processor
|
||||
# mid_block.attentions.0.transformer_blocks.0.attn2.processor
|
||||
#
|
||||
# With attn1=self-attention and attn2=cross-attention, and middle block in last position. So in terms of increasing
|
||||
# indices:
|
||||
#
|
||||
# DownBlocks -> [1, 3, 5, 7, 9, 11]
|
||||
# MiddleBlock -> [31]
|
||||
# UpBlocks -> [13, 15, 17, 19, 21, 23, 25, 27, 29]
|
||||
#
|
||||
# Same for SDXL with more layers (70 cross-attentions vs. 16)
|
||||
CROSS_ATTN_MAPPING: dict[str, list[int]] = {
|
||||
"sd15": list(range(1, 12, 2)) + [31] + list(range(13, 30, 2)),
|
||||
"sdxl": list(range(1, 48, 2)) + list(range(121, 140, 2)) + list(range(49, 120, 2)),
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Converts a IP-Adapter diffusers model to refiners.")
|
||||
parser.add_argument(
|
||||
"--from",
|
||||
type=str,
|
||||
required=True,
|
||||
dest="source_path",
|
||||
help="Path to the source model. (e.g.: 'ip-adapter_sd15.bin').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
type=str,
|
||||
dest="output_path",
|
||||
default=None,
|
||||
help=(
|
||||
"Path to save the converted model. If not specified, the output path will be the source path with the"
|
||||
" extension changed to .safetensors."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--verbose", action="store_true", dest="verbose")
|
||||
parser.add_argument("--half", action="store_true", dest="half")
|
||||
args = parser.parse_args()
|
||||
if args.output_path is None:
|
||||
args.output_path = f"{Path(args.source_path).stem}.safetensors"
|
||||
|
||||
# Do not use `load_tensors`: first-level values are not tensors.
|
||||
weights: dict[str, dict[str, torch.Tensor]] = torch.load(args.source_path, "cpu") # type: ignore
|
||||
assert isinstance(weights, dict)
|
||||
assert sorted(weights.keys()) == ["image_proj", "ip_adapter"]
|
||||
|
||||
image_proj_weights = weights["image_proj"]
|
||||
ip_adapter_weights = weights["ip_adapter"]
|
||||
|
||||
fine_grained = "latents" in image_proj_weights # aka IP-Adapter plus
|
||||
|
||||
match len(ip_adapter_weights):
|
||||
case 32:
|
||||
ip_adapter = SD1IPAdapter(target=SD1UNet(in_channels=4), fine_grained=fine_grained)
|
||||
cross_attn_mapping = CROSS_ATTN_MAPPING["sd15"]
|
||||
case 140:
|
||||
ip_adapter = SDXLIPAdapter(target=SDXLUNet(in_channels=4), fine_grained=fine_grained)
|
||||
cross_attn_mapping = CROSS_ATTN_MAPPING["sdxl"]
|
||||
case _:
|
||||
raise ValueError("Unexpected number of keys in input checkpoint")
|
||||
|
||||
# Manual conversion to avoid any runtime dependency on IP-Adapter[1] custom classes
|
||||
# [1]: https://github.com/tencent-ailab/IP-Adapter
|
||||
|
||||
state_dict: dict[str, torch.Tensor] = {}
|
||||
|
||||
image_proj_state_dict: dict[str, torch.Tensor]
|
||||
|
||||
if fine_grained:
|
||||
w = image_proj_weights
|
||||
image_proj_state_dict = {
|
||||
"LatentsToken.Parameter.weight": w["latents"].squeeze(0), # drop batch dim = 1
|
||||
"Linear_1.weight": w["proj_in.weight"],
|
||||
"Linear_1.bias": w["proj_in.bias"],
|
||||
"Linear_2.weight": w["proj_out.weight"],
|
||||
"Linear_2.bias": w["proj_out.bias"],
|
||||
"LayerNorm.weight": w["norm_out.weight"],
|
||||
"LayerNorm.bias": w["norm_out.bias"],
|
||||
}
|
||||
for i in range(4):
|
||||
t_pfx, s_pfx = f"Transformer.TransformerLayer_{i+1}.Residual_", f"layers.{i}."
|
||||
image_proj_state_dict.update(
|
||||
{
|
||||
f"{t_pfx}1.PerceiverAttention.Distribute.LayerNorm_1.weight": w[f"{s_pfx}0.norm1.weight"],
|
||||
f"{t_pfx}1.PerceiverAttention.Distribute.LayerNorm_1.bias": w[f"{s_pfx}0.norm1.bias"],
|
||||
f"{t_pfx}1.PerceiverAttention.Distribute.LayerNorm_2.weight": w[f"{s_pfx}0.norm2.weight"],
|
||||
f"{t_pfx}1.PerceiverAttention.Distribute.LayerNorm_2.bias": w[f"{s_pfx}0.norm2.bias"],
|
||||
f"{t_pfx}1.PerceiverAttention.Parallel.Chain_2.Linear.weight": w[f"{s_pfx}0.to_q.weight"],
|
||||
f"{t_pfx}1.PerceiverAttention.Parallel.Chain_1.Linear.weight": w[f"{s_pfx}0.to_kv.weight"],
|
||||
f"{t_pfx}1.PerceiverAttention.Linear.weight": w[f"{s_pfx}0.to_out.weight"],
|
||||
f"{t_pfx}2.LayerNorm.weight": w[f"{s_pfx}1.0.weight"],
|
||||
f"{t_pfx}2.LayerNorm.bias": w[f"{s_pfx}1.0.bias"],
|
||||
f"{t_pfx}2.FeedForward.Linear_1.weight": w[f"{s_pfx}1.1.weight"],
|
||||
f"{t_pfx}2.FeedForward.Linear_2.weight": w[f"{s_pfx}1.3.weight"],
|
||||
}
|
||||
)
|
||||
else:
|
||||
image_proj_state_dict = {
|
||||
"Linear.weight": image_proj_weights["proj.weight"],
|
||||
"Linear.bias": image_proj_weights["proj.bias"],
|
||||
"LayerNorm.weight": image_proj_weights["norm.weight"],
|
||||
"LayerNorm.bias": image_proj_weights["norm.bias"],
|
||||
}
|
||||
ip_adapter.image_proj.load_state_dict(state_dict=image_proj_state_dict)
|
||||
|
||||
for k, v in image_proj_state_dict.items():
|
||||
state_dict[f"image_proj.{k}"] = v
|
||||
|
||||
assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2
|
||||
|
||||
for i, _ in enumerate(ip_adapter.sub_adapters):
|
||||
cross_attn_index = cross_attn_mapping[i]
|
||||
k_ip = f"{cross_attn_index}.to_k_ip.weight"
|
||||
v_ip = f"{cross_attn_index}.to_v_ip.weight"
|
||||
|
||||
# the name of the key is not checked at runtime, so we keep the original name
|
||||
state_dict[f"ip_adapter.{i:03d}.to_k_ip.weight"] = ip_adapter_weights[k_ip]
|
||||
state_dict[f"ip_adapter.{i:03d}.to_v_ip.weight"] = ip_adapter_weights[v_ip]
|
||||
|
||||
if args.half:
|
||||
state_dict = {key: value.half() for key, value in state_dict.items()}
|
||||
if args.output_path is None:
|
||||
args.output_path = f"{Path(args.source_path).stem}.safetensors"
|
||||
save_to_safetensors(path=args.output_path, tensors=state_dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,63 +0,0 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from diffusers import T2IAdapter # type: ignore
|
||||
from torch import nn
|
||||
|
||||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from refiners.foundationals.latent_diffusion.t2i_adapter import ConditionEncoder, ConditionEncoderXL
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert a pretrained diffusers T2I-Adapter model to refiners")
|
||||
parser.add_argument(
|
||||
"--from",
|
||||
type=str,
|
||||
dest="source_path",
|
||||
required=True,
|
||||
help="Path or repository name of the source model. (e.g.: 'ip-adapter_sd15.bin').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
type=str,
|
||||
dest="output_path",
|
||||
default=None,
|
||||
help=(
|
||||
"Path to save the converted model (extension will be .safetensors). If not specified, the output path will"
|
||||
" be the source path with the extension changed to .safetensors."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--half",
|
||||
action="store_true",
|
||||
dest="use_half",
|
||||
default=False,
|
||||
help="Use this flag to save the output file as half precision (default: full precision).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
dest="verbose",
|
||||
default=False,
|
||||
help="Use this flag to print verbose output during conversion.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.output_path is None:
|
||||
args.output_path = f"{Path(args.source_path).name}.safetensors"
|
||||
assert args.output_path is not None
|
||||
|
||||
sdxl = "xl" in args.source_path
|
||||
target = ConditionEncoderXL() if sdxl else ConditionEncoder()
|
||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
||||
source: nn.Module = T2IAdapter.from_pretrained( # type: ignore
|
||||
pretrained_model_name_or_path=args.source_path,
|
||||
low_cpu_mem_usage=False,
|
||||
)
|
||||
assert isinstance(source, nn.Module), "Source model is not a nn.Module"
|
||||
|
||||
x = torch.randn(1, 3, 1024, 1024) if sdxl else torch.randn(1, 3, 512, 512)
|
||||
converter = ModelConverter(source_model=source, target_model=target, verbose=args.verbose)
|
||||
if not converter.run(source_args=(x,)):
|
||||
raise RuntimeError("Model conversion failed")
|
||||
|
||||
converter.save_to_safetensors(path=args.output_path, half=args.use_half)
|
|
@ -1,148 +0,0 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel # type: ignore
|
||||
from torch import nn
|
||||
|
||||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from refiners.fluxion.utils import load_from_safetensors, load_tensors
|
||||
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import SDXLLcmAdapter
|
||||
|
||||
|
||||
class Args(argparse.Namespace):
|
||||
source_path: str
|
||||
output_path: str | None
|
||||
subfolder: str
|
||||
half: bool
|
||||
verbose: bool
|
||||
skip_init_check: bool
|
||||
override_weights: str | None
|
||||
|
||||
|
||||
def setup_converter(args: Args) -> ModelConverter:
|
||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
||||
source: nn.Module = UNet2DConditionModel.from_pretrained( # type: ignore
|
||||
pretrained_model_name_or_path=args.source_path,
|
||||
subfolder=args.subfolder,
|
||||
low_cpu_mem_usage=False,
|
||||
)
|
||||
if args.override_weights is not None:
|
||||
if args.override_weights.endswith(".pth"):
|
||||
sd = load_tensors(args.override_weights)
|
||||
elif args.override_weights.endswith(".safetensors"):
|
||||
sd = load_from_safetensors(args.override_weights)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file format: {args.override_weights}")
|
||||
source.load_state_dict(sd)
|
||||
source_in_channels: int = source.config.in_channels # type: ignore
|
||||
source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore
|
||||
source_has_time_ids: bool = source.config.addition_embed_type == "text_time" # type: ignore
|
||||
source_is_lcm: bool = source.config.time_cond_proj_dim is not None
|
||||
|
||||
if source_has_time_ids:
|
||||
target = SDXLUNet(in_channels=source_in_channels)
|
||||
else:
|
||||
target = SD1UNet(in_channels=source_in_channels)
|
||||
|
||||
if source_is_lcm:
|
||||
assert isinstance(target, SDXLUNet)
|
||||
SDXLLcmAdapter(target=target).inject()
|
||||
|
||||
x = torch.randn(1, source_in_channels, 32, 32)
|
||||
timestep = torch.tensor(data=[0])
|
||||
clip_text_embeddings = torch.randn(1, 77, source_clip_embedding_dim)
|
||||
|
||||
target.set_timestep(timestep=timestep)
|
||||
target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
|
||||
added_cond_kwargs = {}
|
||||
if isinstance(target, SDXLUNet):
|
||||
added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)}
|
||||
target.set_time_ids(time_ids=added_cond_kwargs["time_ids"])
|
||||
target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
|
||||
|
||||
target_args = (x,)
|
||||
|
||||
source_kwargs: dict[str, Any] = {}
|
||||
if source_has_time_ids:
|
||||
source_kwargs["added_cond_kwargs"] = added_cond_kwargs
|
||||
if source_is_lcm:
|
||||
source_kwargs["timestep_cond"] = torch.randn(1, source.config.time_cond_proj_dim)
|
||||
|
||||
source_args = {
|
||||
"positional": (x, timestep, clip_text_embeddings),
|
||||
"keyword": source_kwargs,
|
||||
}
|
||||
|
||||
converter = ModelConverter(
|
||||
source_model=source,
|
||||
target_model=target,
|
||||
skip_init_check=args.skip_init_check,
|
||||
skip_output_check=True,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
if not converter.run(
|
||||
source_args=source_args,
|
||||
target_args=target_args,
|
||||
):
|
||||
raise RuntimeError("Model conversion failed")
|
||||
return converter
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Converts a Diffusion UNet model to a Refiners SD1UNet or SDXLUNet model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--from",
|
||||
type=str,
|
||||
dest="source_path",
|
||||
default="runwayml/stable-diffusion-v1-5",
|
||||
help=(
|
||||
"Can be a path to a .bin file, a .safetensors file or a model name from the HuggingFace Hub. Default:"
|
||||
" runwayml/stable-diffusion-v1-5"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--override-weights",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Path to a weights file to override the source model (keeping its config). "
|
||||
"This is useful for models distributed as .pth files."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
type=str,
|
||||
dest="output_path",
|
||||
default=None,
|
||||
help=(
|
||||
"Output path (.safetensors) for converted model. If not provided, the output path will be the same as the"
|
||||
" source path."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--subfolder", type=str, default="unet", help="Subfolder. Default: unet.")
|
||||
parser.add_argument(
|
||||
"--skip-init-check",
|
||||
action="store_true",
|
||||
help="Skip check that source and target have the same layers count.",
|
||||
)
|
||||
parser.add_argument("--half", action="store_true", help="Convert to half precision.")
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Prints additional information during conversion. Default: False",
|
||||
)
|
||||
args = parser.parse_args(namespace=Args())
|
||||
if args.output_path is None:
|
||||
args.output_path = f"{Path(args.source_path).stem}-unet.safetensors"
|
||||
converter = setup_converter(args=args)
|
||||
converter.save_to_safetensors(path=args.output_path, half=args.half)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,176 +0,0 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from refiners.fluxion.utils import load_tensors, save_to_safetensors
|
||||
|
||||
|
||||
def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None:
|
||||
"""Convert a DINOv2 weights from facebook to refiners."""
|
||||
# get depth from "blocks" keys
|
||||
depth = max([int(k.split(".")[1]) for k in weights.keys() if k.startswith("blocks.")]) + 1
|
||||
|
||||
# only needed when pre-training
|
||||
del weights["mask_token"]
|
||||
|
||||
# squeeze cls_token and position_embeddings
|
||||
weights["cls_token"] = weights["cls_token"].squeeze(0)
|
||||
weights["pos_embed"] = weights["pos_embed"].squeeze(0)
|
||||
|
||||
# rename "w12" to "fc1" and "w3" to "fc2", only for giant model
|
||||
for key in list(weights.keys()):
|
||||
if "w3" in key:
|
||||
new_key = key.replace("w3", "fc2")
|
||||
weights[new_key] = weights.pop(key)
|
||||
elif "w12" in key:
|
||||
# we swap w1 and w2 because of the difference between our GLU implementation and theirs
|
||||
# see https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/layers/swiglu_ffn.py#L31-L34
|
||||
# and https://github.com/finegrain-ai/refiners/blob/a2ee70578361e4d84a65a8708564480a9b0ec67e/src/refiners/fluxion/layers/activations.py#L158-L160
|
||||
weight = weights.pop(key)
|
||||
w1, w2 = weight.chunk(2, dim=0)
|
||||
w21 = torch.cat([w2, w1], dim=0)
|
||||
new_key = key.replace("w12", "fc1")
|
||||
weights[new_key] = w21
|
||||
|
||||
rename_keys: list[tuple[str, str]] = [
|
||||
("cls_token", "Concatenate.ClassToken.Parameter.weight"),
|
||||
("pos_embed", "PositionalEncoder.PositionalEmbedding.Parameter.weight"),
|
||||
("patch_embed.proj.weight", "Concatenate.PatchEncoder.Conv2d.weight"),
|
||||
("patch_embed.proj.bias", "Concatenate.PatchEncoder.Conv2d.bias"),
|
||||
("norm.weight", "LayerNorm.weight"),
|
||||
("norm.bias", "LayerNorm.bias"),
|
||||
]
|
||||
for i in range(depth):
|
||||
rename_keys.append(
|
||||
(
|
||||
f"blocks.{i}.norm1.weight",
|
||||
f"Transformer.TransformerLayer_{i+1}.Residual_1.LayerNorm.weight",
|
||||
),
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"blocks.{i}.norm1.bias",
|
||||
f"Transformer.TransformerLayer_{i+1}.Residual_1.LayerNorm.bias",
|
||||
),
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"blocks.{i}.attn.proj.weight",
|
||||
f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Linear.weight",
|
||||
),
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"blocks.{i}.attn.proj.bias",
|
||||
f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Linear.bias",
|
||||
),
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"blocks.{i}.ls1.gamma",
|
||||
f"Transformer.TransformerLayer_{i+1}.Residual_1.LayerScale.weight",
|
||||
),
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"blocks.{i}.norm2.weight",
|
||||
f"Transformer.TransformerLayer_{i+1}.Residual_2.LayerNorm.weight",
|
||||
),
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"blocks.{i}.norm2.bias",
|
||||
f"Transformer.TransformerLayer_{i+1}.Residual_2.LayerNorm.bias",
|
||||
),
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"blocks.{i}.mlp.fc1.weight",
|
||||
f"Transformer.TransformerLayer_{i+1}.Residual_2.FeedForward.Linear_1.weight",
|
||||
),
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"blocks.{i}.mlp.fc1.bias",
|
||||
f"Transformer.TransformerLayer_{i+1}.Residual_2.FeedForward.Linear_1.bias",
|
||||
),
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"blocks.{i}.mlp.fc2.weight",
|
||||
f"Transformer.TransformerLayer_{i+1}.Residual_2.FeedForward.Linear_2.weight",
|
||||
),
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"blocks.{i}.mlp.fc2.bias",
|
||||
f"Transformer.TransformerLayer_{i+1}.Residual_2.FeedForward.Linear_2.bias",
|
||||
),
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"blocks.{i}.ls2.gamma",
|
||||
f"Transformer.TransformerLayer_{i+1}.Residual_2.LayerScale.weight",
|
||||
),
|
||||
)
|
||||
|
||||
if "register_tokens" in weights:
|
||||
weights["register_tokens"] = weights["register_tokens"].squeeze(0)
|
||||
rename_keys.append(("register_tokens", "Registers.Parameter.weight"))
|
||||
|
||||
# rename keys
|
||||
for old_key, new_key in rename_keys:
|
||||
weights[new_key] = weights.pop(old_key)
|
||||
|
||||
# split the qkv weights and biases
|
||||
for i in range(depth):
|
||||
qkv_weight = weights.pop(f"blocks.{i}.attn.qkv.weight")
|
||||
q_weight, k_weight, v_weight = qkv_weight.chunk(3, dim=0)
|
||||
weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_1.weight"] = q_weight
|
||||
weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_2.weight"] = k_weight
|
||||
weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_3.weight"] = v_weight
|
||||
|
||||
qkv_bias = weights.pop(f"blocks.{i}.attn.qkv.bias")
|
||||
q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
|
||||
weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_1.bias"] = q_bias
|
||||
weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_2.bias"] = k_bias
|
||||
weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_3.bias"] = v_bias
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--from",
|
||||
type=str,
|
||||
required=True,
|
||||
dest="source_path",
|
||||
help=(
|
||||
"Official checkpoint from https://github.com/facebookresearch/dinov2"
|
||||
" e.g. /path/to/dinov2_vits14_pretrain.pth"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
type=str,
|
||||
dest="output_path",
|
||||
default=None,
|
||||
help=(
|
||||
"Path to save the converted model. If not specified, the output path will be the source path with the"
|
||||
" extension changed to .safetensors."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--half", action="store_true", dest="half")
|
||||
args = parser.parse_args()
|
||||
|
||||
weights = load_tensors(args.source_path)
|
||||
convert_dinov2_facebook(weights)
|
||||
if args.half:
|
||||
weights = {key: value.half() for key, value in weights.items()}
|
||||
if args.output_path is None:
|
||||
args.output_path = f"{Path(args.source_path).stem}.safetensors"
|
||||
save_to_safetensors(path=args.output_path, tensors=weights)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,102 +0,0 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download # type: ignore
|
||||
|
||||
from refiners.fluxion.utils import load_from_safetensors, save_to_safetensors
|
||||
|
||||
|
||||
class Args(argparse.Namespace):
|
||||
source_path: str
|
||||
output_path: str | None
|
||||
use_half: bool
|
||||
|
||||
|
||||
def convert(args: Args) -> dict[str, torch.Tensor]:
|
||||
if Path(args.source_path).suffix != ".safetensors":
|
||||
args.source_path = hf_hub_download(
|
||||
repo_id=args.source_path, filename="ella-sd1.5-tsc-t5xl.safetensors", local_dir="tests/weights/ELLA-Adapter"
|
||||
)
|
||||
weights = load_from_safetensors(args.source_path)
|
||||
|
||||
for key in list(weights.keys()):
|
||||
if "latents" in key:
|
||||
new_key = "PerceiverResampler.Latents.ParameterInitialized.weight"
|
||||
weights[new_key] = weights.pop(key)
|
||||
elif "time_embedding" in key:
|
||||
new_key = key.replace("time_embedding", "TimestepEncoder.RangeEncoder").replace("linear", "Linear")
|
||||
weights[new_key] = weights.pop(key)
|
||||
elif "proj_in" in key:
|
||||
new_key = f"PerceiverResampler.Linear.{key.split('.')[-1]}"
|
||||
weights[new_key] = weights.pop(key)
|
||||
elif "time_aware" in key:
|
||||
new_key = f"PerceiverResampler.Residual.Linear.{key.split('.')[-1]}"
|
||||
weights[new_key] = weights.pop(key)
|
||||
elif "attn.in_proj" in key:
|
||||
layer_num = int(key.split(".")[2])
|
||||
query_param, key_param, value_param = weights.pop(key).chunk(3, dim=0)
|
||||
param_type = "weight" if "weight" in key else "bias"
|
||||
for i, param in enumerate([query_param, key_param, value_param]):
|
||||
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_1.PerceiverAttention.Attention.Distribute.Linear_{i+1}.{param_type}"
|
||||
weights[new_key] = param
|
||||
elif "attn.out_proj" in key:
|
||||
layer_num = int(key.split(".")[2])
|
||||
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_1.PerceiverAttention.Attention.Linear.{key.split('.')[-1]}"
|
||||
weights[new_key] = weights.pop(key)
|
||||
elif "ln_ff" in key:
|
||||
layer_num = int(key.split(".")[2])
|
||||
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_2.AdaLayerNorm.Parallel.Chain.Linear.{key.split('.')[-1]}"
|
||||
weights[new_key] = weights.pop(key)
|
||||
elif "ln_1" in key or "ln_2" in key:
|
||||
layer_num = int(key.split(".")[2])
|
||||
n = 1 if int(key.split(".")[3].split("_")[-1]) == 2 else 2
|
||||
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_1.PerceiverAttention.Distribute.AdaLayerNorm_{n}.Parallel.Chain.Linear.{key.split('.')[-1]}"
|
||||
weights[new_key] = weights.pop(key)
|
||||
elif "mlp" in key:
|
||||
layer_num = int(key.split(".")[2])
|
||||
n = 1 if "c_fc" in key else 2
|
||||
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_2.FeedForward.Linear_{n}.{key.split('.')[-1]}"
|
||||
weights[new_key] = weights.pop(key)
|
||||
|
||||
if args.use_half:
|
||||
weights = {key: value.half() for key, value in weights.items()}
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert a pretrained Ella Adapter to refiners implementation")
|
||||
parser.add_argument(
|
||||
"--from",
|
||||
type=str,
|
||||
dest="source_path",
|
||||
default="QQGYLab/ELLA",
|
||||
help=(
|
||||
"A path to a local .safetensors weights. If not provided, a repo from Hugging Face Hub will be used"
|
||||
"Default to QQGYLab/ELLA"
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
type=str,
|
||||
dest="output_path",
|
||||
default=None,
|
||||
help=(
|
||||
"Path to save the converted model (extension will be .safetensors). If not specified, the output path will"
|
||||
" be the source path with the prefix set to refiners"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--half",
|
||||
action="store_true",
|
||||
dest="use_half",
|
||||
default=True,
|
||||
help="Use this flag to save the output file as half precision (default: full precision).",
|
||||
)
|
||||
args = parser.parse_args(namespace=Args())
|
||||
weights = convert(args)
|
||||
if args.output_path is None:
|
||||
args.output_path = f"{Path(args.source_path).stem}-refiners.safetensors"
|
||||
save_to_safetensors(path=args.output_path, tensors=weights)
|
|
@ -1,348 +0,0 @@
|
|||
import argparse
|
||||
import logging
|
||||
from logging import info
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import hf_hub_download # type: ignore
|
||||
from torch import Tensor
|
||||
from torch.nn import Parameter as TorchParameter
|
||||
|
||||
from refiners.fluxion.adapters.lora import Lora, LoraAdapter, auto_attach_loras
|
||||
from refiners.fluxion.layers import Conv2d
|
||||
from refiners.fluxion.layers.linear import Linear
|
||||
from refiners.fluxion.utils import load_from_safetensors, save_to_safetensors
|
||||
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.control_lora import (
|
||||
ConditionEncoder,
|
||||
ControlLora,
|
||||
ControlLoraAdapter,
|
||||
ZeroConvolution,
|
||||
)
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||
|
||||
|
||||
def sort_keys(key: str, /) -> tuple[str, int]:
|
||||
"""Compute the score of a key, relatively to its suffix.
|
||||
|
||||
When used by [`sorted`][sorted], the keys will only be sorted "at the suffix level".
|
||||
|
||||
Args:
|
||||
key: The key to sort.
|
||||
|
||||
Returns:
|
||||
The padded suffix of the key.
|
||||
The score of the key's suffix.
|
||||
"""
|
||||
if "time_embed" in key: # HACK: will place the "time_embed" layers at very start of the list
|
||||
return ("", -2)
|
||||
|
||||
if "label_emb" in key: # HACK: will place the "label_emb" layers right after "time_embed"
|
||||
return ("", -1)
|
||||
|
||||
if "proj_out" in key: # HACK: will place the "proj_out" layers at the end of each "transformer_blocks"
|
||||
return (key.removesuffix("proj_out") + "transformer_blocks.99.ff.net.2", 10)
|
||||
|
||||
return SDLoraManager.sort_keys(key)
|
||||
|
||||
|
||||
def load_lora_layers(
|
||||
name: str,
|
||||
state_dict: dict[str, Tensor],
|
||||
control_lora: ControlLora,
|
||||
) -> dict[str, Lora[Linear | Conv2d]]:
|
||||
"""Load the LoRA layers from the state_dict into the ControlLora.
|
||||
|
||||
Args:
|
||||
name: The name of the LoRA.
|
||||
state_dict: The state_dict of the LoRA.
|
||||
control_lora: The ControlLora to load the LoRA layers into.
|
||||
"""
|
||||
# filter from the state_dict the layers that will be used for the LoRA layers
|
||||
lora_weights = {f"{key}.weight": value for key, value in state_dict.items() if ".up" in key or ".down" in key}
|
||||
|
||||
# move the tensors to the device and dtype of the ControlLora
|
||||
lora_weights = {
|
||||
key: value.to(
|
||||
dtype=control_lora.dtype,
|
||||
device=control_lora.device,
|
||||
)
|
||||
for key, value in lora_weights.items()
|
||||
}
|
||||
|
||||
# load every LoRA layers from the filtered state_dict
|
||||
lora_layers = Lora.from_dict(name, state_dict=lora_weights)
|
||||
|
||||
# sort all the LoRA's keys using the `sort_keys` method
|
||||
lora_layers = {
|
||||
key: lora_layers[key]
|
||||
for key in sorted(
|
||||
lora_layers.keys(),
|
||||
key=sort_keys,
|
||||
)
|
||||
}
|
||||
|
||||
# auto-attach the LoRA layers to the U-Net
|
||||
auto_attach_loras(lora_layers, control_lora, exclude=["ZeroConvolution", "ConditionEncoder"])
|
||||
|
||||
# eject all the LoRA adapters from the U-Net
|
||||
# because we need each target path as if the adapter wasn't injected
|
||||
for lora_layer in lora_layers.values():
|
||||
lora_adapter = lora_layer.parent
|
||||
assert isinstance(lora_adapter, LoraAdapter)
|
||||
lora_adapter.eject()
|
||||
|
||||
return lora_layers
|
||||
|
||||
|
||||
def load_condition_encoder(
|
||||
state_dict: dict[str, Tensor],
|
||||
control_lora: ControlLora,
|
||||
) -> None:
|
||||
"""Load the ConditionEncoder's Conv2d layers from the state_dict into the ControlLora.
|
||||
|
||||
Args:
|
||||
state_dict: The state_dict of the ConditionEncoder.
|
||||
control_lora: The control_lora to load the ConditionEncoder's Conv2d layers into.
|
||||
"""
|
||||
# filter from the state_dict the layers that will be used for the ConditionEncoder
|
||||
condition_encoder_tensors = {key: value for key, value in state_dict.items() if "input_hint_block" in key}
|
||||
|
||||
# move the tensors to the device and dtype of the ControlLora
|
||||
condition_encoder_tensors = {
|
||||
key: value.to(
|
||||
dtype=control_lora.dtype,
|
||||
device=control_lora.device,
|
||||
)
|
||||
for key, value in condition_encoder_tensors.items()
|
||||
}
|
||||
|
||||
# find the ConditionEncoder's Conv2d layers
|
||||
condition_encoder_layer = control_lora.ensure_find(ConditionEncoder)
|
||||
condition_encoder_conv2ds = list(condition_encoder_layer.layers(Conv2d))
|
||||
|
||||
# replace the Conv2d layers' weights and biases with the ones from the state_dict
|
||||
for i, layer in enumerate(condition_encoder_conv2ds):
|
||||
layer.weight = TorchParameter(condition_encoder_tensors[f"input_hint_block.{i*2}.weight"])
|
||||
layer.bias = TorchParameter(condition_encoder_tensors[f"input_hint_block.{i*2}.bias"])
|
||||
|
||||
|
||||
def load_zero_convolutions(
|
||||
state_dict: dict[str, Tensor],
|
||||
control_lora: ControlLora,
|
||||
) -> None:
|
||||
"""Load the ZeroConvolution's Conv2d layers from the state_dict into the ControlLora.
|
||||
|
||||
Args:
|
||||
state_dict: The state_dict of the ZeroConvolution.
|
||||
control_lora: The ControlLora to load the ZeroConvolution's Conv2d layers into.
|
||||
"""
|
||||
# filter from the state_dict the layers that will be used for the ZeroConvolution layers
|
||||
zero_convolution_tensors = {key: value for key, value in state_dict.items() if "zero_convs" in key}
|
||||
n = len(zero_convolution_tensors) // 2
|
||||
zero_convolution_tensors[f"zero_convs.{n}.0.weight"] = state_dict["middle_block_out.0.weight"]
|
||||
zero_convolution_tensors[f"zero_convs.{n}.0.bias"] = state_dict["middle_block_out.0.bias"]
|
||||
|
||||
# move the tensors to the device and dtype of the ControlLora
|
||||
zero_convolution_tensors = {
|
||||
key: value.to(
|
||||
dtype=control_lora.dtype,
|
||||
device=control_lora.device,
|
||||
)
|
||||
for key, value in zero_convolution_tensors.items()
|
||||
}
|
||||
|
||||
# find the ZeroConvolution's Conv2d layers
|
||||
zero_convolution_layers = list(control_lora.layers(ZeroConvolution))
|
||||
zero_convolution_conv2ds = [layer.ensure_find(Conv2d) for layer in zero_convolution_layers]
|
||||
|
||||
# replace the Conv2d layers' weights and biases with the ones from the state_dict
|
||||
for i, layer in enumerate(zero_convolution_conv2ds):
|
||||
layer.weight = TorchParameter(zero_convolution_tensors[f"zero_convs.{i}.0.weight"])
|
||||
layer.bias = TorchParameter(zero_convolution_tensors[f"zero_convs.{i}.0.bias"])
|
||||
|
||||
|
||||
def simplify_key(key: str, prefix: str, index: int | None = None) -> str:
|
||||
"""Simplify a key by stripping everything to the left of the prefix.
|
||||
|
||||
Also optionally add a zero-padded index to the prefix.
|
||||
|
||||
Example:
|
||||
>>> simplify_key("foo.bar.ControlLora.something", "ControlLora", 1)
|
||||
"ControlLora_01.something"
|
||||
|
||||
>>> simplify_key("foo.bar.ControlLora.DownBlocks.something", "ControlLora")
|
||||
"ControlLora.DownBlocks.something"
|
||||
|
||||
Args:
|
||||
key: The key to simplify.
|
||||
prefix: The prefix to remove.
|
||||
index: The index to add.
|
||||
"""
|
||||
_, right = key.split(prefix, maxsplit=1)
|
||||
if index:
|
||||
return f"{prefix}_{index:02d}{right}"
|
||||
else:
|
||||
return f"{prefix}{right}"
|
||||
|
||||
|
||||
def convert_lora_layers(
|
||||
lora_layers: dict[str, Lora[Linear | Conv2d]],
|
||||
control_lora: ControlLora,
|
||||
refiners_state_dict: dict[str, Tensor],
|
||||
) -> None:
|
||||
"""Convert the LoRA layers to the refiners format.
|
||||
|
||||
Args:
|
||||
lora_layers: The LoRA layers to convert.
|
||||
control_lora: The ControlLora to convert the LoRA layers from.
|
||||
refiners_state_dict: The refiners state dict to update with the converted LoRA layers.
|
||||
"""
|
||||
for lora_layer in lora_layers.values():
|
||||
# get the adapter associated with the LoRA layer
|
||||
lora_adapter = lora_layer.parent
|
||||
assert isinstance(lora_adapter, LoraAdapter)
|
||||
|
||||
# get the path of the adapter's target in the ControlLora
|
||||
target = lora_adapter.target
|
||||
path = target.get_path(parent=control_lora.ensure_find_parent(target))
|
||||
|
||||
state_dict = {
|
||||
f"{path}.down": lora_layer.down.weight,
|
||||
f"{path}.up": lora_layer.up.weight,
|
||||
}
|
||||
state_dict = {simplify_key(key, "ControlLora."): param for key, param in state_dict.items()}
|
||||
refiners_state_dict.update(state_dict)
|
||||
|
||||
|
||||
def convert_zero_convolutions(
|
||||
control_lora: ControlLora,
|
||||
refiners_state_dict: dict[str, Tensor],
|
||||
) -> None:
|
||||
"""Convert the ZeroConvolution layers to the refiners format.
|
||||
|
||||
Args:
|
||||
control_lora: The ControlLora to convert the ZeroConvolution layers from.
|
||||
refiners_state_dict: The refiners state dict to update with the converted ZeroConvolution layers.
|
||||
"""
|
||||
zero_convolution_layers = list(control_lora.layers(ZeroConvolution))
|
||||
for i, zero_convolution_layer in enumerate(zero_convolution_layers):
|
||||
state_dict = zero_convolution_layer.state_dict()
|
||||
path = zero_convolution_layer.get_path()
|
||||
state_dict = {f"{path}.{key}": param for key, param in state_dict.items()}
|
||||
state_dict = {simplify_key(key, "ZeroConvolution", i + 1): param for key, param in state_dict.items()}
|
||||
refiners_state_dict.update(state_dict)
|
||||
|
||||
|
||||
def convert_condition_encoder(
|
||||
control_lora: ControlLora,
|
||||
refiners_state_dict: dict[str, Tensor],
|
||||
) -> None:
|
||||
"""Convert the ConditionEncoder to the refiners format.
|
||||
|
||||
Args:
|
||||
control_lora: The ControlLora to convert the ConditionEncoder from.
|
||||
refiners_state_dict: The refiners state dict to update with the converted ConditionEncoder.
|
||||
"""
|
||||
condition_encoder_layer = control_lora.ensure_find(ConditionEncoder)
|
||||
path = condition_encoder_layer.get_path()
|
||||
state_dict = condition_encoder_layer.state_dict()
|
||||
state_dict = {f"{path}.{key}": param for key, param in state_dict.items()}
|
||||
state_dict = {simplify_key(key, "ConditionEncoder"): param for key, param in state_dict.items()}
|
||||
refiners_state_dict.update(state_dict)
|
||||
|
||||
|
||||
def convert(
|
||||
name: str,
|
||||
state_dict_path: Path,
|
||||
output_path: Path,
|
||||
) -> None:
|
||||
sdxl = StableDiffusion_XL()
|
||||
info("Stable Diffusion XL model initialized")
|
||||
|
||||
fooocus_state_dict = load_from_safetensors(state_dict_path)
|
||||
info(f"Fooocus weights loaded from: {state_dict_path}")
|
||||
|
||||
control_lora_adapter = ControlLoraAdapter(target=sdxl.unet, name=name).inject()
|
||||
control_lora = control_lora_adapter.control_lora
|
||||
info("ControlLoraAdapter initialized")
|
||||
|
||||
lora_layers = load_lora_layers(name, fooocus_state_dict, control_lora)
|
||||
info("LoRA layers loaded")
|
||||
|
||||
load_zero_convolutions(fooocus_state_dict, control_lora)
|
||||
info("ZeroConvolution layers loaded")
|
||||
|
||||
load_condition_encoder(fooocus_state_dict, control_lora)
|
||||
info("ConditionEncoder loaded")
|
||||
|
||||
refiners_state_dict: dict[str, Tensor] = {}
|
||||
convert_lora_layers(lora_layers, control_lora, refiners_state_dict)
|
||||
info("LoRA layers converted to refiners format")
|
||||
|
||||
convert_zero_convolutions(control_lora, refiners_state_dict)
|
||||
info("ZeroConvolution layers converted to refiners format")
|
||||
|
||||
convert_condition_encoder(control_lora, refiners_state_dict)
|
||||
info("ConditionEncoder converted to refiners format")
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
save_to_safetensors(path=output_path, tensors=refiners_state_dict)
|
||||
info(f"Converted ControlLora state dict saved to disk at: {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert ControlLora (from Fooocus) weights to refiners.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--from",
|
||||
type=Path,
|
||||
dest="source_path",
|
||||
default="lllyasviel/misc:control-lora-canny-rank128.safetensors",
|
||||
help="Path to the state_dict of the ControlLora, or a Hugging Face model ID.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
type=Path,
|
||||
dest="output_path",
|
||||
help=(
|
||||
"Path to save the converted model (extension will be .safetensors)."
|
||||
"If not specified, the output path will be the source path with the extension changed to .safetensors."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
dest="verbose",
|
||||
default=False,
|
||||
help="Use this flag to print verbose output during conversion.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.verbose:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(levelname)s: %(message)s",
|
||||
)
|
||||
|
||||
if not args.source_path.exists():
|
||||
repo_id, filename = str(args.source_path).split(":")
|
||||
args.source_path = Path(
|
||||
hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename=filename,
|
||||
)
|
||||
)
|
||||
|
||||
if args.output_path is None:
|
||||
args.output_path = Path(f"refiners_{args.source_path.stem}.safetensors")
|
||||
|
||||
convert(
|
||||
name=args.source_path.stem,
|
||||
state_dict_path=args.source_path,
|
||||
output_path=args.output_path,
|
||||
)
|
|
@ -1,81 +0,0 @@
|
|||
import argparse
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from refiners.fluxion.utils import load_tensors, save_to_safetensors
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Convert HQ SAM model to Refiners state_dict format")
|
||||
parser.add_argument(
|
||||
"--from",
|
||||
type=str,
|
||||
dest="source_path",
|
||||
required=True,
|
||||
default="sam_hq_vit_h.pth",
|
||||
help="Path to the source model checkpoint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
type=str,
|
||||
dest="output_path",
|
||||
required=True,
|
||||
default="refiners_sam_hq_vit_h.safetensors",
|
||||
help="Path to save the converted model in Refiners format.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
source_state_dict = load_tensors(args.source_path)
|
||||
|
||||
state_dict: dict[str, Tensor] = {}
|
||||
|
||||
for suffix in ["weight", "bias"]:
|
||||
state_dict[f"HQFeatures.CompressViTFeat.ConvTranspose2d_1.{suffix}"] = source_state_dict[
|
||||
f"mask_decoder.compress_vit_feat.0.{suffix}"
|
||||
]
|
||||
state_dict[f"HQFeatures.EmbeddingEncoder.ConvTranspose2d_1.{suffix}"] = source_state_dict[
|
||||
f"mask_decoder.embedding_encoder.0.{suffix}"
|
||||
]
|
||||
state_dict[f"EmbeddingMaskfeature.Conv2d_1.{suffix}"] = source_state_dict[
|
||||
f"mask_decoder.embedding_maskfeature.0.{suffix}"
|
||||
]
|
||||
|
||||
state_dict[f"HQFeatures.CompressViTFeat.LayerNorm2d.{suffix}"] = source_state_dict[
|
||||
f"mask_decoder.compress_vit_feat.1.{suffix}"
|
||||
]
|
||||
state_dict[f"HQFeatures.EmbeddingEncoder.LayerNorm2d.{suffix}"] = source_state_dict[
|
||||
f"mask_decoder.embedding_encoder.1.{suffix}"
|
||||
]
|
||||
state_dict[f"EmbeddingMaskfeature.LayerNorm2d.{suffix}"] = source_state_dict[
|
||||
f"mask_decoder.embedding_maskfeature.1.{suffix}"
|
||||
]
|
||||
|
||||
state_dict[f"HQFeatures.CompressViTFeat.ConvTranspose2d_2.{suffix}"] = source_state_dict[
|
||||
f"mask_decoder.compress_vit_feat.3.{suffix}"
|
||||
]
|
||||
state_dict[f"HQFeatures.EmbeddingEncoder.ConvTranspose2d_2.{suffix}"] = source_state_dict[
|
||||
f"mask_decoder.embedding_encoder.3.{suffix}"
|
||||
]
|
||||
state_dict[f"EmbeddingMaskfeature.Conv2d_2.{suffix}"] = source_state_dict[
|
||||
f"mask_decoder.embedding_maskfeature.3.{suffix}"
|
||||
]
|
||||
|
||||
state_dict = {f"Chain.HQSAMMaskPrediction.Chain.DenseEmbeddingUpscalingHQ.{k}": v for k, v in state_dict.items()}
|
||||
|
||||
# HQ Token
|
||||
state_dict["MaskDecoderTokensExtender.hq_token.weight"] = source_state_dict["mask_decoder.hf_token.weight"]
|
||||
|
||||
# HQ MLP
|
||||
for i in range(3):
|
||||
state_dict[f"Chain.HQSAMMaskPrediction.HQTokenMLP.MultiLinear.Linear_{i+1}.weight"] = source_state_dict[
|
||||
f"mask_decoder.hf_mlp.layers.{i}.weight"
|
||||
]
|
||||
state_dict[f"Chain.HQSAMMaskPrediction.HQTokenMLP.MultiLinear.Linear_{i+1}.bias"] = source_state_dict[
|
||||
f"mask_decoder.hf_mlp.layers.{i}.bias"
|
||||
]
|
||||
|
||||
save_to_safetensors(path=args.output_path, tensors=state_dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,89 +0,0 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from convert_diffusers_unet import Args as UNetArgs, setup_converter as setup_unet_converter
|
||||
from huggingface_hub import hf_hub_download # type: ignore
|
||||
|
||||
from refiners.fluxion.utils import load_from_safetensors, save_to_safetensors
|
||||
|
||||
|
||||
class Args(argparse.Namespace):
|
||||
source_path: str
|
||||
output_path: str | None
|
||||
subfolder: str
|
||||
half: bool
|
||||
verbose: bool
|
||||
reference_unet_path: str
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Converts IC-Light patch weights to work with Refiners")
|
||||
parser.add_argument(
|
||||
"--from",
|
||||
type=str,
|
||||
dest="source_path",
|
||||
default="lllyasviel/ic-light",
|
||||
help=(
|
||||
"Can be a path to a .bin file, a .safetensors file or a model name from the Hugging Face Hub. Default:"
|
||||
" lllyasviel/ic-light"
|
||||
),
|
||||
)
|
||||
parser.add_argument("--filename", type=str, default="iclight_sd15_fc.safetensors", help="Filename inside the hub.")
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
type=str,
|
||||
dest="output_path",
|
||||
default=None,
|
||||
help=(
|
||||
"Output path (.safetensors) for converted model. If not provided, the output path will be the same as the"
|
||||
" source path."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Prints additional information during conversion. Default: False",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reference-unet-path",
|
||||
type=str,
|
||||
dest="reference_unet_path",
|
||||
default="runwayml/stable-diffusion-v1-5",
|
||||
help="Path to the reference UNet weights.",
|
||||
)
|
||||
args = parser.parse_args(namespace=Args())
|
||||
if args.output_path is None:
|
||||
args.output_path = f"{Path(args.filename).stem}-refiners.safetensors"
|
||||
|
||||
patch_file = (
|
||||
Path(args.source_path)
|
||||
if args.source_path.endswith(".safetensors")
|
||||
else Path(
|
||||
hf_hub_download(
|
||||
repo_id=args.source_path,
|
||||
filename=args.filename,
|
||||
)
|
||||
)
|
||||
)
|
||||
patch_weights = load_from_safetensors(patch_file)
|
||||
|
||||
unet_args = UNetArgs(
|
||||
source_path=args.reference_unet_path,
|
||||
subfolder="unet",
|
||||
half=False,
|
||||
verbose=False,
|
||||
skip_init_check=True,
|
||||
override_weights=None,
|
||||
)
|
||||
converter = setup_unet_converter(args=unet_args)
|
||||
result = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage]
|
||||
source_state_dict=patch_weights,
|
||||
target_state_dict=converter.target_model.state_dict(),
|
||||
state_dict_mapping=converter.get_mapping(),
|
||||
)
|
||||
save_to_safetensors(path=args.output_path, tensors=result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,65 +0,0 @@
|
|||
import argparse
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from refiners.fluxion.utils import load_tensors
|
||||
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
|
||||
|
||||
try:
|
||||
from model import Generator # type: ignore
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please download the model.py file from https://github.com/carolineec/informative-drawings and add it to your"
|
||||
" PYTHONPATH"
|
||||
)
|
||||
|
||||
|
||||
class Args(argparse.Namespace):
|
||||
source_path: str
|
||||
output_path: str
|
||||
verbose: bool
|
||||
half: bool
|
||||
|
||||
|
||||
def setup_converter(args: Args) -> ModelConverter:
|
||||
source = cast(nn.Module, Generator(3, 1, 3))
|
||||
source.load_state_dict(state_dict=load_tensors(args.source_path))
|
||||
source.eval()
|
||||
target = InformativeDrawings()
|
||||
x = torch.randn(1, 3, 512, 512)
|
||||
converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose)
|
||||
if not converter.run(source_args=(x,)):
|
||||
raise RuntimeError("Model conversion failed")
|
||||
return converter
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Converts a pretrained Informative Drawings model to a refiners Informative Drawings model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--from",
|
||||
type=str,
|
||||
dest="source_path",
|
||||
default="model2.pth",
|
||||
help="Path to the source model. (default: 'model2.pth').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
type=str,
|
||||
dest="output_path",
|
||||
default="informative-drawings.safetensors",
|
||||
help="Path to save the converted model. (default: 'informative-drawings.safetensors').",
|
||||
)
|
||||
parser.add_argument("--verbose", action="store_true", dest="verbose")
|
||||
parser.add_argument("--half", action="store_true", dest="half")
|
||||
args = parser.parse_args(namespace=Args())
|
||||
converter = setup_converter(args=args)
|
||||
converter.save_to_safetensors(path=args.output_path, half=args.half)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,40 +0,0 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from refiners.fluxion.utils import load_tensors, save_to_safetensors
|
||||
from refiners.foundationals.swin.mvanet.converter import convert_weights
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--from",
|
||||
type=str,
|
||||
required=True,
|
||||
dest="source_path",
|
||||
help="A MVANet checkpoint. One can be found at https://github.com/qianyu-dlut/MVANet",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
type=str,
|
||||
dest="output_path",
|
||||
default=None,
|
||||
help=(
|
||||
"Path to save the converted model. If not specified, the output path will be the source path with the"
|
||||
" extension changed to .safetensors."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--half", action="store_true", dest="half")
|
||||
args = parser.parse_args()
|
||||
|
||||
src_weights = load_tensors(args.source_path)
|
||||
weights = convert_weights(src_weights)
|
||||
if args.half:
|
||||
weights = {key: value.half() for key, value in weights.items()}
|
||||
if args.output_path is None:
|
||||
args.output_path = f"{Path(args.source_path).stem}.safetensors"
|
||||
save_to_safetensors(path=args.output_path, tensors=weights)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,268 +0,0 @@
|
|||
import argparse
|
||||
import types
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from segment_anything import build_sam_vit_h # type: ignore
|
||||
from segment_anything.modeling.common import LayerNorm2d # type: ignore
|
||||
from torch import Tensor
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from refiners.fluxion.utils import load_tensors, manual_seed, save_to_safetensors
|
||||
from refiners.foundationals.segment_anything.image_encoder import PositionalEncoder, SAMViTH
|
||||
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
||||
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
|
||||
|
||||
|
||||
class FacebookSAM(nn.Module):
|
||||
image_encoder: nn.Module
|
||||
prompt_encoder: nn.Module
|
||||
mask_decoder: nn.Module
|
||||
|
||||
|
||||
build_sam_vit_h = cast(Callable[[], FacebookSAM], build_sam_vit_h)
|
||||
|
||||
|
||||
assert issubclass(LayerNorm2d, nn.Module)
|
||||
custom_layers = {LayerNorm2d: fl.LayerNorm2d}
|
||||
|
||||
|
||||
class Args(argparse.Namespace):
|
||||
source_path: str
|
||||
output_path: str
|
||||
half: bool
|
||||
verbose: bool
|
||||
|
||||
|
||||
def convert_mask_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]:
|
||||
manual_seed(seed=0)
|
||||
refiners_mask_encoder = MaskEncoder()
|
||||
|
||||
converter = ModelConverter(
|
||||
source_model=prompt_encoder.mask_downscaling,
|
||||
target_model=refiners_mask_encoder,
|
||||
custom_layer_mapping=custom_layers, # type: ignore
|
||||
)
|
||||
|
||||
x = torch.randn(1, 256, 256)
|
||||
mapping = converter.map_state_dicts(source_args=(x,))
|
||||
assert mapping
|
||||
|
||||
source_state_dict = prompt_encoder.mask_downscaling.state_dict()
|
||||
target_state_dict = refiners_mask_encoder.state_dict()
|
||||
|
||||
# Mapping handled manually (see below) because nn.Parameter is a special case
|
||||
del target_state_dict["no_mask_embedding"]
|
||||
|
||||
converted_source = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage]
|
||||
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
|
||||
)
|
||||
|
||||
state_dict: dict[str, Tensor] = {
|
||||
"no_mask_embedding": nn.Parameter(data=prompt_encoder.no_mask_embed.weight.clone()), # type: ignore
|
||||
}
|
||||
|
||||
state_dict.update(converted_source)
|
||||
|
||||
refiners_mask_encoder.load_state_dict(state_dict=state_dict)
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def convert_point_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]:
|
||||
manual_seed(seed=0)
|
||||
point_embeddings: list[Tensor] = [pe.weight for pe in prompt_encoder.point_embeddings] + [
|
||||
prompt_encoder.not_a_point_embed.weight
|
||||
] # type: ignore
|
||||
pe = prompt_encoder.pe_layer.positional_encoding_gaussian_matrix # type: ignore
|
||||
assert isinstance(pe, Tensor)
|
||||
state_dict: dict[str, Tensor] = {
|
||||
"Residual.PointTypeEmbedding.weight": nn.Parameter(data=torch.cat(tensors=point_embeddings, dim=0)),
|
||||
"CoordinateEncoder.Linear.weight": nn.Parameter(data=pe.T.contiguous()),
|
||||
}
|
||||
|
||||
refiners_prompt_encoder = PointEncoder()
|
||||
refiners_prompt_encoder.load_state_dict(state_dict=state_dict)
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def convert_vit(vit: nn.Module) -> dict[str, Tensor]:
|
||||
manual_seed(seed=0)
|
||||
refiners_sam_vit_h = SAMViTH()
|
||||
|
||||
converter = ModelConverter(
|
||||
source_model=vit,
|
||||
target_model=refiners_sam_vit_h,
|
||||
custom_layer_mapping=custom_layers, # type: ignore
|
||||
)
|
||||
converter.skip_init_check = True
|
||||
|
||||
x = torch.randn(1, 3, 1024, 1024)
|
||||
mapping = converter.map_state_dicts(source_args=(x,))
|
||||
assert mapping
|
||||
|
||||
mapping["PositionalEncoder.Parameter.weight"] = "pos_embed"
|
||||
|
||||
target_state_dict = refiners_sam_vit_h.state_dict()
|
||||
del target_state_dict["PositionalEncoder.Parameter.weight"]
|
||||
|
||||
source_state_dict = vit.state_dict()
|
||||
pos_embed = source_state_dict["pos_embed"]
|
||||
del source_state_dict["pos_embed"]
|
||||
|
||||
target_rel_keys = [
|
||||
(
|
||||
f"Transformer.TransformerLayer_{i}.Residual_1.FusedSelfAttention.RelativePositionAttention.horizontal_embedding",
|
||||
f"Transformer.TransformerLayer_{i}.Residual_1.FusedSelfAttention.RelativePositionAttention.vertical_embedding",
|
||||
)
|
||||
for i in range(1, 33)
|
||||
]
|
||||
source_rel_keys = [(f"blocks.{i}.attn.rel_pos_w", f"blocks.{i}.attn.rel_pos_h") for i in range(32)]
|
||||
|
||||
rel_items: dict[str, Tensor] = {}
|
||||
|
||||
for (key_w, key_h), (target_key_w, target_key_h) in zip(source_rel_keys, target_rel_keys):
|
||||
rel_items[target_key_w] = source_state_dict[key_w]
|
||||
rel_items[target_key_h] = source_state_dict[key_h]
|
||||
del source_state_dict[key_w]
|
||||
del source_state_dict[key_h]
|
||||
del target_state_dict[target_key_w]
|
||||
del target_state_dict[target_key_h]
|
||||
|
||||
converted_source = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage]
|
||||
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
|
||||
)
|
||||
|
||||
positional_encoder = refiners_sam_vit_h.layer("PositionalEncoder", PositionalEncoder)
|
||||
embed = pos_embed.reshape_as(positional_encoder.layer("Parameter", fl.Parameter).weight)
|
||||
converted_source["PositionalEncoder.Parameter.weight"] = embed # type: ignore
|
||||
converted_source.update(rel_items)
|
||||
|
||||
refiners_sam_vit_h.load_state_dict(state_dict=converted_source)
|
||||
assert converter.compare_models((x,), threshold=1e-2)
|
||||
|
||||
return converted_source
|
||||
|
||||
|
||||
def convert_mask_decoder(mask_decoder: nn.Module) -> dict[str, Tensor]:
|
||||
manual_seed(seed=0)
|
||||
|
||||
refiners_mask_decoder = MaskDecoder()
|
||||
|
||||
image_embedding = torch.randn(1, 256, 64, 64)
|
||||
dense_positional_embedding = torch.randn(1, 256, 64, 64)
|
||||
point_embedding = torch.randn(1, 3, 256)
|
||||
mask_embedding = torch.randn(1, 256, 64, 64)
|
||||
|
||||
from segment_anything.modeling.common import LayerNorm2d # type: ignore
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
|
||||
assert issubclass(LayerNorm2d, nn.Module)
|
||||
custom_layers = {LayerNorm2d: fl.LayerNorm2d}
|
||||
|
||||
converter = ModelConverter(
|
||||
source_model=mask_decoder,
|
||||
target_model=refiners_mask_decoder,
|
||||
custom_layer_mapping=custom_layers, # type: ignore
|
||||
)
|
||||
|
||||
inputs = {
|
||||
"image_embeddings": image_embedding,
|
||||
"image_pe": dense_positional_embedding,
|
||||
"sparse_prompt_embeddings": point_embedding,
|
||||
"dense_prompt_embeddings": mask_embedding,
|
||||
"multimask_output": True,
|
||||
}
|
||||
|
||||
refiners_mask_decoder.set_image_embedding(image_embedding)
|
||||
refiners_mask_decoder.set_point_embedding(point_embedding)
|
||||
refiners_mask_decoder.set_mask_embedding(mask_embedding)
|
||||
refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
|
||||
|
||||
mapping = converter.map_state_dicts(source_args=inputs, target_args={})
|
||||
assert mapping is not None
|
||||
mapping["MaskDecoderTokens.Parameter"] = "iou_token"
|
||||
|
||||
state_dict = converter._convert_state_dict( # type: ignore
|
||||
source_state_dict=mask_decoder.state_dict(),
|
||||
target_state_dict=refiners_mask_decoder.state_dict(),
|
||||
state_dict_mapping=mapping,
|
||||
)
|
||||
state_dict["MaskDecoderTokens.Parameter.weight"] = torch.cat(
|
||||
tensors=[mask_decoder.iou_token.weight, mask_decoder.mask_tokens.weight], dim=0
|
||||
) # type: ignore
|
||||
refiners_mask_decoder.load_state_dict(state_dict=state_dict)
|
||||
|
||||
refiners_mask_decoder.set_image_embedding(image_embedding)
|
||||
refiners_mask_decoder.set_point_embedding(point_embedding)
|
||||
refiners_mask_decoder.set_mask_embedding(mask_embedding)
|
||||
refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
|
||||
|
||||
# Perform (1) upscaling then (2) mask prediction in this order (= like in the official implementation) to make
|
||||
# `compare_models` happy (MaskPrediction's Matmul runs those in the reverse order by default)
|
||||
matmul = refiners_mask_decoder.ensure_find(fl.Matmul)
|
||||
|
||||
def forward_swapped_order(self: Any, *args: Any) -> Any:
|
||||
y = self[1](*args) # (1)
|
||||
x = self[0](*args) # (2)
|
||||
return torch.matmul(input=x, other=y)
|
||||
|
||||
matmul.forward = types.MethodType(forward_swapped_order, matmul)
|
||||
|
||||
assert converter.compare_models(source_args=inputs, target_args={}, threshold=1e-3)
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Converts a Segment Anything ViT model to a Refiners SAMViTH model")
|
||||
parser.add_argument(
|
||||
"--from",
|
||||
type=str,
|
||||
dest="source_path",
|
||||
default="sam_vit_h_4b8939.pth",
|
||||
# required=True,
|
||||
help="Path to the Segment Anything model weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
type=str,
|
||||
dest="output_path",
|
||||
default="segment-anything-h.safetensors",
|
||||
help="Output path for converted model (as safetensors).",
|
||||
)
|
||||
parser.add_argument("--half", action="store_true", default=False, help="Convert to half precision. Default: False")
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Prints additional information during conversion. Default: False",
|
||||
)
|
||||
args = parser.parse_args(namespace=Args())
|
||||
|
||||
sam_h = build_sam_vit_h() # type: ignore
|
||||
sam_h.load_state_dict(state_dict=load_tensors(args.source_path))
|
||||
|
||||
vit_state_dict = convert_vit(vit=sam_h.image_encoder)
|
||||
mask_decoder_state_dict = convert_mask_decoder(mask_decoder=sam_h.mask_decoder)
|
||||
point_encoder_state_dict = convert_point_encoder(prompt_encoder=sam_h.prompt_encoder)
|
||||
mask_encoder_state_dict = convert_mask_encoder(prompt_encoder=sam_h.prompt_encoder)
|
||||
|
||||
output_state_dict = {
|
||||
**{f"SAMViTH.{key}": value for key, value in vit_state_dict.items()},
|
||||
**{f"MaskDecoder.{key}": value for key, value in mask_decoder_state_dict.items()},
|
||||
**{f"PointEncoder.{key}": value for key, value in point_encoder_state_dict.items()},
|
||||
**{f"MaskEncoder.{key}": value for key, value in mask_encoder_state_dict.items()},
|
||||
}
|
||||
if args.half:
|
||||
output_state_dict = {key: value.half() for key, value in output_state_dict.items()}
|
||||
|
||||
save_to_safetensors(path=args.output_path, tensors=output_state_dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,149 +0,0 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple, cast
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import CLIPVisionModelWithProjection # type: ignore
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from refiners.fluxion.utils import save_to_safetensors
|
||||
from refiners.foundationals.clip.image_encoder import CLIPImageEncoder
|
||||
|
||||
|
||||
class Args(argparse.Namespace):
|
||||
source_path: str
|
||||
subfolder: str
|
||||
output_path: str | None
|
||||
half: bool
|
||||
verbose: bool
|
||||
threshold: float
|
||||
|
||||
|
||||
class CLIPImageEncoderConfig(NamedTuple):
|
||||
architectures: list[str]
|
||||
num_channels: int
|
||||
hidden_size: int
|
||||
hidden_act: str
|
||||
image_size: int
|
||||
projection_dim: int
|
||||
patch_size: int
|
||||
num_hidden_layers: int
|
||||
num_attention_heads: int
|
||||
intermediate_size: int
|
||||
layer_norm_eps: float
|
||||
|
||||
|
||||
def setup_converter(args: Args) -> ModelConverter:
|
||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
||||
source: nn.Module = CLIPVisionModelWithProjection.from_pretrained( # type: ignore
|
||||
pretrained_model_name_or_path=args.source_path,
|
||||
subfolder=args.subfolder,
|
||||
low_cpu_mem_usage=False,
|
||||
)
|
||||
assert isinstance(source, nn.Module), "Source model is not a nn.Module"
|
||||
config = cast(CLIPImageEncoderConfig, source.config) # pyright: ignore[reportArgumentType, reportUnknownMemberType]
|
||||
|
||||
assert (
|
||||
config.architectures[0] == "CLIPVisionModelWithProjection"
|
||||
), f"Unsupported architecture: {config.architectures[0]}"
|
||||
assert config.num_channels == 3, f"Expected 3 input channels, got {config.num_channels}"
|
||||
assert config.hidden_act == "gelu", f"Unsupported activation: {config.hidden_act}"
|
||||
|
||||
target = CLIPImageEncoder(
|
||||
image_size=config.image_size,
|
||||
embedding_dim=config.hidden_size,
|
||||
output_dim=config.projection_dim,
|
||||
patch_size=config.patch_size,
|
||||
num_layers=config.num_hidden_layers,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
feedforward_dim=config.intermediate_size,
|
||||
layer_norm_eps=config.layer_norm_eps,
|
||||
)
|
||||
|
||||
x = torch.randn(1, 3, config.image_size, config.image_size)
|
||||
|
||||
converter = ModelConverter(source_model=source, target_model=target, verbose=True)
|
||||
|
||||
# Custom conversion logic since the class embedding (fl.Parameter layer) is not supported out-of-the-box by the
|
||||
# converter
|
||||
mapping = converter.map_state_dicts((x,))
|
||||
assert mapping is not None
|
||||
|
||||
source_state_dict = source.state_dict()
|
||||
target_state_dict = target.state_dict()
|
||||
|
||||
# Remove the class embedding from state dict since it was not mapped by the model converter
|
||||
class_embedding = target.ensure_find(fl.Parameter)
|
||||
class_embedding_key = next((n for n, p in target.named_parameters() if id(p) == id(class_embedding.weight)), None)
|
||||
assert class_embedding_key is not None
|
||||
assert class_embedding_key in target_state_dict
|
||||
del target_state_dict[class_embedding_key]
|
||||
|
||||
converted_state_dict = converter._convert_state_dict( # type: ignore[reportPrivateUsage]
|
||||
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
|
||||
)
|
||||
target.load_state_dict(state_dict=converted_state_dict, strict=False)
|
||||
|
||||
# Ad hoc post-conversion steps
|
||||
embed = source.vision_model.embeddings.class_embedding
|
||||
class_embedding.weight = torch.nn.Parameter(embed.clone().reshape_as(class_embedding.weight)) # type: ignore
|
||||
|
||||
assert converter.compare_models((x,), threshold=args.threshold)
|
||||
|
||||
return converter
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Converts a CLIPImageEncoder from the library transformers from the HuggingFace Hub to refiners."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--from",
|
||||
type=str,
|
||||
dest="source_path",
|
||||
default="stabilityai/stable-diffusion-2-1-unclip",
|
||||
help=(
|
||||
"Can be a path to a .bin file, a .safetensors file or a model name from the HuggingFace Hub. Default:"
|
||||
" stabilityai/stable-diffusion-2-1-unclip"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--subfolder",
|
||||
type=str,
|
||||
dest="subfolder",
|
||||
default="image_encoder",
|
||||
help="Subfolder in the source path where the model is located inside the Hub. Default: image_encoder",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
type=str,
|
||||
dest="output_path",
|
||||
default=None,
|
||||
help=(
|
||||
"Output path (.safetensors) for converted model. If not provided, the output path will be the same as the"
|
||||
" source path."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--half", action="store_true", help="Convert to half precision.")
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Prints additional information during conversion. Default: False",
|
||||
)
|
||||
parser.add_argument("--threshold", type=float, default=1e-2, help="Threshold for model comparison. Default: 1e-2")
|
||||
args = parser.parse_args(namespace=Args())
|
||||
if args.output_path is None:
|
||||
args.output_path = f"{Path(args.source_path).stem}-{args.subfolder}.safetensors"
|
||||
converter = setup_converter(args=args)
|
||||
# Do not use converter.save_to_safetensors since it is not in a valid state due to the ad hoc conversion
|
||||
state_dict = converter.target_model.state_dict()
|
||||
if args.half:
|
||||
state_dict = {key: value.half() for key, value in state_dict.items()}
|
||||
save_to_safetensors(path=args.output_path, tensors=state_dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,150 +0,0 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple, cast
|
||||
|
||||
from torch import nn
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection # type: ignore
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from refiners.fluxion.utils import save_to_safetensors
|
||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, CLIPTextEncoderG, CLIPTextEncoderL
|
||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||
|
||||
|
||||
class Args(argparse.Namespace):
|
||||
source_path: str
|
||||
subfolder: str
|
||||
output_path: str | None
|
||||
half: bool
|
||||
verbose: bool
|
||||
|
||||
|
||||
class CLIPTextEncoderConfig(NamedTuple):
|
||||
architectures: list[str]
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
num_hidden_layers: int
|
||||
num_attention_heads: int
|
||||
hidden_act: str
|
||||
layer_norm_eps: float
|
||||
projection_dim: int
|
||||
|
||||
|
||||
def setup_converter(args: Args, with_projection: bool = False) -> ModelConverter:
|
||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
||||
cls = CLIPTextModelWithProjection if with_projection else CLIPTextModel
|
||||
source: nn.Module = cls.from_pretrained( # type: ignore
|
||||
pretrained_model_name_or_path=args.source_path,
|
||||
subfolder=args.subfolder,
|
||||
low_cpu_mem_usage=False,
|
||||
)
|
||||
assert isinstance(source, nn.Module), "Source model is not a nn.Module"
|
||||
config = cast(CLIPTextEncoderConfig, source.config) # pyright: ignore[reportArgumentType, reportUnknownMemberType]
|
||||
architecture: str = config.architectures[0]
|
||||
embedding_dim: int = config.hidden_size
|
||||
projection_dim: int = config.projection_dim
|
||||
use_quick_gelu = config.hidden_act == "quick_gelu"
|
||||
assert architecture in ("CLIPTextModel", "CLIPTextModelWithProjection"), f"Unsupported architecture: {architecture}"
|
||||
target = CLIPTextEncoder(
|
||||
embedding_dim=config.hidden_size,
|
||||
num_layers=config.num_hidden_layers,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
feedforward_dim=config.intermediate_size,
|
||||
use_quick_gelu=use_quick_gelu,
|
||||
)
|
||||
if architecture == "CLIPTextModelWithProjection":
|
||||
target.append(module=fl.Linear(in_features=embedding_dim, out_features=projection_dim, bias=False))
|
||||
text = "What a nice cat you have there!"
|
||||
tokenizer = target.ensure_find(CLIPTokenizer)
|
||||
tokens = tokenizer(text)
|
||||
converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose)
|
||||
if not converter.run(source_args=(tokens,), target_args=(text,)):
|
||||
raise RuntimeError("Model conversion failed")
|
||||
return converter
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Converts a CLIPTextEncoder from the library transformers from the HuggingFace Hub to refiners."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--from",
|
||||
type=str,
|
||||
dest="source_path",
|
||||
default="runwayml/stable-diffusion-v1-5",
|
||||
help=(
|
||||
"Can be a path to a .bin file, a .safetensors file or a model name from the HuggingFace Hub. Default:"
|
||||
" runwayml/stable-diffusion-v1-5"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--subfolder",
|
||||
type=str,
|
||||
dest="subfolder",
|
||||
default="text_encoder",
|
||||
help=(
|
||||
"Subfolder in the source path where the model is located inside the Hub. Default: text_encoder (for"
|
||||
" CLIPTextModel)"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--subfolder2",
|
||||
type=str,
|
||||
dest="subfolder2",
|
||||
default=None,
|
||||
help="Additional subfolder for the 2nd text encoder (useful for SDXL). Default: None",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
type=str,
|
||||
dest="output_path",
|
||||
default=None,
|
||||
help=(
|
||||
"Output path (.safetensors) for converted model. If not provided, the output path will be the same as the"
|
||||
" source path."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--half", action="store_true", help="Convert to half precision.")
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Prints additional information during conversion. Default: False",
|
||||
)
|
||||
args = parser.parse_args(namespace=Args())
|
||||
if args.output_path is None:
|
||||
args.output_path = f"{Path(args.source_path).stem}-{args.subfolder}.safetensors"
|
||||
converter = setup_converter(args=args)
|
||||
if args.subfolder2 is not None:
|
||||
# Assume this is the second text encoder of Stable Diffusion XL
|
||||
args.subfolder = args.subfolder2
|
||||
converter2 = setup_converter(args=args, with_projection=True)
|
||||
|
||||
text_encoder_l = CLIPTextEncoderL()
|
||||
text_encoder_l.load_state_dict(state_dict=converter.get_state_dict())
|
||||
|
||||
projection = cast(CLIPTextEncoder, converter2.target_model)[-1]
|
||||
assert isinstance(projection, fl.Linear)
|
||||
text_encoder_g_with_projection = CLIPTextEncoderG()
|
||||
text_encoder_g_with_projection.append(module=projection)
|
||||
text_encoder_g_with_projection.load_state_dict(state_dict=converter2.get_state_dict())
|
||||
|
||||
projection = text_encoder_g_with_projection.pop(index=-1)
|
||||
assert isinstance(projection, fl.Linear)
|
||||
double_text_encoder = DoubleTextEncoder(
|
||||
text_encoder_l=text_encoder_l, text_encoder_g=text_encoder_g_with_projection, projection=projection
|
||||
)
|
||||
|
||||
state_dict = double_text_encoder.state_dict()
|
||||
if args.half:
|
||||
state_dict = {key: value.half() for key, value in state_dict.items()}
|
||||
save_to_safetensors(path=args.output_path, tensors=state_dict)
|
||||
else:
|
||||
converter.save_to_safetensors(path=args.output_path, half=args.half)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,945 +0,0 @@
|
|||
"""
|
||||
Download and convert weights for testing
|
||||
|
||||
To see what weights will be downloaded and converted, run:
|
||||
DRY_RUN=1 python scripts/prepare_test_weights.py
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import gdown
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
# Set the base directory to the parent directory of the script
|
||||
project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
test_weights_dir = os.path.join(project_dir, "tests", "weights")
|
||||
|
||||
previous_line = "\033[F"
|
||||
|
||||
download_count = 0
|
||||
bytes_count = 0
|
||||
|
||||
|
||||
def die(message: str) -> None:
|
||||
print(message, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def rel(path: str) -> str:
|
||||
return os.path.relpath(path, project_dir)
|
||||
|
||||
|
||||
def calc_hash(filepath: str) -> str:
|
||||
with open(filepath, "rb") as f:
|
||||
data = f.read()
|
||||
found = hashlib.blake2b(data, digest_size=int(32 / 8)).hexdigest()
|
||||
return found
|
||||
|
||||
|
||||
def check_hash(path: str, expected: str) -> str:
|
||||
found = calc_hash(path)
|
||||
if found != expected:
|
||||
die(f"❌ Invalid hash for {path} ({found} != {expected})")
|
||||
return found
|
||||
|
||||
|
||||
def download_file(
|
||||
url: str,
|
||||
dest_folder: str,
|
||||
dry_run: bool | None = None,
|
||||
skip_existing: bool = True,
|
||||
expected_hash: str | None = None,
|
||||
filename: str | None = None,
|
||||
):
|
||||
"""
|
||||
Downloads a file
|
||||
|
||||
Features:
|
||||
- shows a progress bar
|
||||
- skips existing files
|
||||
- uses a temporary file to prevent partial downloads
|
||||
- can do a dry run to check the url is valid
|
||||
- displays the downloaded file hash
|
||||
|
||||
"""
|
||||
global download_count, bytes_count
|
||||
filename = os.path.basename(urlparse(url).path) if filename is None else filename
|
||||
dest_filename = os.path.join(dest_folder, filename)
|
||||
temp_filename = dest_filename + ".part"
|
||||
dry_run = bool(os.environ.get("DRY_RUN") == "1") if dry_run is None else dry_run
|
||||
|
||||
is_downloaded = os.path.exists(dest_filename)
|
||||
if is_downloaded and skip_existing:
|
||||
skip_icon = "✖️ "
|
||||
else:
|
||||
skip_icon = "🔽"
|
||||
|
||||
if dry_run:
|
||||
response = requests.head(url, allow_redirects=True)
|
||||
readable_size = ""
|
||||
|
||||
if response.status_code == 200:
|
||||
content_length = response.headers.get("content-length")
|
||||
|
||||
if content_length:
|
||||
size_in_bytes = int(content_length)
|
||||
readable_size = human_readable_size(size_in_bytes)
|
||||
download_count += 1
|
||||
bytes_count += size_in_bytes
|
||||
print(f"✅{skip_icon} {response.status_code} READY {readable_size:<8} {url}")
|
||||
|
||||
else:
|
||||
print(f"❌{skip_icon} {response.status_code} ERROR {readable_size:<8} {url}")
|
||||
return
|
||||
|
||||
if skip_existing and is_downloaded:
|
||||
print(f"{skip_icon}️ Skipping previously downloaded {url}")
|
||||
if expected_hash is not None:
|
||||
check_hash(dest_filename, expected_hash)
|
||||
return
|
||||
|
||||
os.makedirs(dest_folder, exist_ok=True)
|
||||
|
||||
print(f"🔽 Downloading {url} => '{rel(dest_filename)}'", end="\n")
|
||||
response = requests.get(url, stream=True)
|
||||
if response.status_code != 200:
|
||||
print(response.content[:1000])
|
||||
die(f"Failed to download {url}. Status code: {response.status_code}")
|
||||
total = int(response.headers.get("content-length", 0))
|
||||
bar = tqdm(
|
||||
desc=filename,
|
||||
total=total,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
leave=False,
|
||||
)
|
||||
with open(temp_filename, "wb") as f, bar:
|
||||
for data in response.iter_content(chunk_size=1024 * 1000):
|
||||
size = f.write(data)
|
||||
bar.update(size)
|
||||
|
||||
os.rename(temp_filename, dest_filename)
|
||||
calculated_hash = calc_hash(dest_filename)
|
||||
|
||||
print(f"{previous_line}✅ Downloaded {calculated_hash} {url} => '{rel(dest_filename)}' ")
|
||||
if expected_hash is not None:
|
||||
check_hash(dest_filename, expected_hash)
|
||||
|
||||
|
||||
def download_files(urls: list[str], dest_folder: str):
|
||||
for url in urls:
|
||||
download_file(url, dest_folder)
|
||||
|
||||
|
||||
def human_readable_size(size: int | float, decimal_places: int = 2) -> str:
|
||||
for unit in ["B", "KB", "MB", "GB", "TB", "PB"]:
|
||||
if size < 1024.0:
|
||||
break
|
||||
size /= 1024.0
|
||||
return f"{size:.{decimal_places}f}{unit}" # type: ignore
|
||||
|
||||
|
||||
def download_sd_text_encoder(hf_repo_id: str = "runwayml/stable-diffusion-v1-5", subdir: str = "text_encoder"):
|
||||
encoder_filename = "model.safetensors" if "inpainting" not in hf_repo_id else "model.fp16.safetensors"
|
||||
base_url = f"https://huggingface.co/{hf_repo_id}"
|
||||
download_files(
|
||||
urls=[
|
||||
f"{base_url}/raw/main/{subdir}/config.json",
|
||||
f"{base_url}/resolve/main/{subdir}/{encoder_filename}",
|
||||
],
|
||||
dest_folder=os.path.join(test_weights_dir, hf_repo_id, subdir),
|
||||
)
|
||||
|
||||
|
||||
def download_sd_tokenizer(hf_repo_id: str = "runwayml/stable-diffusion-v1-5", subdir: str = "tokenizer"):
|
||||
download_files(
|
||||
urls=[
|
||||
f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/merges.txt",
|
||||
f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/special_tokens_map.json",
|
||||
f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/tokenizer_config.json",
|
||||
f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/vocab.json",
|
||||
],
|
||||
dest_folder=os.path.join(test_weights_dir, hf_repo_id, subdir),
|
||||
)
|
||||
|
||||
|
||||
def download_sd_base(hf_repo_id: str = "runwayml/stable-diffusion-v1-5"):
|
||||
is_inpainting = "inpainting" in hf_repo_id
|
||||
ext = "safetensors" if not is_inpainting else "bin"
|
||||
base_folder = os.path.join(test_weights_dir, hf_repo_id)
|
||||
download_file(f"https://huggingface.co/{hf_repo_id}/raw/main/model_index.json", base_folder)
|
||||
download_file(
|
||||
f"https://huggingface.co/{hf_repo_id}/raw/main/scheduler/scheduler_config.json",
|
||||
os.path.join(base_folder, "scheduler"),
|
||||
)
|
||||
|
||||
for subdir in ["unet", "vae"]:
|
||||
subdir_folder = os.path.join(base_folder, subdir)
|
||||
download_file(f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/config.json", subdir_folder)
|
||||
download_file(
|
||||
f"https://huggingface.co/{hf_repo_id}/resolve/main/{subdir}/diffusion_pytorch_model.{ext}", subdir_folder
|
||||
)
|
||||
# we only need the unet for the inpainting model
|
||||
if not is_inpainting:
|
||||
download_sd_text_encoder(hf_repo_id, "text_encoder")
|
||||
download_sd_tokenizer(hf_repo_id, "tokenizer")
|
||||
|
||||
|
||||
def download_sd15(hf_repo_id: str = "runwayml/stable-diffusion-v1-5"):
|
||||
download_sd_base(hf_repo_id)
|
||||
base_folder = os.path.join(test_weights_dir, hf_repo_id)
|
||||
|
||||
subdir = "feature_extractor"
|
||||
download_file(
|
||||
f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/preprocessor_config.json",
|
||||
os.path.join(base_folder, subdir),
|
||||
)
|
||||
|
||||
if "inpainting" not in hf_repo_id:
|
||||
subdir = "safety_checker"
|
||||
subdir_folder = os.path.join(base_folder, subdir)
|
||||
download_file(f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/config.json", subdir_folder)
|
||||
download_file(f"https://huggingface.co/{hf_repo_id}/resolve/main/{subdir}/model.safetensors", subdir_folder)
|
||||
|
||||
|
||||
def download_sdxl(hf_repo_id: str = "stabilityai/stable-diffusion-xl-base-1.0"):
|
||||
download_sd_base(hf_repo_id)
|
||||
download_sd_text_encoder(hf_repo_id, "text_encoder_2")
|
||||
download_sd_tokenizer(hf_repo_id, "tokenizer_2")
|
||||
|
||||
|
||||
def download_vae_fp16_fix():
|
||||
download_files(
|
||||
urls=[
|
||||
"https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/raw/main/config.json",
|
||||
"https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/diffusion_pytorch_model.safetensors",
|
||||
],
|
||||
dest_folder=os.path.join(test_weights_dir, "madebyollin", "sdxl-vae-fp16-fix"),
|
||||
)
|
||||
|
||||
|
||||
def download_vae_ft_mse():
|
||||
download_files(
|
||||
urls=[
|
||||
"https://huggingface.co/stabilityai/sd-vae-ft-mse/raw/main/config.json",
|
||||
"https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/diffusion_pytorch_model.safetensors",
|
||||
],
|
||||
dest_folder=os.path.join(test_weights_dir, "stabilityai", "sd-vae-ft-mse"),
|
||||
)
|
||||
|
||||
|
||||
def download_loras():
|
||||
dest_folder = os.path.join(test_weights_dir, "loras", "pokemon-lora")
|
||||
download_file(
|
||||
"https://huggingface.co/pcuenq/pokemon-lora/resolve/main/pytorch_lora_weights.bin",
|
||||
dest_folder,
|
||||
expected_hash="89992ea6",
|
||||
)
|
||||
|
||||
dest_folder = os.path.join(test_weights_dir, "loras", "dpo-lora")
|
||||
download_file(
|
||||
"https://huggingface.co/radames/sdxl-DPO-LoRA/resolve/main/pytorch_lora_weights.safetensors",
|
||||
dest_folder,
|
||||
expected_hash="a51e9144",
|
||||
)
|
||||
|
||||
dest_folder = os.path.join(test_weights_dir, "loras", "sliders")
|
||||
download_file("https://sliders.baulab.info/weights/xl_sliders/age.pt", dest_folder, expected_hash="908f07d3")
|
||||
download_file(
|
||||
"https://sliders.baulab.info/weights/xl_sliders/cartoon_style.pt", dest_folder, expected_hash="25652004"
|
||||
)
|
||||
download_file("https://sliders.baulab.info/weights/xl_sliders/eyesize.pt", dest_folder, expected_hash="ee170e4d")
|
||||
|
||||
dest_folder = os.path.join(test_weights_dir, "loras")
|
||||
download_file(
|
||||
"https://civitai.com/api/download/models/140624",
|
||||
filename="Sci-fi_Environments_sdxl.safetensors",
|
||||
dest_folder=dest_folder,
|
||||
expected_hash="6a4afda8",
|
||||
)
|
||||
download_file(
|
||||
"https://civitai.com/api/download/models/135931",
|
||||
filename="pixel-art-xl-v1.1.safetensors",
|
||||
dest_folder=dest_folder,
|
||||
expected_hash="71aaa6ca",
|
||||
)
|
||||
|
||||
|
||||
def download_preprocessors():
|
||||
dest_folder = os.path.join(test_weights_dir, "carolineec", "informativedrawings")
|
||||
download_file("https://huggingface.co/spaces/carolineec/informativedrawings/resolve/main/model2.pth", dest_folder)
|
||||
|
||||
|
||||
def download_controlnet():
|
||||
base_folder = os.path.join(test_weights_dir, "lllyasviel")
|
||||
controlnets = [
|
||||
"control_v11p_sd15_canny",
|
||||
"control_v11f1p_sd15_depth",
|
||||
"control_v11p_sd15_normalbae",
|
||||
"control_v11p_sd15_lineart",
|
||||
]
|
||||
for net in controlnets:
|
||||
net_folder = os.path.join(base_folder, net)
|
||||
urls = [
|
||||
f"https://huggingface.co/lllyasviel/{net}/raw/main/config.json",
|
||||
f"https://huggingface.co/lllyasviel/{net}/resolve/main/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
download_files(urls, net_folder)
|
||||
|
||||
tile_folder = os.path.join(base_folder, "control_v11f1e_sd15_tile")
|
||||
urls = [
|
||||
"https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile/raw/main/config.json",
|
||||
"https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile/resolve/main/diffusion_pytorch_model.bin",
|
||||
]
|
||||
download_files(urls, tile_folder)
|
||||
|
||||
mfidabel_folder = os.path.join(test_weights_dir, "mfidabel", "controlnet-segment-anything")
|
||||
urls = [
|
||||
"https://huggingface.co/mfidabel/controlnet-segment-anything/raw/main/config.json",
|
||||
"https://huggingface.co/mfidabel/controlnet-segment-anything/resolve/main/diffusion_pytorch_model.bin",
|
||||
]
|
||||
download_files(urls, mfidabel_folder)
|
||||
|
||||
|
||||
def download_control_lora_fooocus():
|
||||
base_folder = os.path.join(test_weights_dir, "lllyasviel", "misc")
|
||||
|
||||
download_file(
|
||||
url=f"https://huggingface.co/lllyasviel/misc/resolve/main/control-lora-canny-rank128.safetensors",
|
||||
dest_folder=base_folder,
|
||||
expected_hash="fec9e32b",
|
||||
)
|
||||
|
||||
download_file(
|
||||
url=f"https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_xl_cpds_128.safetensors",
|
||||
dest_folder=base_folder,
|
||||
expected_hash="fc04b120",
|
||||
)
|
||||
|
||||
|
||||
def download_unclip():
|
||||
base_folder = os.path.join(test_weights_dir, "stabilityai", "stable-diffusion-2-1-unclip")
|
||||
download_file(
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/raw/main/model_index.json", base_folder
|
||||
)
|
||||
image_encoder_folder = os.path.join(base_folder, "image_encoder")
|
||||
urls = [
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/raw/main/image_encoder/config.json",
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/resolve/main/image_encoder/model.safetensors",
|
||||
]
|
||||
download_files(urls, image_encoder_folder)
|
||||
|
||||
|
||||
def download_ip_adapter():
|
||||
base_folder = os.path.join(test_weights_dir, "h94", "IP-Adapter")
|
||||
models_folder = os.path.join(base_folder, "models")
|
||||
urls = [
|
||||
"https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter_sd15.bin",
|
||||
"https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-plus_sd15.bin",
|
||||
]
|
||||
download_files(urls, models_folder)
|
||||
|
||||
sdxl_models_folder = os.path.join(base_folder, "sdxl_models")
|
||||
urls = [
|
||||
"https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter_sdxl_vit-h.bin",
|
||||
"https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin",
|
||||
]
|
||||
download_files(urls, sdxl_models_folder)
|
||||
|
||||
|
||||
def download_t5xl_fp16():
|
||||
base_folder = os.path.join(test_weights_dir, "QQGYLab", "T5XLFP16")
|
||||
urls = [
|
||||
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/config.json",
|
||||
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/model.safetensors",
|
||||
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/special_tokens_map.json",
|
||||
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/spiece.model",
|
||||
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/tokenizer.json",
|
||||
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/tokenizer_config.json",
|
||||
]
|
||||
download_files(urls, base_folder)
|
||||
|
||||
|
||||
def download_ella_adapter():
|
||||
download_t5xl_fp16()
|
||||
base_folder = os.path.join(test_weights_dir, "QQGYLab", "ELLA")
|
||||
download_file(
|
||||
"https://huggingface.co/QQGYLab/ELLA/resolve/main/ella-sd1.5-tsc-t5xl.safetensors",
|
||||
base_folder,
|
||||
expected_hash="5af7b200",
|
||||
)
|
||||
|
||||
|
||||
def download_t2i_adapter():
|
||||
base_folder = os.path.join(test_weights_dir, "TencentARC", "t2iadapter_depth_sd15v2")
|
||||
urls = [
|
||||
"https://huggingface.co/TencentARC/t2iadapter_depth_sd15v2/raw/main/config.json",
|
||||
"https://huggingface.co/TencentARC/t2iadapter_depth_sd15v2/resolve/main/diffusion_pytorch_model.bin",
|
||||
]
|
||||
download_files(urls, base_folder)
|
||||
|
||||
canny_sdxl_folder = os.path.join(test_weights_dir, "TencentARC", "t2i-adapter-canny-sdxl-1.0")
|
||||
urls = [
|
||||
"https://huggingface.co/TencentARC/t2i-adapter-canny-sdxl-1.0/raw/main/config.json",
|
||||
"https://huggingface.co/TencentARC/t2i-adapter-canny-sdxl-1.0/resolve/main/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
download_files(urls, canny_sdxl_folder)
|
||||
|
||||
|
||||
def download_sam():
|
||||
weights_folder = os.path.join(test_weights_dir)
|
||||
download_file(
|
||||
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", weights_folder, expected_hash="06785e66"
|
||||
)
|
||||
|
||||
|
||||
def download_hq_sam():
|
||||
weights_folder = os.path.join(test_weights_dir)
|
||||
download_file(
|
||||
"https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth", weights_folder, expected_hash="66da2472"
|
||||
)
|
||||
|
||||
|
||||
def download_dinov2():
|
||||
# For conversion
|
||||
weights_folder = os.path.join(test_weights_dir)
|
||||
urls = [
|
||||
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth",
|
||||
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth",
|
||||
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth",
|
||||
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth",
|
||||
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth",
|
||||
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth",
|
||||
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth",
|
||||
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth",
|
||||
]
|
||||
download_files(urls, weights_folder)
|
||||
|
||||
|
||||
def download_lcm_base():
|
||||
base_folder = os.path.join(test_weights_dir, "latent-consistency/lcm-sdxl")
|
||||
download_file(f"https://huggingface.co/latent-consistency/lcm-sdxl/raw/main/config.json", base_folder)
|
||||
download_file(
|
||||
f"https://huggingface.co/latent-consistency/lcm-sdxl/resolve/main/diffusion_pytorch_model.safetensors",
|
||||
base_folder,
|
||||
)
|
||||
|
||||
|
||||
def download_lcm_lora():
|
||||
download_file(
|
||||
"https://huggingface.co/latent-consistency/lcm-lora-sdxl/resolve/main/pytorch_lora_weights.safetensors",
|
||||
dest_folder=test_weights_dir,
|
||||
filename="sdxl-lcm-lora.safetensors",
|
||||
expected_hash="6312a30a",
|
||||
)
|
||||
|
||||
|
||||
def download_sdxl_lightning_base():
|
||||
base_folder = os.path.join(test_weights_dir, "ByteDance/SDXL-Lightning")
|
||||
download_file(
|
||||
f"https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_4step_unet.safetensors",
|
||||
base_folder,
|
||||
expected_hash="1b76cca3",
|
||||
)
|
||||
download_file(
|
||||
f"https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_1step_unet_x0.safetensors",
|
||||
base_folder,
|
||||
expected_hash="38e605bd",
|
||||
)
|
||||
|
||||
|
||||
def download_sdxl_lightning_lora():
|
||||
download_file(
|
||||
"https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_4step_lora.safetensors",
|
||||
dest_folder=test_weights_dir,
|
||||
expected_hash="9783edac",
|
||||
)
|
||||
|
||||
|
||||
def download_ic_light():
|
||||
download_file(
|
||||
"https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors",
|
||||
dest_folder=test_weights_dir,
|
||||
expected_hash="bce70123",
|
||||
)
|
||||
|
||||
|
||||
def download_mvanet():
|
||||
fn = "Model_80.pth"
|
||||
dest_folder = os.path.join(test_weights_dir, "mvanet")
|
||||
dest_filename = os.path.join(dest_folder, fn)
|
||||
|
||||
if os.environ.get("DRY_RUN") == "1":
|
||||
return
|
||||
|
||||
if os.path.exists(dest_filename):
|
||||
print(f"✖️ ️ Skipping previously downloaded mvanet/{fn}")
|
||||
else:
|
||||
os.makedirs(dest_folder, exist_ok=True)
|
||||
print(f"🔽 Downloading mvanet/{fn} => '{rel(dest_filename)}'", end="\n")
|
||||
gdown.download(id="1_gabQXOF03MfXnf3EWDK1d_8wKiOemOv", output=dest_filename, quiet=True)
|
||||
print(f"{previous_line}✅ Downloaded mvanet/{fn} => '{rel(dest_filename)}' ")
|
||||
|
||||
check_hash(dest_filename, "b915d492")
|
||||
|
||||
|
||||
def download_box_segmenter():
|
||||
download_file(
|
||||
"https://huggingface.co/finegrain/finegrain-box-segmenter/resolve/v0.1/model.safetensors",
|
||||
dest_folder=test_weights_dir,
|
||||
filename="finegrain-box-segmenter-v0-1.safetensors",
|
||||
expected_hash="e0450e8c",
|
||||
)
|
||||
|
||||
|
||||
def printg(msg: str):
|
||||
"""print in green color"""
|
||||
print("\033[92m" + msg + "\033[0m")
|
||||
|
||||
|
||||
def run_conversion_script(
|
||||
script_filename: str,
|
||||
from_weights: str,
|
||||
to_weights: str,
|
||||
half: bool = False,
|
||||
expected_hash: str | None = None,
|
||||
additional_args: list[str] | None = None,
|
||||
skip_existing: bool = True,
|
||||
):
|
||||
if skip_existing and expected_hash and os.path.exists(to_weights):
|
||||
found_hash = check_hash(to_weights, expected_hash)
|
||||
if expected_hash == found_hash:
|
||||
printg(f"✖️ Skipping converted from {from_weights} to {to_weights} (hash {found_hash} confirmed) ")
|
||||
return
|
||||
|
||||
msg = f"Converting {from_weights} to {to_weights}"
|
||||
printg(msg)
|
||||
|
||||
args = ["python", f"scripts/conversion/{script_filename}", "--from", from_weights, "--to", to_weights]
|
||||
if half:
|
||||
args.append("--half")
|
||||
if additional_args:
|
||||
args.extend(additional_args)
|
||||
|
||||
subprocess.run(args, check=True)
|
||||
if expected_hash is not None:
|
||||
found_hash = check_hash(to_weights, expected_hash)
|
||||
printg(f"✅ Converted from {from_weights} to {to_weights} (hash {found_hash} confirmed) ")
|
||||
else:
|
||||
printg(f"✅⚠️ Converted from {from_weights} to {to_weights} (no hash check performed)")
|
||||
|
||||
|
||||
def convert_sd15():
|
||||
run_conversion_script(
|
||||
script_filename="convert_transformers_clip_text_model.py",
|
||||
from_weights="tests/weights/runwayml/stable-diffusion-v1-5",
|
||||
to_weights="tests/weights/CLIPTextEncoderL.safetensors",
|
||||
half=True,
|
||||
expected_hash="6c9cbc59",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_autoencoder_kl.py",
|
||||
"tests/weights/runwayml/stable-diffusion-v1-5",
|
||||
"tests/weights/lda.safetensors",
|
||||
expected_hash="329e369c",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_unet.py",
|
||||
"tests/weights/runwayml/stable-diffusion-v1-5",
|
||||
"tests/weights/unet.safetensors",
|
||||
half=True,
|
||||
expected_hash="f81ac65a",
|
||||
)
|
||||
os.makedirs("tests/weights/inpainting", exist_ok=True)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_unet.py",
|
||||
"tests/weights/runwayml/stable-diffusion-inpainting",
|
||||
"tests/weights/inpainting/unet.safetensors",
|
||||
half=True,
|
||||
expected_hash="c07a8c61",
|
||||
)
|
||||
|
||||
|
||||
def convert_sdxl():
|
||||
run_conversion_script(
|
||||
"convert_transformers_clip_text_model.py",
|
||||
"tests/weights/stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"tests/weights/DoubleCLIPTextEncoder.safetensors",
|
||||
half=True,
|
||||
expected_hash="7f99c30b",
|
||||
additional_args=["--subfolder2", "text_encoder_2"],
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_autoencoder_kl.py",
|
||||
"tests/weights/stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"tests/weights/sdxl-lda.safetensors",
|
||||
half=True,
|
||||
expected_hash="7464e9dc",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_unet.py",
|
||||
"tests/weights/stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"tests/weights/sdxl-unet.safetensors",
|
||||
half=True,
|
||||
expected_hash="2e5c4911",
|
||||
)
|
||||
|
||||
|
||||
def convert_vae_ft_mse():
|
||||
run_conversion_script(
|
||||
"convert_diffusers_autoencoder_kl.py",
|
||||
"tests/weights/stabilityai/sd-vae-ft-mse",
|
||||
"tests/weights/lda_ft_mse.safetensors",
|
||||
half=True,
|
||||
expected_hash="4d0bae7e",
|
||||
)
|
||||
|
||||
|
||||
def convert_vae_fp16_fix():
|
||||
run_conversion_script(
|
||||
"convert_diffusers_autoencoder_kl.py",
|
||||
"tests/weights/madebyollin/sdxl-vae-fp16-fix",
|
||||
"tests/weights/sdxl-lda-fp16-fix.safetensors",
|
||||
additional_args=["--subfolder", "''"],
|
||||
half=True,
|
||||
expected_hash="98c7e998",
|
||||
)
|
||||
|
||||
|
||||
def convert_preprocessors():
|
||||
subprocess.run(
|
||||
[
|
||||
"curl",
|
||||
"-L",
|
||||
"https://raw.githubusercontent.com/carolineec/informative-drawings/main/model.py",
|
||||
"-o",
|
||||
"src/model.py",
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_informative_drawings.py",
|
||||
"tests/weights/carolineec/informativedrawings/model2.pth",
|
||||
"tests/weights/informative-drawings.safetensors",
|
||||
expected_hash="93dca207",
|
||||
)
|
||||
os.remove("src/model.py")
|
||||
|
||||
|
||||
def convert_controlnet():
|
||||
os.makedirs("tests/weights/controlnet", exist_ok=True)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_controlnet.py",
|
||||
"tests/weights/lllyasviel/control_v11p_sd15_canny",
|
||||
"tests/weights/controlnet/lllyasviel_control_v11p_sd15_canny.safetensors",
|
||||
expected_hash="9a1a48cf",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_controlnet.py",
|
||||
"tests/weights/lllyasviel/control_v11f1p_sd15_depth",
|
||||
"tests/weights/controlnet/lllyasviel_control_v11f1p_sd15_depth.safetensors",
|
||||
expected_hash="bbe7e5a6",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_controlnet.py",
|
||||
"tests/weights/lllyasviel/control_v11p_sd15_normalbae",
|
||||
"tests/weights/controlnet/lllyasviel_control_v11p_sd15_normalbae.safetensors",
|
||||
expected_hash="9fa88ed5",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_controlnet.py",
|
||||
"tests/weights/lllyasviel/control_v11p_sd15_lineart",
|
||||
"tests/weights/controlnet/lllyasviel_control_v11p_sd15_lineart.safetensors",
|
||||
expected_hash="c29e8c03",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_controlnet.py",
|
||||
"tests/weights/mfidabel/controlnet-segment-anything",
|
||||
"tests/weights/controlnet/mfidabel_controlnet-segment-anything.safetensors",
|
||||
expected_hash="d536eebb",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_controlnet.py",
|
||||
"tests/weights/lllyasviel/control_v11f1e_sd15_tile",
|
||||
"tests/weights/controlnet/lllyasviel_control_v11f1e_sd15_tile.safetensors",
|
||||
expected_hash="42463af8",
|
||||
)
|
||||
|
||||
|
||||
def convert_unclip():
|
||||
run_conversion_script(
|
||||
"convert_transformers_clip_image_model.py",
|
||||
"tests/weights/stabilityai/stable-diffusion-2-1-unclip",
|
||||
"tests/weights/CLIPImageEncoderH.safetensors",
|
||||
half=True,
|
||||
expected_hash="4ddb44d2",
|
||||
)
|
||||
|
||||
|
||||
def convert_ip_adapter():
|
||||
run_conversion_script(
|
||||
"convert_diffusers_ip_adapter.py",
|
||||
"tests/weights/h94/IP-Adapter/models/ip-adapter_sd15.bin",
|
||||
"tests/weights/ip-adapter_sd15.safetensors",
|
||||
expected_hash="3fb0472e",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_ip_adapter.py",
|
||||
"tests/weights/h94/IP-Adapter/sdxl_models/ip-adapter_sdxl_vit-h.bin",
|
||||
"tests/weights/ip-adapter_sdxl_vit-h.safetensors",
|
||||
half=True,
|
||||
expected_hash="860518fe",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_ip_adapter.py",
|
||||
"tests/weights/h94/IP-Adapter/models/ip-adapter-plus_sd15.bin",
|
||||
"tests/weights/ip-adapter-plus_sd15.safetensors",
|
||||
half=True,
|
||||
expected_hash="aba8503b",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_ip_adapter.py",
|
||||
"tests/weights/h94/IP-Adapter/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin",
|
||||
"tests/weights/ip-adapter-plus_sdxl_vit-h.safetensors",
|
||||
half=True,
|
||||
expected_hash="545d5ce7",
|
||||
)
|
||||
|
||||
|
||||
def convert_ella_adapter():
|
||||
os.makedirs("tests/weights/ELLA-Adapter", exist_ok=True)
|
||||
run_conversion_script(
|
||||
"convert_ella_adapter.py",
|
||||
"tests/weights/QQGYLab/ELLA/ella-sd1.5-tsc-t5xl.safetensors",
|
||||
"tests/weights/ELLA-Adapter/ella-sd1.5-tsc-t5xl.safetensors",
|
||||
half=True,
|
||||
expected_hash="b8244cb6",
|
||||
)
|
||||
|
||||
|
||||
def convert_t2i_adapter():
|
||||
os.makedirs("tests/weights/T2I-Adapter", exist_ok=True)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_t2i_adapter.py",
|
||||
"tests/weights/TencentARC/t2iadapter_depth_sd15v2",
|
||||
"tests/weights/T2I-Adapter/t2iadapter_depth_sd15v2.safetensors",
|
||||
half=True,
|
||||
expected_hash="bb2b3115",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_t2i_adapter.py",
|
||||
"tests/weights/TencentARC/t2i-adapter-canny-sdxl-1.0",
|
||||
"tests/weights/T2I-Adapter/t2i-adapter-canny-sdxl-1.0.safetensors",
|
||||
half=True,
|
||||
expected_hash="f07249a6",
|
||||
)
|
||||
|
||||
|
||||
def convert_sam():
|
||||
run_conversion_script(
|
||||
"convert_segment_anything.py",
|
||||
"tests/weights/sam_vit_h_4b8939.pth",
|
||||
"tests/weights/segment-anything-h.safetensors",
|
||||
expected_hash="5ffb976f",
|
||||
)
|
||||
|
||||
|
||||
def convert_hq_sam():
|
||||
run_conversion_script(
|
||||
"convert_hq_segment_anything.py",
|
||||
"tests/weights/sam_hq_vit_h.pth",
|
||||
"tests/weights/refiners-sam-hq-vit-h.safetensors",
|
||||
expected_hash="b2f5e79f",
|
||||
)
|
||||
|
||||
|
||||
def convert_dinov2():
|
||||
run_conversion_script(
|
||||
"convert_dinov2.py",
|
||||
"tests/weights/dinov2_vits14_pretrain.pth",
|
||||
"tests/weights/dinov2_vits14_pretrain.safetensors",
|
||||
expected_hash="af000ded",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_dinov2.py",
|
||||
"tests/weights/dinov2_vitb14_pretrain.pth",
|
||||
"tests/weights/dinov2_vitb14_pretrain.safetensors",
|
||||
expected_hash="d6294087",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_dinov2.py",
|
||||
"tests/weights/dinov2_vitl14_pretrain.pth",
|
||||
"tests/weights/dinov2_vitl14_pretrain.safetensors",
|
||||
expected_hash="ddd4819f",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_dinov2.py",
|
||||
"tests/weights/dinov2_vitg14_pretrain.pth",
|
||||
"tests/weights/dinov2_vitg14_pretrain.safetensors",
|
||||
expected_hash="880c61f5",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_dinov2.py",
|
||||
"tests/weights/dinov2_vits14_reg4_pretrain.pth",
|
||||
"tests/weights/dinov2_vits14_reg4_pretrain.safetensors",
|
||||
expected_hash="080247c7",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_dinov2.py",
|
||||
"tests/weights/dinov2_vitb14_reg4_pretrain.pth",
|
||||
"tests/weights/dinov2_vitb14_reg4_pretrain.safetensors",
|
||||
expected_hash="5cd4d408",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_dinov2.py",
|
||||
"tests/weights/dinov2_vitl14_reg4_pretrain.pth",
|
||||
"tests/weights/dinov2_vitl14_reg4_pretrain.safetensors",
|
||||
expected_hash="b1221702",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_dinov2.py",
|
||||
"tests/weights/dinov2_vitg14_reg4_pretrain.pth",
|
||||
"tests/weights/dinov2_vitg14_reg4_pretrain.safetensors",
|
||||
expected_hash="639398eb",
|
||||
)
|
||||
|
||||
|
||||
def convert_control_lora_fooocus():
|
||||
run_conversion_script(
|
||||
"convert_fooocus_control_lora.py",
|
||||
"tests/weights/lllyasviel/misc/control-lora-canny-rank128.safetensors",
|
||||
"tests/weights/control-loras/refiners_control-lora-canny-rank128.safetensors",
|
||||
expected_hash="4d505134",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_fooocus_control_lora.py",
|
||||
"tests/weights/lllyasviel/misc/fooocus_xl_cpds_128.safetensors",
|
||||
"tests/weights/control-loras/refiners_fooocus_xl_cpds_128.safetensors",
|
||||
expected_hash="d81aa461",
|
||||
)
|
||||
|
||||
|
||||
def convert_lcm_base():
|
||||
run_conversion_script(
|
||||
"convert_diffusers_unet.py",
|
||||
"tests/weights/latent-consistency/lcm-sdxl",
|
||||
"tests/weights/sdxl-lcm-unet.safetensors",
|
||||
half=True,
|
||||
expected_hash="e161b20c",
|
||||
)
|
||||
|
||||
|
||||
def convert_sdxl_lightning_base():
|
||||
run_conversion_script(
|
||||
"convert_diffusers_unet.py",
|
||||
"tests/weights/stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"tests/weights/sdxl_lightning_4step_unet.safetensors",
|
||||
additional_args=[
|
||||
"--override-weights",
|
||||
"tests/weights/ByteDance/SDXL-Lightning/sdxl_lightning_4step_unet.safetensors",
|
||||
],
|
||||
half=True,
|
||||
expected_hash="cfdc46da",
|
||||
)
|
||||
|
||||
run_conversion_script(
|
||||
"convert_diffusers_unet.py",
|
||||
"tests/weights/stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"tests/weights/sdxl_lightning_1step_unet_x0.safetensors",
|
||||
additional_args=[
|
||||
"--override-weights",
|
||||
"tests/weights/ByteDance/SDXL-Lightning/sdxl_lightning_1step_unet_x0.safetensors",
|
||||
],
|
||||
half=True,
|
||||
expected_hash="21166a64",
|
||||
)
|
||||
|
||||
|
||||
def convert_ic_light():
|
||||
run_conversion_script(
|
||||
"convert_ic_light.py",
|
||||
"tests/weights/iclight_sd15_fc.safetensors",
|
||||
"tests/weights/iclight_sd15_fc-refiners.safetensors",
|
||||
half=False,
|
||||
expected_hash="be315c1f",
|
||||
)
|
||||
|
||||
|
||||
def convert_mvanet():
|
||||
run_conversion_script(
|
||||
"convert_mvanet.py",
|
||||
"tests/weights/mvanet/Model_80.pth",
|
||||
"tests/weights/mvanet/mvanet.safetensors",
|
||||
half=True,
|
||||
expected_hash="bf9ae4cb",
|
||||
)
|
||||
|
||||
|
||||
def download_all():
|
||||
print(f"\nAll weights will be downloaded to {test_weights_dir}\n")
|
||||
download_sd15("runwayml/stable-diffusion-v1-5")
|
||||
download_sd15("runwayml/stable-diffusion-inpainting")
|
||||
download_sdxl("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
download_vae_ft_mse()
|
||||
download_vae_fp16_fix()
|
||||
download_loras()
|
||||
download_preprocessors()
|
||||
download_controlnet()
|
||||
download_unclip()
|
||||
download_ip_adapter()
|
||||
download_t2i_adapter()
|
||||
download_ella_adapter()
|
||||
download_sam()
|
||||
download_hq_sam()
|
||||
download_dinov2()
|
||||
download_control_lora_fooocus()
|
||||
download_lcm_base()
|
||||
download_lcm_lora()
|
||||
download_sdxl_lightning_base()
|
||||
download_sdxl_lightning_lora()
|
||||
download_ic_light()
|
||||
download_mvanet()
|
||||
download_box_segmenter()
|
||||
|
||||
|
||||
def convert_all():
|
||||
convert_sd15()
|
||||
convert_sdxl()
|
||||
convert_vae_ft_mse()
|
||||
convert_vae_fp16_fix()
|
||||
# Note: no convert loras: this is done at runtime by `SDLoraManager`
|
||||
convert_preprocessors()
|
||||
convert_controlnet()
|
||||
convert_unclip()
|
||||
convert_ip_adapter()
|
||||
convert_t2i_adapter()
|
||||
convert_ella_adapter()
|
||||
convert_sam()
|
||||
convert_hq_sam()
|
||||
convert_dinov2()
|
||||
convert_control_lora_fooocus()
|
||||
convert_lcm_base()
|
||||
convert_sdxl_lightning_base()
|
||||
convert_ic_light()
|
||||
convert_mvanet()
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
download_all()
|
||||
print(f"{download_count} files ({human_readable_size(bytes_count)})\n")
|
||||
if not bool(os.environ.get("DRY_RUN") == "1"):
|
||||
printg("Converting weights to refiners format\n")
|
||||
convert_all()
|
||||
except KeyboardInterrupt:
|
||||
print("Stopped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,654 +0,0 @@
|
|||
from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from typing import Any, DefaultDict, TypedDict
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from refiners.fluxion.utils import no_grad, norm, save_to_safetensors
|
||||
|
||||
TORCH_BASIC_LAYERS: list[type[nn.Module]] = [
|
||||
nn.Conv1d,
|
||||
nn.Conv2d,
|
||||
nn.Conv3d,
|
||||
nn.ConvTranspose1d,
|
||||
nn.ConvTranspose2d,
|
||||
nn.ConvTranspose3d,
|
||||
nn.Linear,
|
||||
nn.BatchNorm1d,
|
||||
nn.BatchNorm2d,
|
||||
nn.BatchNorm3d,
|
||||
nn.LayerNorm,
|
||||
nn.GroupNorm,
|
||||
nn.Embedding,
|
||||
nn.MaxPool2d,
|
||||
nn.AvgPool2d,
|
||||
nn.AdaptiveAvgPool2d,
|
||||
]
|
||||
|
||||
|
||||
ModelTypeShape = tuple[type[nn.Module], tuple[torch.Size, ...]]
|
||||
|
||||
|
||||
class ModuleArgsDict(TypedDict):
|
||||
"""Represents positional and keyword arguments passed to a module.
|
||||
|
||||
- `positional`: A tuple of positional arguments.
|
||||
- `keyword`: A dictionary of keyword arguments.
|
||||
"""
|
||||
|
||||
positional: tuple[Any, ...]
|
||||
keyword: dict[str, Any]
|
||||
|
||||
|
||||
class ConversionStage(Enum):
|
||||
"""Represents the current stage of the conversion process.
|
||||
|
||||
Attributes:
|
||||
INIT: The conversion process has not started.
|
||||
BASIC_LAYERS_MATCH: The source and target models have the same number of basic layers.
|
||||
SHAPE_AND_LAYERS_MATCH: The shape of both models agree.
|
||||
MODELS_OUTPUT_AGREE: The source and target models agree.
|
||||
"""
|
||||
|
||||
INIT = auto()
|
||||
BASIC_LAYERS_MATCH = auto()
|
||||
SHAPE_AND_LAYERS_MATCH = auto()
|
||||
MODELS_OUTPUT_AGREE = auto()
|
||||
|
||||
|
||||
class ModelConverter:
|
||||
"""Converts a model's state_dict to match another model's state_dict.
|
||||
|
||||
Note: The conversion process consists of three stages
|
||||
1. Verify that the source and target models have the same number of basic layers.
|
||||
2. Find matching shapes and layers between the source and target models.
|
||||
3. Convert the source model's state_dict to match the target model's state_dict.
|
||||
4. Compare the outputs of the source and target models.
|
||||
|
||||
The conversion process can be run multiple times, and will resume from the last stage.
|
||||
|
||||
Example:
|
||||
```py
|
||||
source = ...
|
||||
target = ...
|
||||
|
||||
converter = ModelConverter(
|
||||
source_model=source,
|
||||
target_model=target,
|
||||
threshold=0.1,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
is_converted = converter(args)
|
||||
if is_converted:
|
||||
converter.save_to_safetensors(path="converted_model.pt")
|
||||
```
|
||||
"""
|
||||
|
||||
ModuleArgs = tuple[Any, ...] | dict[str, Any] | ModuleArgsDict
|
||||
stage: ConversionStage = ConversionStage.INIT
|
||||
_stored_mapping: dict[str, str] | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source_model: nn.Module,
|
||||
target_model: nn.Module,
|
||||
source_keys_to_skip: list[str] | None = None,
|
||||
target_keys_to_skip: list[str] | None = None,
|
||||
custom_layer_mapping: dict[type[nn.Module], type[nn.Module]] | None = None,
|
||||
threshold: float = 1e-5,
|
||||
skip_output_check: bool = False,
|
||||
skip_init_check: bool = False,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
"""Initializes the ModelConverter.
|
||||
|
||||
Args:
|
||||
source_model: The model to convert from.
|
||||
target_model: The model to convert to.
|
||||
source_keys_to_skip: A list of keys to skip when tracing the source model.
|
||||
target_keys_to_skip: A list of keys to skip when tracing the target model.
|
||||
custom_layer_mapping: A dictionary mapping custom layer types between the source and target models.
|
||||
threshold: The threshold for comparing outputs between the source and target models.
|
||||
skip_output_check: Whether to skip comparing the outputs of the source and target models.
|
||||
skip_init_check: Whether to skip checking that the source and target models have the same number of basic
|
||||
layers.
|
||||
verbose: Whether to print messages during the conversion process.
|
||||
|
||||
"""
|
||||
self.source_model = source_model
|
||||
self.target_model = target_model
|
||||
self.source_keys_to_skip = source_keys_to_skip or []
|
||||
self.target_keys_to_skip = target_keys_to_skip or []
|
||||
self.custom_layer_mapping = custom_layer_mapping or {}
|
||||
self.threshold = threshold
|
||||
self.skip_output_check = skip_output_check
|
||||
self.skip_init_check = skip_init_check
|
||||
self.verbose = verbose
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"ModelConverter(source_model={self.source_model.__class__.__name__},"
|
||||
f" target_model={self.target_model.__class__.__name__}, stage={self.stage})"
|
||||
)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return self.stage.value >= 2 if self.skip_output_check else self.stage.value >= 3
|
||||
|
||||
def run(self, source_args: ModuleArgs, target_args: ModuleArgs | None = None) -> bool:
|
||||
"""Run the conversion process.
|
||||
|
||||
Args:
|
||||
source_args: The arguments to pass to the source model it can be either a tuple of positional arguments,
|
||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
|
||||
is not provided, these arguments will also be passed to the target model.
|
||||
target_args: The arguments to pass to the target model it can be either a tuple of positional arguments,
|
||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||
|
||||
Returns:
|
||||
True if the conversion process is done and the models agree.
|
||||
"""
|
||||
if target_args is None:
|
||||
target_args = source_args
|
||||
|
||||
match self.stage:
|
||||
case ConversionStage.MODELS_OUTPUT_AGREE:
|
||||
self._increment_stage()
|
||||
return True
|
||||
|
||||
case ConversionStage.SHAPE_AND_LAYERS_MATCH if self._run_shape_and_layers_match_stage(
|
||||
source_args=source_args, target_args=target_args
|
||||
):
|
||||
self._increment_stage()
|
||||
return True
|
||||
|
||||
case ConversionStage.BASIC_LAYERS_MATCH if self._run_basic_layers_match_stage(
|
||||
source_args=source_args, target_args=target_args
|
||||
):
|
||||
self._increment_stage()
|
||||
return self.run(source_args=source_args, target_args=target_args)
|
||||
|
||||
case ConversionStage.INIT if self._run_init_stage():
|
||||
self._increment_stage()
|
||||
return self.run(source_args=source_args, target_args=target_args)
|
||||
|
||||
case _:
|
||||
self._log(message=f"Conversion failed at stage {self.stage.value}")
|
||||
return False
|
||||
|
||||
def _increment_stage(self) -> None:
|
||||
"""Increment the stage of the conversion process."""
|
||||
match self.stage:
|
||||
case ConversionStage.INIT:
|
||||
self.stage = ConversionStage.BASIC_LAYERS_MATCH
|
||||
self._log(
|
||||
message=(
|
||||
"Stage 0 -> 1 - Models have the same number of basic layers. Finding matching shapes and"
|
||||
" layers..."
|
||||
)
|
||||
)
|
||||
case ConversionStage.BASIC_LAYERS_MATCH:
|
||||
self.stage = ConversionStage.SHAPE_AND_LAYERS_MATCH
|
||||
self._log(
|
||||
message=(
|
||||
"Stage 1 -> 2 - Shape of both models agree. Applying state_dict to target model. Comparing"
|
||||
" models..."
|
||||
)
|
||||
)
|
||||
|
||||
case ConversionStage.SHAPE_AND_LAYERS_MATCH:
|
||||
if self.skip_output_check:
|
||||
self._log(
|
||||
message=(
|
||||
"Stage 2 - Nothing to do. Skipping output check. If you want to compare the outputs, set"
|
||||
" `skip_output_check` to `False`"
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.stage = ConversionStage.MODELS_OUTPUT_AGREE
|
||||
self._log(
|
||||
message=(
|
||||
"Stage 2 -> 3 - Conversion is done and source and target models agree: you can export the"
|
||||
" converted model using `save_to_safetensors`"
|
||||
)
|
||||
)
|
||||
case ConversionStage.MODELS_OUTPUT_AGREE:
|
||||
self._log(
|
||||
message=(
|
||||
"Stage 3 - Nothing to do. Conversion is done and source and target models agree: you can export"
|
||||
" the converted model using `save_to_safetensors`"
|
||||
)
|
||||
)
|
||||
|
||||
def get_state_dict(self) -> dict[str, Tensor]:
|
||||
"""Get the converted state_dict."""
|
||||
if not self:
|
||||
raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
|
||||
return self.target_model.state_dict()
|
||||
|
||||
def get_mapping(self) -> dict[str, str]:
|
||||
"""Get the mapping between the source and target models' state_dicts."""
|
||||
if not self:
|
||||
raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
|
||||
assert self._stored_mapping is not None, "Mapping is not stored"
|
||||
return self._stored_mapping
|
||||
|
||||
def save_to_safetensors(self, path: Path | str, metadata: dict[str, str] | None = None, half: bool = False) -> None:
|
||||
"""Save the converted model to a SafeTensors file.
|
||||
|
||||
Warning:
|
||||
This method can only be called after the conversion process is done.
|
||||
|
||||
Args:
|
||||
path: The path to save the converted model to.
|
||||
metadata: Metadata to save with the converted model.
|
||||
half: Whether to save the converted model as half precision.
|
||||
|
||||
Raises:
|
||||
ValueError: If the conversion process is not done yet. Run `converter` first.
|
||||
"""
|
||||
if not self:
|
||||
raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
|
||||
state_dict = self.get_state_dict()
|
||||
if half:
|
||||
state_dict = {key: value.half() for key, value in state_dict.items()}
|
||||
save_to_safetensors(path=path, tensors=state_dict, metadata=metadata)
|
||||
|
||||
def map_state_dicts(
|
||||
self,
|
||||
source_args: ModuleArgs,
|
||||
target_args: ModuleArgs | None = None,
|
||||
) -> dict[str, str] | None:
|
||||
"""Find a mapping between the source and target models' state_dicts.
|
||||
|
||||
Args:
|
||||
source_args: The arguments to pass to the source model it can be either a tuple of positional arguments,
|
||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
|
||||
is not provided, these arguments will also be passed to the target model.
|
||||
target_args: The arguments to pass to the target model it can be either a tuple of positional arguments,
|
||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping keys in the target model's state_dict to keys in the source model's state_dict.
|
||||
"""
|
||||
if target_args is None:
|
||||
target_args = source_args
|
||||
|
||||
source_order = self._trace_module_execution_order(
|
||||
module=self.source_model, args=source_args, keys_to_skip=self.source_keys_to_skip
|
||||
)
|
||||
target_order = self._trace_module_execution_order(
|
||||
module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip
|
||||
)
|
||||
|
||||
if not self._assert_shapes_aligned(source_order=source_order, target_order=target_order):
|
||||
return None
|
||||
|
||||
mapping: dict[str, str] = {}
|
||||
for source_type_shape in source_order:
|
||||
source_keys = source_order[source_type_shape]
|
||||
target_type_shape = source_type_shape
|
||||
if not self._is_torch_basic_layer(module_type=source_type_shape[0]):
|
||||
for source_custom_type, target_custom_type in self.custom_layer_mapping.items():
|
||||
if source_custom_type == source_type_shape[0]:
|
||||
target_type_shape = (target_custom_type, source_type_shape[1])
|
||||
break
|
||||
|
||||
target_keys = target_order[target_type_shape]
|
||||
mapping.update(zip(target_keys, source_keys))
|
||||
|
||||
return mapping
|
||||
|
||||
def compare_models(
|
||||
self,
|
||||
source_args: ModuleArgs,
|
||||
target_args: ModuleArgs | None = None,
|
||||
threshold: float = 1e-5,
|
||||
) -> bool:
|
||||
"""Compare the outputs of the source and target models.
|
||||
|
||||
Args:
|
||||
source_args: The arguments to pass to the source model it can be either a tuple of positional arguments,
|
||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
|
||||
is not provided, these arguments will also be passed to the target model.
|
||||
target_args: The arguments to pass to the target model it can be either a tuple of positional arguments,
|
||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||
threshold: The threshold for comparing outputs between the source and target models.
|
||||
|
||||
Returns:
|
||||
True if the outputs of the source and target models agree.
|
||||
"""
|
||||
if target_args is None:
|
||||
target_args = source_args
|
||||
|
||||
source_outputs = self._collect_layers_outputs(
|
||||
module=self.source_model, args=source_args, keys_to_skip=self.source_keys_to_skip
|
||||
)
|
||||
target_outputs = self._collect_layers_outputs(
|
||||
module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip
|
||||
)
|
||||
|
||||
diff, prev_source_key, prev_target_key = None, None, None
|
||||
for (source_key, source_output), (target_key, target_output) in zip(source_outputs, target_outputs):
|
||||
diff = norm(source_output - target_output.reshape(shape=source_output.shape)).item()
|
||||
if diff > threshold:
|
||||
self._log(
|
||||
f"Models diverged between {prev_source_key} and {source_key}, and between {prev_target_key} and"
|
||||
f" {target_key}, difference in norm: {diff}"
|
||||
)
|
||||
return False
|
||||
prev_source_key, prev_target_key = source_key, target_key
|
||||
|
||||
self._log(message=f"Models agree. Difference in norm: {diff}")
|
||||
|
||||
return True
|
||||
|
||||
def _run_init_stage(self) -> bool:
|
||||
"""Run the init stage of the conversion process."""
|
||||
if self.skip_init_check:
|
||||
self._log(
|
||||
message=(
|
||||
"Skipping init check. If you want to check the number of basic layers, set `skip_init_check` to"
|
||||
" `False`"
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
is_count_correct = self._verify_basic_layers_count()
|
||||
is_not_missing_layers = self._verify_missing_basic_layers()
|
||||
|
||||
return is_count_correct and is_not_missing_layers
|
||||
|
||||
def _run_basic_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool:
|
||||
"""Run the basic layers match stage of the conversion process."""
|
||||
mapping = self.map_state_dicts(source_args=source_args, target_args=target_args)
|
||||
self._stored_mapping = mapping
|
||||
if mapping is None:
|
||||
self._log(message="Models do not have matching shapes.")
|
||||
return False
|
||||
|
||||
source_state_dict = self.source_model.state_dict()
|
||||
target_state_dict = self.target_model.state_dict()
|
||||
converted_state_dict = self._convert_state_dict(
|
||||
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
|
||||
)
|
||||
self.target_model.load_state_dict(state_dict=converted_state_dict)
|
||||
|
||||
return True
|
||||
|
||||
def _run_shape_and_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool:
|
||||
"""Run the shape and layers match stage of the conversion process."""
|
||||
if self.skip_output_check:
|
||||
self._log(
|
||||
message="Skipping output check. If you want to compare the outputs, set `skip_output_check` to `False`"
|
||||
)
|
||||
return True
|
||||
|
||||
try:
|
||||
if self.compare_models(source_args=source_args, target_args=target_args, threshold=self.threshold):
|
||||
self._log(message="Models agree. You can export the converted model using `save_to_safetensors`")
|
||||
return True
|
||||
else:
|
||||
self._log(message="Models do not agree. Try to increase the threshold or modify the models.")
|
||||
return False
|
||||
except Exception as e:
|
||||
self._log(message=f"An error occurred while comparing the models: {e}")
|
||||
return False
|
||||
|
||||
def _log(self, message: str) -> None:
|
||||
"""Print a message if `verbose` is `True`."""
|
||||
if self.verbose:
|
||||
print(message)
|
||||
|
||||
def _debug_print_shapes(
|
||||
self,
|
||||
shape: ModelTypeShape,
|
||||
source_keys: list[str],
|
||||
target_keys: list[str],
|
||||
) -> None:
|
||||
"""Print the shapes of the sub-modules in `source_keys` and `target_keys`."""
|
||||
self._log(message=f"{shape}")
|
||||
max_len = max(len(source_keys), len(target_keys))
|
||||
for i in range(max_len):
|
||||
source_key = source_keys[i] if i < len(source_keys) else "---"
|
||||
target_key = target_keys[i] if i < len(target_keys) else "---"
|
||||
self._log(f"\t{source_key}\t{target_key}")
|
||||
|
||||
@staticmethod
|
||||
def _unpack_module_args(module_args: ModuleArgs) -> tuple[tuple[Any, ...], dict[str, Any]]:
|
||||
"""Unpack the positional and keyword arguments passed to a module."""
|
||||
match module_args:
|
||||
case tuple(positional_args):
|
||||
keyword_args: dict[str, Any] = {}
|
||||
case {"positional": positional_args, "keyword": keyword_args}:
|
||||
pass
|
||||
case _:
|
||||
positional_args = ()
|
||||
keyword_args = dict(**module_args)
|
||||
|
||||
return positional_args, keyword_args
|
||||
|
||||
def _is_torch_basic_layer(self, module_type: type[nn.Module]) -> bool:
|
||||
"""Check if a module type is a subclass of a torch basic layer."""
|
||||
return any(issubclass(module_type, torch_basic_layer) for torch_basic_layer in TORCH_BASIC_LAYERS)
|
||||
|
||||
def _infer_basic_layer_type(self, module: nn.Module) -> type[nn.Module] | None:
|
||||
"""Infer the type of a basic layer."""
|
||||
layer_types = (
|
||||
set(self.custom_layer_mapping.keys()) | set(self.custom_layer_mapping.values()) | set(TORCH_BASIC_LAYERS)
|
||||
)
|
||||
for layer_type in layer_types:
|
||||
if isinstance(module, layer_type):
|
||||
return layer_type
|
||||
|
||||
return None
|
||||
|
||||
def get_module_signature(self, module: nn.Module) -> ModelTypeShape:
|
||||
"""Get the signature of a module."""
|
||||
layer_type = self._infer_basic_layer_type(module=module)
|
||||
assert layer_type is not None, f"Module {module} is not a basic layer"
|
||||
param_shapes = [p.shape for p in module.parameters()]
|
||||
return (layer_type, tuple(param_shapes))
|
||||
|
||||
def _count_basic_layers(self, module: nn.Module) -> dict[type[nn.Module], int]:
|
||||
"""Count the number of basic layers in a module."""
|
||||
count: DefaultDict[type[nn.Module], int] = defaultdict(int)
|
||||
for submodule in module.modules():
|
||||
layer_type = self._infer_basic_layer_type(module=submodule)
|
||||
if layer_type is not None:
|
||||
count[layer_type] += 1
|
||||
|
||||
return count
|
||||
|
||||
def _verify_basic_layers_count(self) -> bool:
|
||||
"""Verify that the source and target models have the same number of basic layers."""
|
||||
source_layers = self._count_basic_layers(module=self.source_model)
|
||||
target_layers = self._count_basic_layers(module=self.target_model)
|
||||
|
||||
reverse_mapping = {v: k for k, v in self.custom_layer_mapping.items()}
|
||||
|
||||
diff: dict[type[nn.Module], tuple[int, int]] = {}
|
||||
for layer_type, source_count in source_layers.items():
|
||||
target_type = self.custom_layer_mapping.get(layer_type, layer_type)
|
||||
target_count = target_layers.get(target_type, 0)
|
||||
|
||||
if source_count != target_count:
|
||||
diff[layer_type] = (source_count, target_count)
|
||||
|
||||
for layer_type, target_count in target_layers.items():
|
||||
source_type = reverse_mapping.get(layer_type, layer_type)
|
||||
source_count = source_layers.get(source_type, 0)
|
||||
|
||||
if source_count != target_count:
|
||||
diff[layer_type] = (source_count, target_count)
|
||||
|
||||
if diff:
|
||||
message = "Models do not have the same number of basic layers:\n"
|
||||
for layer_type, counts in diff.items():
|
||||
message += f" {layer_type}: Source {counts[0]} - Target {counts[1]}\n"
|
||||
self._log(message=message.strip())
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _is_weighted_leaf_module(self, module: nn.Module) -> bool:
|
||||
"""Check if a module is a leaf module with weights."""
|
||||
return next(module.parameters(), None) is not None and next(module.children(), None) is None
|
||||
|
||||
def _check_for_missing_basic_layers(self, module: nn.Module) -> list[type[nn.Module]]:
|
||||
"""Check if a module has weighted leaf modules that are not basic layers."""
|
||||
return [
|
||||
type(submodule)
|
||||
for submodule in module.modules()
|
||||
if self._is_weighted_leaf_module(module=submodule) and not self._infer_basic_layer_type(module=submodule)
|
||||
]
|
||||
|
||||
def _verify_missing_basic_layers(self) -> bool:
|
||||
"""Verify that the source and target models do not have missing basic layers."""
|
||||
missing_source_layers = self._check_for_missing_basic_layers(module=self.source_model)
|
||||
missing_target_layers = self._check_for_missing_basic_layers(module=self.target_model)
|
||||
|
||||
if missing_source_layers or missing_target_layers:
|
||||
self._log(
|
||||
message=(
|
||||
"Models might have missing basic layers. If you want to skip this check, set"
|
||||
f" `skip_init_check` to `True`: {missing_source_layers}, {missing_target_layers}"
|
||||
)
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@no_grad()
|
||||
def _trace_module_execution_order(
|
||||
self,
|
||||
module: nn.Module,
|
||||
args: ModuleArgs,
|
||||
keys_to_skip: list[str],
|
||||
) -> dict[ModelTypeShape, list[str]]:
|
||||
"""Execute a forward pass and store the order of execution of specific sub-modules.
|
||||
|
||||
Args:
|
||||
module: The module to trace.
|
||||
args: The arguments to pass to the module it can be either a tuple of positional arguments,
|
||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||
keys_to_skip: A list of keys to skip when tracing the module.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping the signature of each sub-module to a list of keys in the module's `named_modules`
|
||||
"""
|
||||
submodule_to_key: dict[nn.Module, str] = {}
|
||||
execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list)
|
||||
|
||||
def collect_execution_order_hook(layer: nn.Module, *_: Any) -> None:
|
||||
layer_signature = self.get_module_signature(module=layer)
|
||||
execution_order[layer_signature].append(submodule_to_key[layer])
|
||||
|
||||
hooks: list[RemovableHandle] = []
|
||||
named_modules: list[tuple[str, nn.Module]] = module.named_modules() # type: ignore
|
||||
for name, submodule in named_modules:
|
||||
if (self._infer_basic_layer_type(module=submodule) is not None) and name not in keys_to_skip:
|
||||
submodule_to_key[submodule] = name # type: ignore
|
||||
hook = submodule.register_forward_hook(hook=collect_execution_order_hook)
|
||||
hooks.append(hook)
|
||||
|
||||
positional_args, keyword_args = self._unpack_module_args(module_args=args)
|
||||
module(*positional_args, **keyword_args)
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
return dict(execution_order)
|
||||
|
||||
def _assert_shapes_aligned(
|
||||
self, source_order: dict[ModelTypeShape, list[str]], target_order: dict[ModelTypeShape, list[str]]
|
||||
) -> bool:
|
||||
"""Assert that the shapes of the sub-modules in `source_order` and `target_order` are aligned."""
|
||||
model_type_shapes = set(source_order.keys()) | set(target_order.keys())
|
||||
|
||||
default_type_shapes = [
|
||||
type_shape for type_shape in model_type_shapes if self._is_torch_basic_layer(module_type=type_shape[0])
|
||||
]
|
||||
|
||||
shape_mismatched = False
|
||||
|
||||
for model_type_shape in default_type_shapes:
|
||||
source_keys = source_order.get(model_type_shape, [])
|
||||
target_keys = target_order.get(model_type_shape, [])
|
||||
|
||||
if len(source_keys) != len(target_keys):
|
||||
shape_mismatched = True
|
||||
self._debug_print_shapes(shape=model_type_shape, source_keys=source_keys, target_keys=target_keys)
|
||||
|
||||
for source_custom_type in self.custom_layer_mapping.keys():
|
||||
# iterate over all type_shapes that have the same type as source_custom_type
|
||||
for source_type_shape in [
|
||||
type_shape for type_shape in model_type_shapes if type_shape[0] == source_custom_type
|
||||
]:
|
||||
source_keys = source_order.get(source_type_shape, [])
|
||||
target_custom_type = self.custom_layer_mapping[source_custom_type]
|
||||
target_type_shape = (target_custom_type, source_type_shape[1])
|
||||
target_keys = target_order.get(target_type_shape, [])
|
||||
|
||||
if len(source_keys) != len(target_keys):
|
||||
shape_mismatched = True
|
||||
self._debug_print_shapes(shape=source_type_shape, source_keys=source_keys, target_keys=target_keys)
|
||||
|
||||
return not shape_mismatched
|
||||
|
||||
@staticmethod
|
||||
def _convert_state_dict(
|
||||
source_state_dict: dict[str, Tensor], target_state_dict: dict[str, Tensor], state_dict_mapping: dict[str, str]
|
||||
) -> dict[str, Tensor]:
|
||||
"""Convert the source model's state_dict to match the target model's state_dict."""
|
||||
converted_state_dict: dict[str, Tensor] = {}
|
||||
for target_key in target_state_dict:
|
||||
target_prefix, suffix = target_key.rsplit(sep=".", maxsplit=1)
|
||||
source_prefix = state_dict_mapping[target_prefix]
|
||||
source_key = ".".join([source_prefix, suffix])
|
||||
converted_state_dict[target_key] = source_state_dict[source_key]
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@no_grad()
|
||||
def _collect_layers_outputs(
|
||||
self, module: nn.Module, args: ModuleArgs, keys_to_skip: list[str]
|
||||
) -> list[tuple[str, Tensor]]:
|
||||
"""Execute a forward pass and store the output of specific sub-modules.
|
||||
|
||||
Args:
|
||||
module: The module to trace.
|
||||
args: The arguments to pass to the module it can be either a tuple of positional arguments,
|
||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||
keys_to_skip: A list of keys to skip when tracing the module.
|
||||
|
||||
Returns:
|
||||
A list of tuples containing the key of each sub-module and its output.
|
||||
|
||||
Note:
|
||||
The output of each sub-module is cloned to avoid memory leaks.
|
||||
"""
|
||||
submodule_to_key: dict[nn.Module, str] = {}
|
||||
execution_order: list[tuple[str, Tensor]] = []
|
||||
|
||||
def collect_execution_order_hook(layer: nn.Module, _: Any, output: Tensor) -> None:
|
||||
execution_order.append((submodule_to_key[layer], output.clone()))
|
||||
|
||||
hooks: list[RemovableHandle] = []
|
||||
named_modules: list[tuple[str, nn.Module]] = module.named_modules() # type: ignore
|
||||
for name, submodule in named_modules:
|
||||
if (self._infer_basic_layer_type(module=submodule) is not None) and name not in keys_to_skip:
|
||||
submodule_to_key[submodule] = name # type: ignore
|
||||
hook = submodule.register_forward_hook(hook=collect_execution_order_hook)
|
||||
hooks.append(hook)
|
||||
|
||||
positional_args, keyword_args = self._unpack_module_args(module_args=args)
|
||||
module(*positional_args, **keyword_args)
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
return execution_order
|
Loading…
Reference in a new issue