diff --git a/docs/reference/fluxion/model_converter.md b/docs/reference/fluxion/model_converter.md deleted file mode 100644 index 39f4988..0000000 --- a/docs/reference/fluxion/model_converter.md +++ /dev/null @@ -1 +0,0 @@ -::: refiners.fluxion.model_converter diff --git a/scripts/conversion/convert_diffusers_autoencoder_kl.py b/scripts/conversion/convert_diffusers_autoencoder_kl.py deleted file mode 100644 index bcb68b4..0000000 --- a/scripts/conversion/convert_diffusers_autoencoder_kl.py +++ /dev/null @@ -1,81 +0,0 @@ -import argparse -from pathlib import Path - -import torch -from diffusers import AutoencoderKL # type: ignore -from torch import nn - -from refiners.fluxion.model_converter import ModelConverter -from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder - - -class Args(argparse.Namespace): - source_path: str - output_path: str | None - use_half: bool - verbose: bool - - -def setup_converter(args: Args) -> ModelConverter: - target = LatentDiffusionAutoencoder() - # low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate` - source: nn.Module = AutoencoderKL.from_pretrained( # type: ignore - pretrained_model_name_or_path=args.source_path, - subfolder=args.subfolder, - low_cpu_mem_usage=False, - ) # type: ignore - x = torch.randn(1, 3, 512, 512) - converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose) - if not converter.run(source_args=(x,)): - raise RuntimeError("Model conversion failed") - return converter - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Convert a pretrained diffusers AutoencoderKL model to a refiners Latent Diffusion Autoencoder" - ) - parser.add_argument( - "--from", - type=str, - dest="source_path", - default="runwayml/stable-diffusion-v1-5", - help="Path to the source pretrained model (default: 'runwayml/stable-diffusion-v1-5').", - ) - parser.add_argument( - "--subfolder", - type=str, - dest="subfolder", - default="vae", - help="Subfolder in the source path where the model is located inside the Hub (default: 'vae')", - ) - parser.add_argument( - "--to", - type=str, - dest="output_path", - default=None, - help=( - "Path to save the converted model (extension will be .safetensors). If not specified, the output path will" - " be the source path with the extension changed to .safetensors." - ), - ) - parser.add_argument( - "--half", - action="store_true", - dest="use_half", - default=False, - help="Use this flag to save the output file as half precision (default: full precision).", - ) - parser.add_argument( - "--verbose", - action="store_true", - dest="verbose", - default=False, - help="Use this flag to print verbose output during conversion.", - ) - args = parser.parse_args(namespace=Args()) - if args.output_path is None: - args.output_path = f"{Path(args.source_path).stem}-autoencoder.safetensors" - assert args.output_path is not None - converter = setup_converter(args=args) - converter.save_to_safetensors(path=args.output_path, half=args.use_half) diff --git a/scripts/conversion/convert_diffusers_controlnet.py b/scripts/conversion/convert_diffusers_controlnet.py deleted file mode 100644 index d707298..0000000 --- a/scripts/conversion/convert_diffusers_controlnet.py +++ /dev/null @@ -1,233 +0,0 @@ -# pyright: reportPrivateUsage=false -import argparse -from pathlib import Path - -import torch -from diffusers import ControlNetModel # type: ignore -from torch import nn - -from refiners.fluxion.model_converter import ModelConverter -from refiners.fluxion.utils import no_grad, save_to_safetensors -from refiners.foundationals.latent_diffusion import ( - DPMSolver, - SD1ControlnetAdapter, - SD1UNet, -) - - -class Args(argparse.Namespace): - source_path: str - output_path: str | None - - -@no_grad() -def convert(args: Args) -> dict[str, torch.Tensor]: - # low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate` - controlnet_src: nn.Module = ControlNetModel.from_pretrained( # type: ignore - pretrained_model_name_or_path=args.source_path, - low_cpu_mem_usage=False, - ) - unet = SD1UNet(in_channels=4) - adapter = SD1ControlnetAdapter(unet, name="mycn").inject() - controlnet = adapter.controlnet - - condition = torch.randn(1, 3, 512, 512) - adapter.set_controlnet_condition(condition=condition) - - clip_text_embedding = torch.rand(1, 77, 768) - unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding) - - solver = DPMSolver(num_inference_steps=10) - timestep = solver.timesteps[0].unsqueeze(dim=0) - unet.set_timestep(timestep=timestep.unsqueeze(dim=0)) - - x = torch.randn(1, 4, 64, 64) - - # We need the hack below because our implementation is not strictly equivalent - # to diffusers in order, since we compute the residuals inline instead of - # in a separate step. - - converter = ModelConverter( - source_model=controlnet_src, target_model=controlnet, skip_output_check=True, verbose=False - ) - - source_order = converter._trace_module_execution_order( - module=controlnet_src, args=(x, timestep, clip_text_embedding, condition), keys_to_skip=[] - ) - target_order = converter._trace_module_execution_order(module=controlnet, args=(x,), keys_to_skip=[]) - - broken_k = (nn.Conv2d, (torch.Size([320, 320, 1, 1]), torch.Size([320]))) - - expected_source_order = [ - "down_blocks.0.attentions.0.proj_in", - "down_blocks.0.attentions.0.proj_out", - "down_blocks.0.attentions.1.proj_in", - "down_blocks.0.attentions.1.proj_out", - "controlnet_down_blocks.0", - "controlnet_down_blocks.1", - "controlnet_down_blocks.2", - "controlnet_down_blocks.3", - ] - - expected_target_order = [ - "DownBlocks.Chain_1.Passthrough.Conv2d", - "DownBlocks.Chain_2.CLIPLCrossAttention.Chain_1.Conv2d", - "DownBlocks.Chain_2.CLIPLCrossAttention.Chain_3.Conv2d", - "DownBlocks.Chain_2.Passthrough.Conv2d", - "DownBlocks.Chain_3.CLIPLCrossAttention.Chain_1.Conv2d", - "DownBlocks.Chain_3.CLIPLCrossAttention.Chain_3.Conv2d", - "DownBlocks.Chain_3.Passthrough.Conv2d", - "DownBlocks.Chain_4.Passthrough.Conv2d", - ] - - fixed_source_order = [ - "controlnet_down_blocks.0", - "down_blocks.0.attentions.0.proj_in", - "down_blocks.0.attentions.0.proj_out", - "controlnet_down_blocks.1", - "down_blocks.0.attentions.1.proj_in", - "down_blocks.0.attentions.1.proj_out", - "controlnet_down_blocks.2", - "controlnet_down_blocks.3", - ] - - assert source_order[broken_k] == expected_source_order - assert target_order[broken_k] == expected_target_order - source_order[broken_k] = fixed_source_order - - broken_k = (nn.Conv2d, (torch.Size([640, 640, 1, 1]), torch.Size([640]))) - - expected_source_order = [ - "down_blocks.1.attentions.0.proj_in", - "down_blocks.1.attentions.0.proj_out", - "down_blocks.1.attentions.1.proj_in", - "down_blocks.1.attentions.1.proj_out", - "controlnet_down_blocks.4", - "controlnet_down_blocks.5", - "controlnet_down_blocks.6", - ] - - expected_target_order = [ - "DownBlocks.Chain_5.CLIPLCrossAttention.Chain_1.Conv2d", - "DownBlocks.Chain_5.CLIPLCrossAttention.Chain_3.Conv2d", - "DownBlocks.Chain_5.Passthrough.Conv2d", - "DownBlocks.Chain_6.CLIPLCrossAttention.Chain_1.Conv2d", - "DownBlocks.Chain_6.CLIPLCrossAttention.Chain_3.Conv2d", - "DownBlocks.Chain_6.Passthrough.Conv2d", - "DownBlocks.Chain_7.Passthrough.Conv2d", - ] - - fixed_source_order = [ - "down_blocks.1.attentions.0.proj_in", - "down_blocks.1.attentions.0.proj_out", - "controlnet_down_blocks.4", - "down_blocks.1.attentions.1.proj_in", - "down_blocks.1.attentions.1.proj_out", - "controlnet_down_blocks.5", - "controlnet_down_blocks.6", - ] - - assert source_order[broken_k] == expected_source_order - assert target_order[broken_k] == expected_target_order - source_order[broken_k] = fixed_source_order - - broken_k = (nn.Conv2d, (torch.Size([1280, 1280, 1, 1]), torch.Size([1280]))) - - expected_source_order = [ - "down_blocks.2.attentions.0.proj_in", - "down_blocks.2.attentions.0.proj_out", - "down_blocks.2.attentions.1.proj_in", - "down_blocks.2.attentions.1.proj_out", - "mid_block.attentions.0.proj_in", - "mid_block.attentions.0.proj_out", - "controlnet_down_blocks.7", - "controlnet_down_blocks.8", - "controlnet_down_blocks.9", - "controlnet_down_blocks.10", - "controlnet_down_blocks.11", - "controlnet_mid_block", - ] - - expected_target_order = [ - "DownBlocks.Chain_8.CLIPLCrossAttention.Chain_1.Conv2d", - "DownBlocks.Chain_8.CLIPLCrossAttention.Chain_3.Conv2d", - "DownBlocks.Chain_8.Passthrough.Conv2d", - "DownBlocks.Chain_9.CLIPLCrossAttention.Chain_1.Conv2d", - "DownBlocks.Chain_9.CLIPLCrossAttention.Chain_3.Conv2d", - "DownBlocks.Chain_9.Passthrough.Conv2d", - "DownBlocks.Chain_10.Passthrough.Conv2d", - "DownBlocks.Chain_11.Passthrough.Conv2d", - "DownBlocks.Chain_12.Passthrough.Conv2d", - "MiddleBlock.CLIPLCrossAttention.Chain_1.Conv2d", - "MiddleBlock.CLIPLCrossAttention.Chain_3.Conv2d", - "MiddleBlock.Passthrough.Conv2d", - ] - - fixed_source_order = [ - "down_blocks.2.attentions.0.proj_in", - "down_blocks.2.attentions.0.proj_out", - "controlnet_down_blocks.7", - "down_blocks.2.attentions.1.proj_in", - "down_blocks.2.attentions.1.proj_out", - "controlnet_down_blocks.8", - "controlnet_down_blocks.9", - "controlnet_down_blocks.10", - "controlnet_down_blocks.11", - "mid_block.attentions.0.proj_in", - "mid_block.attentions.0.proj_out", - "controlnet_mid_block", - ] - - assert source_order[broken_k] == expected_source_order - assert target_order[broken_k] == expected_target_order - source_order[broken_k] = fixed_source_order - - assert converter._assert_shapes_aligned(source_order=source_order, target_order=target_order), "Shapes not aligned" - - mapping: dict[str, str] = {} - for model_type_shape in source_order: - source_keys = source_order[model_type_shape] - target_keys = target_order[model_type_shape] - mapping.update(zip(target_keys, source_keys)) - - state_dict = converter._convert_state_dict( - source_state_dict=controlnet_src.state_dict(), - target_state_dict=controlnet.state_dict(), - state_dict_mapping=mapping, - ) - - return {k: v.half() for k, v in state_dict.items()} - - -def main() -> None: - parser = argparse.ArgumentParser(description="Convert a diffusers ControlNet model to a Refiners ControlNet model") - parser.add_argument( - "--from", - type=str, - dest="source_path", - default="lllyasviel/sd-controlnet-depth", - help=( - "Can be a path to a .bin, a .safetensors file, or a model identifier from Hugging Face Hub. Defaults to" - " lllyasviel/sd-controlnet-depth" - ), - ) - parser.add_argument( - "--to", - type=str, - dest="output_path", - required=False, - default=None, - help=( - "Output path (.safetensors) for converted model. If not provided, the output path will be the same as the" - " source path." - ), - ) - args = parser.parse_args(namespace=Args()) - if args.output_path is None: - args.output_path = f"{Path(args.source_path).stem}-controlnet.safetensors" - state_dict = convert(args=args) - save_to_safetensors(path=args.output_path, tensors=state_dict) - - -if __name__ == "__main__": - main() diff --git a/scripts/conversion/convert_diffusers_ip_adapter.py b/scripts/conversion/convert_diffusers_ip_adapter.py deleted file mode 100644 index 8fcc3fe..0000000 --- a/scripts/conversion/convert_diffusers_ip_adapter.py +++ /dev/null @@ -1,154 +0,0 @@ -import argparse -from pathlib import Path - -import torch - -from refiners.fluxion.utils import save_to_safetensors -from refiners.foundationals.latent_diffusion import SD1IPAdapter, SD1UNet, SDXLIPAdapter, SDXLUNet - -# Running: -# -# from diffusers import UNet2DConditionModel -# unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") -# for k in unet.attn_processors.keys(): -# print(k) -# -# Gives: -# -# down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor -# down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor -# ... -# down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor -# up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor -# up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor -# ... -# up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor -# mid_block.attentions.0.transformer_blocks.0.attn1.processor -# mid_block.attentions.0.transformer_blocks.0.attn2.processor -# -# With attn1=self-attention and attn2=cross-attention, and middle block in last position. So in terms of increasing -# indices: -# -# DownBlocks -> [1, 3, 5, 7, 9, 11] -# MiddleBlock -> [31] -# UpBlocks -> [13, 15, 17, 19, 21, 23, 25, 27, 29] -# -# Same for SDXL with more layers (70 cross-attentions vs. 16) -CROSS_ATTN_MAPPING: dict[str, list[int]] = { - "sd15": list(range(1, 12, 2)) + [31] + list(range(13, 30, 2)), - "sdxl": list(range(1, 48, 2)) + list(range(121, 140, 2)) + list(range(49, 120, 2)), -} - - -def main() -> None: - parser = argparse.ArgumentParser(description="Converts a IP-Adapter diffusers model to refiners.") - parser.add_argument( - "--from", - type=str, - required=True, - dest="source_path", - help="Path to the source model. (e.g.: 'ip-adapter_sd15.bin').", - ) - parser.add_argument( - "--to", - type=str, - dest="output_path", - default=None, - help=( - "Path to save the converted model. If not specified, the output path will be the source path with the" - " extension changed to .safetensors." - ), - ) - parser.add_argument("--verbose", action="store_true", dest="verbose") - parser.add_argument("--half", action="store_true", dest="half") - args = parser.parse_args() - if args.output_path is None: - args.output_path = f"{Path(args.source_path).stem}.safetensors" - - # Do not use `load_tensors`: first-level values are not tensors. - weights: dict[str, dict[str, torch.Tensor]] = torch.load(args.source_path, "cpu") # type: ignore - assert isinstance(weights, dict) - assert sorted(weights.keys()) == ["image_proj", "ip_adapter"] - - image_proj_weights = weights["image_proj"] - ip_adapter_weights = weights["ip_adapter"] - - fine_grained = "latents" in image_proj_weights # aka IP-Adapter plus - - match len(ip_adapter_weights): - case 32: - ip_adapter = SD1IPAdapter(target=SD1UNet(in_channels=4), fine_grained=fine_grained) - cross_attn_mapping = CROSS_ATTN_MAPPING["sd15"] - case 140: - ip_adapter = SDXLIPAdapter(target=SDXLUNet(in_channels=4), fine_grained=fine_grained) - cross_attn_mapping = CROSS_ATTN_MAPPING["sdxl"] - case _: - raise ValueError("Unexpected number of keys in input checkpoint") - - # Manual conversion to avoid any runtime dependency on IP-Adapter[1] custom classes - # [1]: https://github.com/tencent-ailab/IP-Adapter - - state_dict: dict[str, torch.Tensor] = {} - - image_proj_state_dict: dict[str, torch.Tensor] - - if fine_grained: - w = image_proj_weights - image_proj_state_dict = { - "LatentsToken.Parameter.weight": w["latents"].squeeze(0), # drop batch dim = 1 - "Linear_1.weight": w["proj_in.weight"], - "Linear_1.bias": w["proj_in.bias"], - "Linear_2.weight": w["proj_out.weight"], - "Linear_2.bias": w["proj_out.bias"], - "LayerNorm.weight": w["norm_out.weight"], - "LayerNorm.bias": w["norm_out.bias"], - } - for i in range(4): - t_pfx, s_pfx = f"Transformer.TransformerLayer_{i+1}.Residual_", f"layers.{i}." - image_proj_state_dict.update( - { - f"{t_pfx}1.PerceiverAttention.Distribute.LayerNorm_1.weight": w[f"{s_pfx}0.norm1.weight"], - f"{t_pfx}1.PerceiverAttention.Distribute.LayerNorm_1.bias": w[f"{s_pfx}0.norm1.bias"], - f"{t_pfx}1.PerceiverAttention.Distribute.LayerNorm_2.weight": w[f"{s_pfx}0.norm2.weight"], - f"{t_pfx}1.PerceiverAttention.Distribute.LayerNorm_2.bias": w[f"{s_pfx}0.norm2.bias"], - f"{t_pfx}1.PerceiverAttention.Parallel.Chain_2.Linear.weight": w[f"{s_pfx}0.to_q.weight"], - f"{t_pfx}1.PerceiverAttention.Parallel.Chain_1.Linear.weight": w[f"{s_pfx}0.to_kv.weight"], - f"{t_pfx}1.PerceiverAttention.Linear.weight": w[f"{s_pfx}0.to_out.weight"], - f"{t_pfx}2.LayerNorm.weight": w[f"{s_pfx}1.0.weight"], - f"{t_pfx}2.LayerNorm.bias": w[f"{s_pfx}1.0.bias"], - f"{t_pfx}2.FeedForward.Linear_1.weight": w[f"{s_pfx}1.1.weight"], - f"{t_pfx}2.FeedForward.Linear_2.weight": w[f"{s_pfx}1.3.weight"], - } - ) - else: - image_proj_state_dict = { - "Linear.weight": image_proj_weights["proj.weight"], - "Linear.bias": image_proj_weights["proj.bias"], - "LayerNorm.weight": image_proj_weights["norm.weight"], - "LayerNorm.bias": image_proj_weights["norm.bias"], - } - ip_adapter.image_proj.load_state_dict(state_dict=image_proj_state_dict) - - for k, v in image_proj_state_dict.items(): - state_dict[f"image_proj.{k}"] = v - - assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2 - - for i, _ in enumerate(ip_adapter.sub_adapters): - cross_attn_index = cross_attn_mapping[i] - k_ip = f"{cross_attn_index}.to_k_ip.weight" - v_ip = f"{cross_attn_index}.to_v_ip.weight" - - # the name of the key is not checked at runtime, so we keep the original name - state_dict[f"ip_adapter.{i:03d}.to_k_ip.weight"] = ip_adapter_weights[k_ip] - state_dict[f"ip_adapter.{i:03d}.to_v_ip.weight"] = ip_adapter_weights[v_ip] - - if args.half: - state_dict = {key: value.half() for key, value in state_dict.items()} - if args.output_path is None: - args.output_path = f"{Path(args.source_path).stem}.safetensors" - save_to_safetensors(path=args.output_path, tensors=state_dict) - - -if __name__ == "__main__": - main() diff --git a/scripts/conversion/convert_diffusers_t2i_adapter.py b/scripts/conversion/convert_diffusers_t2i_adapter.py deleted file mode 100644 index af3d024..0000000 --- a/scripts/conversion/convert_diffusers_t2i_adapter.py +++ /dev/null @@ -1,63 +0,0 @@ -import argparse -from pathlib import Path - -import torch -from diffusers import T2IAdapter # type: ignore -from torch import nn - -from refiners.fluxion.model_converter import ModelConverter -from refiners.foundationals.latent_diffusion.t2i_adapter import ConditionEncoder, ConditionEncoderXL - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Convert a pretrained diffusers T2I-Adapter model to refiners") - parser.add_argument( - "--from", - type=str, - dest="source_path", - required=True, - help="Path or repository name of the source model. (e.g.: 'ip-adapter_sd15.bin').", - ) - parser.add_argument( - "--to", - type=str, - dest="output_path", - default=None, - help=( - "Path to save the converted model (extension will be .safetensors). If not specified, the output path will" - " be the source path with the extension changed to .safetensors." - ), - ) - parser.add_argument( - "--half", - action="store_true", - dest="use_half", - default=False, - help="Use this flag to save the output file as half precision (default: full precision).", - ) - parser.add_argument( - "--verbose", - action="store_true", - dest="verbose", - default=False, - help="Use this flag to print verbose output during conversion.", - ) - args = parser.parse_args() - if args.output_path is None: - args.output_path = f"{Path(args.source_path).name}.safetensors" - assert args.output_path is not None - - sdxl = "xl" in args.source_path - target = ConditionEncoderXL() if sdxl else ConditionEncoder() - # low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate` - source: nn.Module = T2IAdapter.from_pretrained( # type: ignore - pretrained_model_name_or_path=args.source_path, - low_cpu_mem_usage=False, - ) - assert isinstance(source, nn.Module), "Source model is not a nn.Module" - - x = torch.randn(1, 3, 1024, 1024) if sdxl else torch.randn(1, 3, 512, 512) - converter = ModelConverter(source_model=source, target_model=target, verbose=args.verbose) - if not converter.run(source_args=(x,)): - raise RuntimeError("Model conversion failed") - - converter.save_to_safetensors(path=args.output_path, half=args.use_half) diff --git a/scripts/conversion/convert_diffusers_unet.py b/scripts/conversion/convert_diffusers_unet.py deleted file mode 100644 index 694f10f..0000000 --- a/scripts/conversion/convert_diffusers_unet.py +++ /dev/null @@ -1,148 +0,0 @@ -import argparse -from pathlib import Path -from typing import Any - -import torch -from diffusers import UNet2DConditionModel # type: ignore -from torch import nn - -from refiners.fluxion.model_converter import ModelConverter -from refiners.fluxion.utils import load_from_safetensors, load_tensors -from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet -from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import SDXLLcmAdapter - - -class Args(argparse.Namespace): - source_path: str - output_path: str | None - subfolder: str - half: bool - verbose: bool - skip_init_check: bool - override_weights: str | None - - -def setup_converter(args: Args) -> ModelConverter: - # low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate` - source: nn.Module = UNet2DConditionModel.from_pretrained( # type: ignore - pretrained_model_name_or_path=args.source_path, - subfolder=args.subfolder, - low_cpu_mem_usage=False, - ) - if args.override_weights is not None: - if args.override_weights.endswith(".pth"): - sd = load_tensors(args.override_weights) - elif args.override_weights.endswith(".safetensors"): - sd = load_from_safetensors(args.override_weights) - else: - raise ValueError(f"Unsupported file format: {args.override_weights}") - source.load_state_dict(sd) - source_in_channels: int = source.config.in_channels # type: ignore - source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore - source_has_time_ids: bool = source.config.addition_embed_type == "text_time" # type: ignore - source_is_lcm: bool = source.config.time_cond_proj_dim is not None - - if source_has_time_ids: - target = SDXLUNet(in_channels=source_in_channels) - else: - target = SD1UNet(in_channels=source_in_channels) - - if source_is_lcm: - assert isinstance(target, SDXLUNet) - SDXLLcmAdapter(target=target).inject() - - x = torch.randn(1, source_in_channels, 32, 32) - timestep = torch.tensor(data=[0]) - clip_text_embeddings = torch.randn(1, 77, source_clip_embedding_dim) - - target.set_timestep(timestep=timestep) - target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings) - added_cond_kwargs = {} - if isinstance(target, SDXLUNet): - added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)} - target.set_time_ids(time_ids=added_cond_kwargs["time_ids"]) - target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"]) - - target_args = (x,) - - source_kwargs: dict[str, Any] = {} - if source_has_time_ids: - source_kwargs["added_cond_kwargs"] = added_cond_kwargs - if source_is_lcm: - source_kwargs["timestep_cond"] = torch.randn(1, source.config.time_cond_proj_dim) - - source_args = { - "positional": (x, timestep, clip_text_embeddings), - "keyword": source_kwargs, - } - - converter = ModelConverter( - source_model=source, - target_model=target, - skip_init_check=args.skip_init_check, - skip_output_check=True, - verbose=args.verbose, - ) - if not converter.run( - source_args=source_args, - target_args=target_args, - ): - raise RuntimeError("Model conversion failed") - return converter - - -def main() -> None: - parser = argparse.ArgumentParser( - description="Converts a Diffusion UNet model to a Refiners SD1UNet or SDXLUNet model" - ) - parser.add_argument( - "--from", - type=str, - dest="source_path", - default="runwayml/stable-diffusion-v1-5", - help=( - "Can be a path to a .bin file, a .safetensors file or a model name from the HuggingFace Hub. Default:" - " runwayml/stable-diffusion-v1-5" - ), - ) - parser.add_argument( - "--override-weights", - type=str, - default=None, - help=( - "Path to a weights file to override the source model (keeping its config). " - "This is useful for models distributed as .pth files." - ), - ) - parser.add_argument( - "--to", - type=str, - dest="output_path", - default=None, - help=( - "Output path (.safetensors) for converted model. If not provided, the output path will be the same as the" - " source path." - ), - ) - parser.add_argument("--subfolder", type=str, default="unet", help="Subfolder. Default: unet.") - parser.add_argument( - "--skip-init-check", - action="store_true", - help="Skip check that source and target have the same layers count.", - ) - parser.add_argument("--half", action="store_true", help="Convert to half precision.") - parser.add_argument( - "--verbose", - action="store_true", - default=False, - help="Prints additional information during conversion. Default: False", - ) - args = parser.parse_args(namespace=Args()) - if args.output_path is None: - args.output_path = f"{Path(args.source_path).stem}-unet.safetensors" - converter = setup_converter(args=args) - converter.save_to_safetensors(path=args.output_path, half=args.half) - - -if __name__ == "__main__": - main() diff --git a/scripts/conversion/convert_dinov2.py b/scripts/conversion/convert_dinov2.py deleted file mode 100644 index ec4b204..0000000 --- a/scripts/conversion/convert_dinov2.py +++ /dev/null @@ -1,176 +0,0 @@ -import argparse -from pathlib import Path - -import torch - -from refiners.fluxion.utils import load_tensors, save_to_safetensors - - -def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None: - """Convert a DINOv2 weights from facebook to refiners.""" - # get depth from "blocks" keys - depth = max([int(k.split(".")[1]) for k in weights.keys() if k.startswith("blocks.")]) + 1 - - # only needed when pre-training - del weights["mask_token"] - - # squeeze cls_token and position_embeddings - weights["cls_token"] = weights["cls_token"].squeeze(0) - weights["pos_embed"] = weights["pos_embed"].squeeze(0) - - # rename "w12" to "fc1" and "w3" to "fc2", only for giant model - for key in list(weights.keys()): - if "w3" in key: - new_key = key.replace("w3", "fc2") - weights[new_key] = weights.pop(key) - elif "w12" in key: - # we swap w1 and w2 because of the difference between our GLU implementation and theirs - # see https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/layers/swiglu_ffn.py#L31-L34 - # and https://github.com/finegrain-ai/refiners/blob/a2ee70578361e4d84a65a8708564480a9b0ec67e/src/refiners/fluxion/layers/activations.py#L158-L160 - weight = weights.pop(key) - w1, w2 = weight.chunk(2, dim=0) - w21 = torch.cat([w2, w1], dim=0) - new_key = key.replace("w12", "fc1") - weights[new_key] = w21 - - rename_keys: list[tuple[str, str]] = [ - ("cls_token", "Concatenate.ClassToken.Parameter.weight"), - ("pos_embed", "PositionalEncoder.PositionalEmbedding.Parameter.weight"), - ("patch_embed.proj.weight", "Concatenate.PatchEncoder.Conv2d.weight"), - ("patch_embed.proj.bias", "Concatenate.PatchEncoder.Conv2d.bias"), - ("norm.weight", "LayerNorm.weight"), - ("norm.bias", "LayerNorm.bias"), - ] - for i in range(depth): - rename_keys.append( - ( - f"blocks.{i}.norm1.weight", - f"Transformer.TransformerLayer_{i+1}.Residual_1.LayerNorm.weight", - ), - ) - rename_keys.append( - ( - f"blocks.{i}.norm1.bias", - f"Transformer.TransformerLayer_{i+1}.Residual_1.LayerNorm.bias", - ), - ) - rename_keys.append( - ( - f"blocks.{i}.attn.proj.weight", - f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Linear.weight", - ), - ) - rename_keys.append( - ( - f"blocks.{i}.attn.proj.bias", - f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Linear.bias", - ), - ) - rename_keys.append( - ( - f"blocks.{i}.ls1.gamma", - f"Transformer.TransformerLayer_{i+1}.Residual_1.LayerScale.weight", - ), - ) - rename_keys.append( - ( - f"blocks.{i}.norm2.weight", - f"Transformer.TransformerLayer_{i+1}.Residual_2.LayerNorm.weight", - ), - ) - rename_keys.append( - ( - f"blocks.{i}.norm2.bias", - f"Transformer.TransformerLayer_{i+1}.Residual_2.LayerNorm.bias", - ), - ) - rename_keys.append( - ( - f"blocks.{i}.mlp.fc1.weight", - f"Transformer.TransformerLayer_{i+1}.Residual_2.FeedForward.Linear_1.weight", - ), - ) - rename_keys.append( - ( - f"blocks.{i}.mlp.fc1.bias", - f"Transformer.TransformerLayer_{i+1}.Residual_2.FeedForward.Linear_1.bias", - ), - ) - rename_keys.append( - ( - f"blocks.{i}.mlp.fc2.weight", - f"Transformer.TransformerLayer_{i+1}.Residual_2.FeedForward.Linear_2.weight", - ), - ) - rename_keys.append( - ( - f"blocks.{i}.mlp.fc2.bias", - f"Transformer.TransformerLayer_{i+1}.Residual_2.FeedForward.Linear_2.bias", - ), - ) - rename_keys.append( - ( - f"blocks.{i}.ls2.gamma", - f"Transformer.TransformerLayer_{i+1}.Residual_2.LayerScale.weight", - ), - ) - - if "register_tokens" in weights: - weights["register_tokens"] = weights["register_tokens"].squeeze(0) - rename_keys.append(("register_tokens", "Registers.Parameter.weight")) - - # rename keys - for old_key, new_key in rename_keys: - weights[new_key] = weights.pop(old_key) - - # split the qkv weights and biases - for i in range(depth): - qkv_weight = weights.pop(f"blocks.{i}.attn.qkv.weight") - q_weight, k_weight, v_weight = qkv_weight.chunk(3, dim=0) - weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_1.weight"] = q_weight - weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_2.weight"] = k_weight - weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_3.weight"] = v_weight - - qkv_bias = weights.pop(f"blocks.{i}.attn.qkv.bias") - q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0) - weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_1.bias"] = q_bias - weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_2.bias"] = k_bias - weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_3.bias"] = v_bias - - -def main() -> None: - parser = argparse.ArgumentParser() - parser.add_argument( - "--from", - type=str, - required=True, - dest="source_path", - help=( - "Official checkpoint from https://github.com/facebookresearch/dinov2" - " e.g. /path/to/dinov2_vits14_pretrain.pth" - ), - ) - parser.add_argument( - "--to", - type=str, - dest="output_path", - default=None, - help=( - "Path to save the converted model. If not specified, the output path will be the source path with the" - " extension changed to .safetensors." - ), - ) - parser.add_argument("--half", action="store_true", dest="half") - args = parser.parse_args() - - weights = load_tensors(args.source_path) - convert_dinov2_facebook(weights) - if args.half: - weights = {key: value.half() for key, value in weights.items()} - if args.output_path is None: - args.output_path = f"{Path(args.source_path).stem}.safetensors" - save_to_safetensors(path=args.output_path, tensors=weights) - - -if __name__ == "__main__": - main() diff --git a/scripts/conversion/convert_ella_adapter.py b/scripts/conversion/convert_ella_adapter.py deleted file mode 100644 index 74aaa74..0000000 --- a/scripts/conversion/convert_ella_adapter.py +++ /dev/null @@ -1,102 +0,0 @@ -import argparse -from pathlib import Path - -import torch -from huggingface_hub import hf_hub_download # type: ignore - -from refiners.fluxion.utils import load_from_safetensors, save_to_safetensors - - -class Args(argparse.Namespace): - source_path: str - output_path: str | None - use_half: bool - - -def convert(args: Args) -> dict[str, torch.Tensor]: - if Path(args.source_path).suffix != ".safetensors": - args.source_path = hf_hub_download( - repo_id=args.source_path, filename="ella-sd1.5-tsc-t5xl.safetensors", local_dir="tests/weights/ELLA-Adapter" - ) - weights = load_from_safetensors(args.source_path) - - for key in list(weights.keys()): - if "latents" in key: - new_key = "PerceiverResampler.Latents.ParameterInitialized.weight" - weights[new_key] = weights.pop(key) - elif "time_embedding" in key: - new_key = key.replace("time_embedding", "TimestepEncoder.RangeEncoder").replace("linear", "Linear") - weights[new_key] = weights.pop(key) - elif "proj_in" in key: - new_key = f"PerceiverResampler.Linear.{key.split('.')[-1]}" - weights[new_key] = weights.pop(key) - elif "time_aware" in key: - new_key = f"PerceiverResampler.Residual.Linear.{key.split('.')[-1]}" - weights[new_key] = weights.pop(key) - elif "attn.in_proj" in key: - layer_num = int(key.split(".")[2]) - query_param, key_param, value_param = weights.pop(key).chunk(3, dim=0) - param_type = "weight" if "weight" in key else "bias" - for i, param in enumerate([query_param, key_param, value_param]): - new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_1.PerceiverAttention.Attention.Distribute.Linear_{i+1}.{param_type}" - weights[new_key] = param - elif "attn.out_proj" in key: - layer_num = int(key.split(".")[2]) - new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_1.PerceiverAttention.Attention.Linear.{key.split('.')[-1]}" - weights[new_key] = weights.pop(key) - elif "ln_ff" in key: - layer_num = int(key.split(".")[2]) - new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_2.AdaLayerNorm.Parallel.Chain.Linear.{key.split('.')[-1]}" - weights[new_key] = weights.pop(key) - elif "ln_1" in key or "ln_2" in key: - layer_num = int(key.split(".")[2]) - n = 1 if int(key.split(".")[3].split("_")[-1]) == 2 else 2 - new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_1.PerceiverAttention.Distribute.AdaLayerNorm_{n}.Parallel.Chain.Linear.{key.split('.')[-1]}" - weights[new_key] = weights.pop(key) - elif "mlp" in key: - layer_num = int(key.split(".")[2]) - n = 1 if "c_fc" in key else 2 - new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_2.FeedForward.Linear_{n}.{key.split('.')[-1]}" - weights[new_key] = weights.pop(key) - - if args.use_half: - weights = {key: value.half() for key, value in weights.items()} - - return weights - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Convert a pretrained Ella Adapter to refiners implementation") - parser.add_argument( - "--from", - type=str, - dest="source_path", - default="QQGYLab/ELLA", - help=( - "A path to a local .safetensors weights. If not provided, a repo from Hugging Face Hub will be used" - "Default to QQGYLab/ELLA" - ), - ) - - parser.add_argument( - "--to", - type=str, - dest="output_path", - default=None, - help=( - "Path to save the converted model (extension will be .safetensors). If not specified, the output path will" - " be the source path with the prefix set to refiners" - ), - ) - parser.add_argument( - "--half", - action="store_true", - dest="use_half", - default=True, - help="Use this flag to save the output file as half precision (default: full precision).", - ) - args = parser.parse_args(namespace=Args()) - weights = convert(args) - if args.output_path is None: - args.output_path = f"{Path(args.source_path).stem}-refiners.safetensors" - save_to_safetensors(path=args.output_path, tensors=weights) diff --git a/scripts/conversion/convert_fooocus_control_lora.py b/scripts/conversion/convert_fooocus_control_lora.py deleted file mode 100644 index b378973..0000000 --- a/scripts/conversion/convert_fooocus_control_lora.py +++ /dev/null @@ -1,348 +0,0 @@ -import argparse -import logging -from logging import info -from pathlib import Path - -from huggingface_hub import hf_hub_download # type: ignore -from torch import Tensor -from torch.nn import Parameter as TorchParameter - -from refiners.fluxion.adapters.lora import Lora, LoraAdapter, auto_attach_loras -from refiners.fluxion.layers import Conv2d -from refiners.fluxion.layers.linear import Linear -from refiners.fluxion.utils import load_from_safetensors, save_to_safetensors -from refiners.foundationals.latent_diffusion.lora import SDLoraManager -from refiners.foundationals.latent_diffusion.stable_diffusion_xl.control_lora import ( - ConditionEncoder, - ControlLora, - ControlLoraAdapter, - ZeroConvolution, -) -from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL - - -def sort_keys(key: str, /) -> tuple[str, int]: - """Compute the score of a key, relatively to its suffix. - - When used by [`sorted`][sorted], the keys will only be sorted "at the suffix level". - - Args: - key: The key to sort. - - Returns: - The padded suffix of the key. - The score of the key's suffix. - """ - if "time_embed" in key: # HACK: will place the "time_embed" layers at very start of the list - return ("", -2) - - if "label_emb" in key: # HACK: will place the "label_emb" layers right after "time_embed" - return ("", -1) - - if "proj_out" in key: # HACK: will place the "proj_out" layers at the end of each "transformer_blocks" - return (key.removesuffix("proj_out") + "transformer_blocks.99.ff.net.2", 10) - - return SDLoraManager.sort_keys(key) - - -def load_lora_layers( - name: str, - state_dict: dict[str, Tensor], - control_lora: ControlLora, -) -> dict[str, Lora[Linear | Conv2d]]: - """Load the LoRA layers from the state_dict into the ControlLora. - - Args: - name: The name of the LoRA. - state_dict: The state_dict of the LoRA. - control_lora: The ControlLora to load the LoRA layers into. - """ - # filter from the state_dict the layers that will be used for the LoRA layers - lora_weights = {f"{key}.weight": value for key, value in state_dict.items() if ".up" in key or ".down" in key} - - # move the tensors to the device and dtype of the ControlLora - lora_weights = { - key: value.to( - dtype=control_lora.dtype, - device=control_lora.device, - ) - for key, value in lora_weights.items() - } - - # load every LoRA layers from the filtered state_dict - lora_layers = Lora.from_dict(name, state_dict=lora_weights) - - # sort all the LoRA's keys using the `sort_keys` method - lora_layers = { - key: lora_layers[key] - for key in sorted( - lora_layers.keys(), - key=sort_keys, - ) - } - - # auto-attach the LoRA layers to the U-Net - auto_attach_loras(lora_layers, control_lora, exclude=["ZeroConvolution", "ConditionEncoder"]) - - # eject all the LoRA adapters from the U-Net - # because we need each target path as if the adapter wasn't injected - for lora_layer in lora_layers.values(): - lora_adapter = lora_layer.parent - assert isinstance(lora_adapter, LoraAdapter) - lora_adapter.eject() - - return lora_layers - - -def load_condition_encoder( - state_dict: dict[str, Tensor], - control_lora: ControlLora, -) -> None: - """Load the ConditionEncoder's Conv2d layers from the state_dict into the ControlLora. - - Args: - state_dict: The state_dict of the ConditionEncoder. - control_lora: The control_lora to load the ConditionEncoder's Conv2d layers into. - """ - # filter from the state_dict the layers that will be used for the ConditionEncoder - condition_encoder_tensors = {key: value for key, value in state_dict.items() if "input_hint_block" in key} - - # move the tensors to the device and dtype of the ControlLora - condition_encoder_tensors = { - key: value.to( - dtype=control_lora.dtype, - device=control_lora.device, - ) - for key, value in condition_encoder_tensors.items() - } - - # find the ConditionEncoder's Conv2d layers - condition_encoder_layer = control_lora.ensure_find(ConditionEncoder) - condition_encoder_conv2ds = list(condition_encoder_layer.layers(Conv2d)) - - # replace the Conv2d layers' weights and biases with the ones from the state_dict - for i, layer in enumerate(condition_encoder_conv2ds): - layer.weight = TorchParameter(condition_encoder_tensors[f"input_hint_block.{i*2}.weight"]) - layer.bias = TorchParameter(condition_encoder_tensors[f"input_hint_block.{i*2}.bias"]) - - -def load_zero_convolutions( - state_dict: dict[str, Tensor], - control_lora: ControlLora, -) -> None: - """Load the ZeroConvolution's Conv2d layers from the state_dict into the ControlLora. - - Args: - state_dict: The state_dict of the ZeroConvolution. - control_lora: The ControlLora to load the ZeroConvolution's Conv2d layers into. - """ - # filter from the state_dict the layers that will be used for the ZeroConvolution layers - zero_convolution_tensors = {key: value for key, value in state_dict.items() if "zero_convs" in key} - n = len(zero_convolution_tensors) // 2 - zero_convolution_tensors[f"zero_convs.{n}.0.weight"] = state_dict["middle_block_out.0.weight"] - zero_convolution_tensors[f"zero_convs.{n}.0.bias"] = state_dict["middle_block_out.0.bias"] - - # move the tensors to the device and dtype of the ControlLora - zero_convolution_tensors = { - key: value.to( - dtype=control_lora.dtype, - device=control_lora.device, - ) - for key, value in zero_convolution_tensors.items() - } - - # find the ZeroConvolution's Conv2d layers - zero_convolution_layers = list(control_lora.layers(ZeroConvolution)) - zero_convolution_conv2ds = [layer.ensure_find(Conv2d) for layer in zero_convolution_layers] - - # replace the Conv2d layers' weights and biases with the ones from the state_dict - for i, layer in enumerate(zero_convolution_conv2ds): - layer.weight = TorchParameter(zero_convolution_tensors[f"zero_convs.{i}.0.weight"]) - layer.bias = TorchParameter(zero_convolution_tensors[f"zero_convs.{i}.0.bias"]) - - -def simplify_key(key: str, prefix: str, index: int | None = None) -> str: - """Simplify a key by stripping everything to the left of the prefix. - - Also optionally add a zero-padded index to the prefix. - - Example: - >>> simplify_key("foo.bar.ControlLora.something", "ControlLora", 1) - "ControlLora_01.something" - - >>> simplify_key("foo.bar.ControlLora.DownBlocks.something", "ControlLora") - "ControlLora.DownBlocks.something" - - Args: - key: The key to simplify. - prefix: The prefix to remove. - index: The index to add. - """ - _, right = key.split(prefix, maxsplit=1) - if index: - return f"{prefix}_{index:02d}{right}" - else: - return f"{prefix}{right}" - - -def convert_lora_layers( - lora_layers: dict[str, Lora[Linear | Conv2d]], - control_lora: ControlLora, - refiners_state_dict: dict[str, Tensor], -) -> None: - """Convert the LoRA layers to the refiners format. - - Args: - lora_layers: The LoRA layers to convert. - control_lora: The ControlLora to convert the LoRA layers from. - refiners_state_dict: The refiners state dict to update with the converted LoRA layers. - """ - for lora_layer in lora_layers.values(): - # get the adapter associated with the LoRA layer - lora_adapter = lora_layer.parent - assert isinstance(lora_adapter, LoraAdapter) - - # get the path of the adapter's target in the ControlLora - target = lora_adapter.target - path = target.get_path(parent=control_lora.ensure_find_parent(target)) - - state_dict = { - f"{path}.down": lora_layer.down.weight, - f"{path}.up": lora_layer.up.weight, - } - state_dict = {simplify_key(key, "ControlLora."): param for key, param in state_dict.items()} - refiners_state_dict.update(state_dict) - - -def convert_zero_convolutions( - control_lora: ControlLora, - refiners_state_dict: dict[str, Tensor], -) -> None: - """Convert the ZeroConvolution layers to the refiners format. - - Args: - control_lora: The ControlLora to convert the ZeroConvolution layers from. - refiners_state_dict: The refiners state dict to update with the converted ZeroConvolution layers. - """ - zero_convolution_layers = list(control_lora.layers(ZeroConvolution)) - for i, zero_convolution_layer in enumerate(zero_convolution_layers): - state_dict = zero_convolution_layer.state_dict() - path = zero_convolution_layer.get_path() - state_dict = {f"{path}.{key}": param for key, param in state_dict.items()} - state_dict = {simplify_key(key, "ZeroConvolution", i + 1): param for key, param in state_dict.items()} - refiners_state_dict.update(state_dict) - - -def convert_condition_encoder( - control_lora: ControlLora, - refiners_state_dict: dict[str, Tensor], -) -> None: - """Convert the ConditionEncoder to the refiners format. - - Args: - control_lora: The ControlLora to convert the ConditionEncoder from. - refiners_state_dict: The refiners state dict to update with the converted ConditionEncoder. - """ - condition_encoder_layer = control_lora.ensure_find(ConditionEncoder) - path = condition_encoder_layer.get_path() - state_dict = condition_encoder_layer.state_dict() - state_dict = {f"{path}.{key}": param for key, param in state_dict.items()} - state_dict = {simplify_key(key, "ConditionEncoder"): param for key, param in state_dict.items()} - refiners_state_dict.update(state_dict) - - -def convert( - name: str, - state_dict_path: Path, - output_path: Path, -) -> None: - sdxl = StableDiffusion_XL() - info("Stable Diffusion XL model initialized") - - fooocus_state_dict = load_from_safetensors(state_dict_path) - info(f"Fooocus weights loaded from: {state_dict_path}") - - control_lora_adapter = ControlLoraAdapter(target=sdxl.unet, name=name).inject() - control_lora = control_lora_adapter.control_lora - info("ControlLoraAdapter initialized") - - lora_layers = load_lora_layers(name, fooocus_state_dict, control_lora) - info("LoRA layers loaded") - - load_zero_convolutions(fooocus_state_dict, control_lora) - info("ZeroConvolution layers loaded") - - load_condition_encoder(fooocus_state_dict, control_lora) - info("ConditionEncoder loaded") - - refiners_state_dict: dict[str, Tensor] = {} - convert_lora_layers(lora_layers, control_lora, refiners_state_dict) - info("LoRA layers converted to refiners format") - - convert_zero_convolutions(control_lora, refiners_state_dict) - info("ZeroConvolution layers converted to refiners format") - - convert_condition_encoder(control_lora, refiners_state_dict) - info("ConditionEncoder converted to refiners format") - - output_path.parent.mkdir(parents=True, exist_ok=True) - save_to_safetensors(path=output_path, tensors=refiners_state_dict) - info(f"Converted ControlLora state dict saved to disk at: {output_path}") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Convert ControlLora (from Fooocus) weights to refiners.", - ) - - parser.add_argument( - "--from", - type=Path, - dest="source_path", - default="lllyasviel/misc:control-lora-canny-rank128.safetensors", - help="Path to the state_dict of the ControlLora, or a Hugging Face model ID.", - ) - - parser.add_argument( - "--to", - type=Path, - dest="output_path", - help=( - "Path to save the converted model (extension will be .safetensors)." - "If not specified, the output path will be the source path with the extension changed to .safetensors." - ), - ) - - parser.add_argument( - "--verbose", - action="store_true", - dest="verbose", - default=False, - help="Use this flag to print verbose output during conversion.", - ) - - args = parser.parse_args() - - if args.verbose: - logging.basicConfig( - level=logging.INFO, - format="%(levelname)s: %(message)s", - ) - - if not args.source_path.exists(): - repo_id, filename = str(args.source_path).split(":") - args.source_path = Path( - hf_hub_download( - repo_id=repo_id, - filename=filename, - ) - ) - - if args.output_path is None: - args.output_path = Path(f"refiners_{args.source_path.stem}.safetensors") - - convert( - name=args.source_path.stem, - state_dict_path=args.source_path, - output_path=args.output_path, - ) diff --git a/scripts/conversion/convert_hq_segment_anything.py b/scripts/conversion/convert_hq_segment_anything.py deleted file mode 100644 index cbb36b7..0000000 --- a/scripts/conversion/convert_hq_segment_anything.py +++ /dev/null @@ -1,81 +0,0 @@ -import argparse - -from torch import Tensor - -from refiners.fluxion.utils import load_tensors, save_to_safetensors - - -def main() -> None: - parser = argparse.ArgumentParser(description="Convert HQ SAM model to Refiners state_dict format") - parser.add_argument( - "--from", - type=str, - dest="source_path", - required=True, - default="sam_hq_vit_h.pth", - help="Path to the source model checkpoint.", - ) - parser.add_argument( - "--to", - type=str, - dest="output_path", - required=True, - default="refiners_sam_hq_vit_h.safetensors", - help="Path to save the converted model in Refiners format.", - ) - args = parser.parse_args() - - source_state_dict = load_tensors(args.source_path) - - state_dict: dict[str, Tensor] = {} - - for suffix in ["weight", "bias"]: - state_dict[f"HQFeatures.CompressViTFeat.ConvTranspose2d_1.{suffix}"] = source_state_dict[ - f"mask_decoder.compress_vit_feat.0.{suffix}" - ] - state_dict[f"HQFeatures.EmbeddingEncoder.ConvTranspose2d_1.{suffix}"] = source_state_dict[ - f"mask_decoder.embedding_encoder.0.{suffix}" - ] - state_dict[f"EmbeddingMaskfeature.Conv2d_1.{suffix}"] = source_state_dict[ - f"mask_decoder.embedding_maskfeature.0.{suffix}" - ] - - state_dict[f"HQFeatures.CompressViTFeat.LayerNorm2d.{suffix}"] = source_state_dict[ - f"mask_decoder.compress_vit_feat.1.{suffix}" - ] - state_dict[f"HQFeatures.EmbeddingEncoder.LayerNorm2d.{suffix}"] = source_state_dict[ - f"mask_decoder.embedding_encoder.1.{suffix}" - ] - state_dict[f"EmbeddingMaskfeature.LayerNorm2d.{suffix}"] = source_state_dict[ - f"mask_decoder.embedding_maskfeature.1.{suffix}" - ] - - state_dict[f"HQFeatures.CompressViTFeat.ConvTranspose2d_2.{suffix}"] = source_state_dict[ - f"mask_decoder.compress_vit_feat.3.{suffix}" - ] - state_dict[f"HQFeatures.EmbeddingEncoder.ConvTranspose2d_2.{suffix}"] = source_state_dict[ - f"mask_decoder.embedding_encoder.3.{suffix}" - ] - state_dict[f"EmbeddingMaskfeature.Conv2d_2.{suffix}"] = source_state_dict[ - f"mask_decoder.embedding_maskfeature.3.{suffix}" - ] - - state_dict = {f"Chain.HQSAMMaskPrediction.Chain.DenseEmbeddingUpscalingHQ.{k}": v for k, v in state_dict.items()} - - # HQ Token - state_dict["MaskDecoderTokensExtender.hq_token.weight"] = source_state_dict["mask_decoder.hf_token.weight"] - - # HQ MLP - for i in range(3): - state_dict[f"Chain.HQSAMMaskPrediction.HQTokenMLP.MultiLinear.Linear_{i+1}.weight"] = source_state_dict[ - f"mask_decoder.hf_mlp.layers.{i}.weight" - ] - state_dict[f"Chain.HQSAMMaskPrediction.HQTokenMLP.MultiLinear.Linear_{i+1}.bias"] = source_state_dict[ - f"mask_decoder.hf_mlp.layers.{i}.bias" - ] - - save_to_safetensors(path=args.output_path, tensors=state_dict) - - -if __name__ == "__main__": - main() diff --git a/scripts/conversion/convert_ic_light.py b/scripts/conversion/convert_ic_light.py deleted file mode 100644 index ddc367e..0000000 --- a/scripts/conversion/convert_ic_light.py +++ /dev/null @@ -1,89 +0,0 @@ -import argparse -from pathlib import Path - -from convert_diffusers_unet import Args as UNetArgs, setup_converter as setup_unet_converter -from huggingface_hub import hf_hub_download # type: ignore - -from refiners.fluxion.utils import load_from_safetensors, save_to_safetensors - - -class Args(argparse.Namespace): - source_path: str - output_path: str | None - subfolder: str - half: bool - verbose: bool - reference_unet_path: str - - -def main() -> None: - parser = argparse.ArgumentParser(description="Converts IC-Light patch weights to work with Refiners") - parser.add_argument( - "--from", - type=str, - dest="source_path", - default="lllyasviel/ic-light", - help=( - "Can be a path to a .bin file, a .safetensors file or a model name from the Hugging Face Hub. Default:" - " lllyasviel/ic-light" - ), - ) - parser.add_argument("--filename", type=str, default="iclight_sd15_fc.safetensors", help="Filename inside the hub.") - parser.add_argument( - "--to", - type=str, - dest="output_path", - default=None, - help=( - "Output path (.safetensors) for converted model. If not provided, the output path will be the same as the" - " source path." - ), - ) - parser.add_argument( - "--verbose", - action="store_true", - default=False, - help="Prints additional information during conversion. Default: False", - ) - parser.add_argument( - "--reference-unet-path", - type=str, - dest="reference_unet_path", - default="runwayml/stable-diffusion-v1-5", - help="Path to the reference UNet weights.", - ) - args = parser.parse_args(namespace=Args()) - if args.output_path is None: - args.output_path = f"{Path(args.filename).stem}-refiners.safetensors" - - patch_file = ( - Path(args.source_path) - if args.source_path.endswith(".safetensors") - else Path( - hf_hub_download( - repo_id=args.source_path, - filename=args.filename, - ) - ) - ) - patch_weights = load_from_safetensors(patch_file) - - unet_args = UNetArgs( - source_path=args.reference_unet_path, - subfolder="unet", - half=False, - verbose=False, - skip_init_check=True, - override_weights=None, - ) - converter = setup_unet_converter(args=unet_args) - result = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage] - source_state_dict=patch_weights, - target_state_dict=converter.target_model.state_dict(), - state_dict_mapping=converter.get_mapping(), - ) - save_to_safetensors(path=args.output_path, tensors=result) - - -if __name__ == "__main__": - main() diff --git a/scripts/conversion/convert_informative_drawings.py b/scripts/conversion/convert_informative_drawings.py deleted file mode 100644 index df350b6..0000000 --- a/scripts/conversion/convert_informative_drawings.py +++ /dev/null @@ -1,65 +0,0 @@ -import argparse -from typing import cast - -import torch -from torch import nn - -from refiners.fluxion.model_converter import ModelConverter -from refiners.fluxion.utils import load_tensors -from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings - -try: - from model import Generator # type: ignore -except ImportError: - raise ImportError( - "Please download the model.py file from https://github.com/carolineec/informative-drawings and add it to your" - " PYTHONPATH" - ) - - -class Args(argparse.Namespace): - source_path: str - output_path: str - verbose: bool - half: bool - - -def setup_converter(args: Args) -> ModelConverter: - source = cast(nn.Module, Generator(3, 1, 3)) - source.load_state_dict(state_dict=load_tensors(args.source_path)) - source.eval() - target = InformativeDrawings() - x = torch.randn(1, 3, 512, 512) - converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose) - if not converter.run(source_args=(x,)): - raise RuntimeError("Model conversion failed") - return converter - - -def main() -> None: - parser = argparse.ArgumentParser( - description="Converts a pretrained Informative Drawings model to a refiners Informative Drawings model" - ) - parser.add_argument( - "--from", - type=str, - dest="source_path", - default="model2.pth", - help="Path to the source model. (default: 'model2.pth').", - ) - parser.add_argument( - "--to", - type=str, - dest="output_path", - default="informative-drawings.safetensors", - help="Path to save the converted model. (default: 'informative-drawings.safetensors').", - ) - parser.add_argument("--verbose", action="store_true", dest="verbose") - parser.add_argument("--half", action="store_true", dest="half") - args = parser.parse_args(namespace=Args()) - converter = setup_converter(args=args) - converter.save_to_safetensors(path=args.output_path, half=args.half) - - -if __name__ == "__main__": - main() diff --git a/scripts/conversion/convert_mvanet.py b/scripts/conversion/convert_mvanet.py deleted file mode 100644 index e5b30e7..0000000 --- a/scripts/conversion/convert_mvanet.py +++ /dev/null @@ -1,40 +0,0 @@ -import argparse -from pathlib import Path - -from refiners.fluxion.utils import load_tensors, save_to_safetensors -from refiners.foundationals.swin.mvanet.converter import convert_weights - - -def main() -> None: - parser = argparse.ArgumentParser() - parser.add_argument( - "--from", - type=str, - required=True, - dest="source_path", - help="A MVANet checkpoint. One can be found at https://github.com/qianyu-dlut/MVANet", - ) - parser.add_argument( - "--to", - type=str, - dest="output_path", - default=None, - help=( - "Path to save the converted model. If not specified, the output path will be the source path with the" - " extension changed to .safetensors." - ), - ) - parser.add_argument("--half", action="store_true", dest="half") - args = parser.parse_args() - - src_weights = load_tensors(args.source_path) - weights = convert_weights(src_weights) - if args.half: - weights = {key: value.half() for key, value in weights.items()} - if args.output_path is None: - args.output_path = f"{Path(args.source_path).stem}.safetensors" - save_to_safetensors(path=args.output_path, tensors=weights) - - -if __name__ == "__main__": - main() diff --git a/scripts/conversion/convert_segment_anything.py b/scripts/conversion/convert_segment_anything.py deleted file mode 100644 index 227c58a..0000000 --- a/scripts/conversion/convert_segment_anything.py +++ /dev/null @@ -1,268 +0,0 @@ -import argparse -import types -from typing import Any, Callable, cast - -import torch -import torch.nn as nn -from segment_anything import build_sam_vit_h # type: ignore -from segment_anything.modeling.common import LayerNorm2d # type: ignore -from torch import Tensor - -import refiners.fluxion.layers as fl -from refiners.fluxion.model_converter import ModelConverter -from refiners.fluxion.utils import load_tensors, manual_seed, save_to_safetensors -from refiners.foundationals.segment_anything.image_encoder import PositionalEncoder, SAMViTH -from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder -from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder - - -class FacebookSAM(nn.Module): - image_encoder: nn.Module - prompt_encoder: nn.Module - mask_decoder: nn.Module - - -build_sam_vit_h = cast(Callable[[], FacebookSAM], build_sam_vit_h) - - -assert issubclass(LayerNorm2d, nn.Module) -custom_layers = {LayerNorm2d: fl.LayerNorm2d} - - -class Args(argparse.Namespace): - source_path: str - output_path: str - half: bool - verbose: bool - - -def convert_mask_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]: - manual_seed(seed=0) - refiners_mask_encoder = MaskEncoder() - - converter = ModelConverter( - source_model=prompt_encoder.mask_downscaling, - target_model=refiners_mask_encoder, - custom_layer_mapping=custom_layers, # type: ignore - ) - - x = torch.randn(1, 256, 256) - mapping = converter.map_state_dicts(source_args=(x,)) - assert mapping - - source_state_dict = prompt_encoder.mask_downscaling.state_dict() - target_state_dict = refiners_mask_encoder.state_dict() - - # Mapping handled manually (see below) because nn.Parameter is a special case - del target_state_dict["no_mask_embedding"] - - converted_source = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage] - source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping - ) - - state_dict: dict[str, Tensor] = { - "no_mask_embedding": nn.Parameter(data=prompt_encoder.no_mask_embed.weight.clone()), # type: ignore - } - - state_dict.update(converted_source) - - refiners_mask_encoder.load_state_dict(state_dict=state_dict) - - return state_dict - - -def convert_point_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]: - manual_seed(seed=0) - point_embeddings: list[Tensor] = [pe.weight for pe in prompt_encoder.point_embeddings] + [ - prompt_encoder.not_a_point_embed.weight - ] # type: ignore - pe = prompt_encoder.pe_layer.positional_encoding_gaussian_matrix # type: ignore - assert isinstance(pe, Tensor) - state_dict: dict[str, Tensor] = { - "Residual.PointTypeEmbedding.weight": nn.Parameter(data=torch.cat(tensors=point_embeddings, dim=0)), - "CoordinateEncoder.Linear.weight": nn.Parameter(data=pe.T.contiguous()), - } - - refiners_prompt_encoder = PointEncoder() - refiners_prompt_encoder.load_state_dict(state_dict=state_dict) - - return state_dict - - -def convert_vit(vit: nn.Module) -> dict[str, Tensor]: - manual_seed(seed=0) - refiners_sam_vit_h = SAMViTH() - - converter = ModelConverter( - source_model=vit, - target_model=refiners_sam_vit_h, - custom_layer_mapping=custom_layers, # type: ignore - ) - converter.skip_init_check = True - - x = torch.randn(1, 3, 1024, 1024) - mapping = converter.map_state_dicts(source_args=(x,)) - assert mapping - - mapping["PositionalEncoder.Parameter.weight"] = "pos_embed" - - target_state_dict = refiners_sam_vit_h.state_dict() - del target_state_dict["PositionalEncoder.Parameter.weight"] - - source_state_dict = vit.state_dict() - pos_embed = source_state_dict["pos_embed"] - del source_state_dict["pos_embed"] - - target_rel_keys = [ - ( - f"Transformer.TransformerLayer_{i}.Residual_1.FusedSelfAttention.RelativePositionAttention.horizontal_embedding", - f"Transformer.TransformerLayer_{i}.Residual_1.FusedSelfAttention.RelativePositionAttention.vertical_embedding", - ) - for i in range(1, 33) - ] - source_rel_keys = [(f"blocks.{i}.attn.rel_pos_w", f"blocks.{i}.attn.rel_pos_h") for i in range(32)] - - rel_items: dict[str, Tensor] = {} - - for (key_w, key_h), (target_key_w, target_key_h) in zip(source_rel_keys, target_rel_keys): - rel_items[target_key_w] = source_state_dict[key_w] - rel_items[target_key_h] = source_state_dict[key_h] - del source_state_dict[key_w] - del source_state_dict[key_h] - del target_state_dict[target_key_w] - del target_state_dict[target_key_h] - - converted_source = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage] - source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping - ) - - positional_encoder = refiners_sam_vit_h.layer("PositionalEncoder", PositionalEncoder) - embed = pos_embed.reshape_as(positional_encoder.layer("Parameter", fl.Parameter).weight) - converted_source["PositionalEncoder.Parameter.weight"] = embed # type: ignore - converted_source.update(rel_items) - - refiners_sam_vit_h.load_state_dict(state_dict=converted_source) - assert converter.compare_models((x,), threshold=1e-2) - - return converted_source - - -def convert_mask_decoder(mask_decoder: nn.Module) -> dict[str, Tensor]: - manual_seed(seed=0) - - refiners_mask_decoder = MaskDecoder() - - image_embedding = torch.randn(1, 256, 64, 64) - dense_positional_embedding = torch.randn(1, 256, 64, 64) - point_embedding = torch.randn(1, 3, 256) - mask_embedding = torch.randn(1, 256, 64, 64) - - from segment_anything.modeling.common import LayerNorm2d # type: ignore - - import refiners.fluxion.layers as fl - - assert issubclass(LayerNorm2d, nn.Module) - custom_layers = {LayerNorm2d: fl.LayerNorm2d} - - converter = ModelConverter( - source_model=mask_decoder, - target_model=refiners_mask_decoder, - custom_layer_mapping=custom_layers, # type: ignore - ) - - inputs = { - "image_embeddings": image_embedding, - "image_pe": dense_positional_embedding, - "sparse_prompt_embeddings": point_embedding, - "dense_prompt_embeddings": mask_embedding, - "multimask_output": True, - } - - refiners_mask_decoder.set_image_embedding(image_embedding) - refiners_mask_decoder.set_point_embedding(point_embedding) - refiners_mask_decoder.set_mask_embedding(mask_embedding) - refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding) - - mapping = converter.map_state_dicts(source_args=inputs, target_args={}) - assert mapping is not None - mapping["MaskDecoderTokens.Parameter"] = "iou_token" - - state_dict = converter._convert_state_dict( # type: ignore - source_state_dict=mask_decoder.state_dict(), - target_state_dict=refiners_mask_decoder.state_dict(), - state_dict_mapping=mapping, - ) - state_dict["MaskDecoderTokens.Parameter.weight"] = torch.cat( - tensors=[mask_decoder.iou_token.weight, mask_decoder.mask_tokens.weight], dim=0 - ) # type: ignore - refiners_mask_decoder.load_state_dict(state_dict=state_dict) - - refiners_mask_decoder.set_image_embedding(image_embedding) - refiners_mask_decoder.set_point_embedding(point_embedding) - refiners_mask_decoder.set_mask_embedding(mask_embedding) - refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding) - - # Perform (1) upscaling then (2) mask prediction in this order (= like in the official implementation) to make - # `compare_models` happy (MaskPrediction's Matmul runs those in the reverse order by default) - matmul = refiners_mask_decoder.ensure_find(fl.Matmul) - - def forward_swapped_order(self: Any, *args: Any) -> Any: - y = self[1](*args) # (1) - x = self[0](*args) # (2) - return torch.matmul(input=x, other=y) - - matmul.forward = types.MethodType(forward_swapped_order, matmul) - - assert converter.compare_models(source_args=inputs, target_args={}, threshold=1e-3) - - return state_dict - - -def main() -> None: - parser = argparse.ArgumentParser(description="Converts a Segment Anything ViT model to a Refiners SAMViTH model") - parser.add_argument( - "--from", - type=str, - dest="source_path", - default="sam_vit_h_4b8939.pth", - # required=True, - help="Path to the Segment Anything model weights", - ) - parser.add_argument( - "--to", - type=str, - dest="output_path", - default="segment-anything-h.safetensors", - help="Output path for converted model (as safetensors).", - ) - parser.add_argument("--half", action="store_true", default=False, help="Convert to half precision. Default: False") - parser.add_argument( - "--verbose", - action="store_true", - default=False, - help="Prints additional information during conversion. Default: False", - ) - args = parser.parse_args(namespace=Args()) - - sam_h = build_sam_vit_h() # type: ignore - sam_h.load_state_dict(state_dict=load_tensors(args.source_path)) - - vit_state_dict = convert_vit(vit=sam_h.image_encoder) - mask_decoder_state_dict = convert_mask_decoder(mask_decoder=sam_h.mask_decoder) - point_encoder_state_dict = convert_point_encoder(prompt_encoder=sam_h.prompt_encoder) - mask_encoder_state_dict = convert_mask_encoder(prompt_encoder=sam_h.prompt_encoder) - - output_state_dict = { - **{f"SAMViTH.{key}": value for key, value in vit_state_dict.items()}, - **{f"MaskDecoder.{key}": value for key, value in mask_decoder_state_dict.items()}, - **{f"PointEncoder.{key}": value for key, value in point_encoder_state_dict.items()}, - **{f"MaskEncoder.{key}": value for key, value in mask_encoder_state_dict.items()}, - } - if args.half: - output_state_dict = {key: value.half() for key, value in output_state_dict.items()} - - save_to_safetensors(path=args.output_path, tensors=output_state_dict) - - -if __name__ == "__main__": - main() diff --git a/scripts/conversion/convert_transformers_clip_image_model.py b/scripts/conversion/convert_transformers_clip_image_model.py deleted file mode 100644 index 121b370..0000000 --- a/scripts/conversion/convert_transformers_clip_image_model.py +++ /dev/null @@ -1,149 +0,0 @@ -import argparse -from pathlib import Path -from typing import NamedTuple, cast - -import torch -from torch import nn -from transformers import CLIPVisionModelWithProjection # type: ignore - -import refiners.fluxion.layers as fl -from refiners.fluxion.model_converter import ModelConverter -from refiners.fluxion.utils import save_to_safetensors -from refiners.foundationals.clip.image_encoder import CLIPImageEncoder - - -class Args(argparse.Namespace): - source_path: str - subfolder: str - output_path: str | None - half: bool - verbose: bool - threshold: float - - -class CLIPImageEncoderConfig(NamedTuple): - architectures: list[str] - num_channels: int - hidden_size: int - hidden_act: str - image_size: int - projection_dim: int - patch_size: int - num_hidden_layers: int - num_attention_heads: int - intermediate_size: int - layer_norm_eps: float - - -def setup_converter(args: Args) -> ModelConverter: - # low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate` - source: nn.Module = CLIPVisionModelWithProjection.from_pretrained( # type: ignore - pretrained_model_name_or_path=args.source_path, - subfolder=args.subfolder, - low_cpu_mem_usage=False, - ) - assert isinstance(source, nn.Module), "Source model is not a nn.Module" - config = cast(CLIPImageEncoderConfig, source.config) # pyright: ignore[reportArgumentType, reportUnknownMemberType] - - assert ( - config.architectures[0] == "CLIPVisionModelWithProjection" - ), f"Unsupported architecture: {config.architectures[0]}" - assert config.num_channels == 3, f"Expected 3 input channels, got {config.num_channels}" - assert config.hidden_act == "gelu", f"Unsupported activation: {config.hidden_act}" - - target = CLIPImageEncoder( - image_size=config.image_size, - embedding_dim=config.hidden_size, - output_dim=config.projection_dim, - patch_size=config.patch_size, - num_layers=config.num_hidden_layers, - num_attention_heads=config.num_attention_heads, - feedforward_dim=config.intermediate_size, - layer_norm_eps=config.layer_norm_eps, - ) - - x = torch.randn(1, 3, config.image_size, config.image_size) - - converter = ModelConverter(source_model=source, target_model=target, verbose=True) - - # Custom conversion logic since the class embedding (fl.Parameter layer) is not supported out-of-the-box by the - # converter - mapping = converter.map_state_dicts((x,)) - assert mapping is not None - - source_state_dict = source.state_dict() - target_state_dict = target.state_dict() - - # Remove the class embedding from state dict since it was not mapped by the model converter - class_embedding = target.ensure_find(fl.Parameter) - class_embedding_key = next((n for n, p in target.named_parameters() if id(p) == id(class_embedding.weight)), None) - assert class_embedding_key is not None - assert class_embedding_key in target_state_dict - del target_state_dict[class_embedding_key] - - converted_state_dict = converter._convert_state_dict( # type: ignore[reportPrivateUsage] - source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping - ) - target.load_state_dict(state_dict=converted_state_dict, strict=False) - - # Ad hoc post-conversion steps - embed = source.vision_model.embeddings.class_embedding - class_embedding.weight = torch.nn.Parameter(embed.clone().reshape_as(class_embedding.weight)) # type: ignore - - assert converter.compare_models((x,), threshold=args.threshold) - - return converter - - -def main() -> None: - parser = argparse.ArgumentParser( - description="Converts a CLIPImageEncoder from the library transformers from the HuggingFace Hub to refiners." - ) - parser.add_argument( - "--from", - type=str, - dest="source_path", - default="stabilityai/stable-diffusion-2-1-unclip", - help=( - "Can be a path to a .bin file, a .safetensors file or a model name from the HuggingFace Hub. Default:" - " stabilityai/stable-diffusion-2-1-unclip" - ), - ) - parser.add_argument( - "--subfolder", - type=str, - dest="subfolder", - default="image_encoder", - help="Subfolder in the source path where the model is located inside the Hub. Default: image_encoder", - ) - parser.add_argument( - "--to", - type=str, - dest="output_path", - default=None, - help=( - "Output path (.safetensors) for converted model. If not provided, the output path will be the same as the" - " source path." - ), - ) - parser.add_argument("--half", action="store_true", help="Convert to half precision.") - parser.add_argument( - "--verbose", - action="store_true", - default=False, - help="Prints additional information during conversion. Default: False", - ) - parser.add_argument("--threshold", type=float, default=1e-2, help="Threshold for model comparison. Default: 1e-2") - args = parser.parse_args(namespace=Args()) - if args.output_path is None: - args.output_path = f"{Path(args.source_path).stem}-{args.subfolder}.safetensors" - converter = setup_converter(args=args) - # Do not use converter.save_to_safetensors since it is not in a valid state due to the ad hoc conversion - state_dict = converter.target_model.state_dict() - if args.half: - state_dict = {key: value.half() for key, value in state_dict.items()} - save_to_safetensors(path=args.output_path, tensors=state_dict) - - -if __name__ == "__main__": - main() diff --git a/scripts/conversion/convert_transformers_clip_text_model.py b/scripts/conversion/convert_transformers_clip_text_model.py deleted file mode 100644 index fd45767..0000000 --- a/scripts/conversion/convert_transformers_clip_text_model.py +++ /dev/null @@ -1,150 +0,0 @@ -import argparse -from pathlib import Path -from typing import NamedTuple, cast - -from torch import nn -from transformers import CLIPTextModel, CLIPTextModelWithProjection # type: ignore - -import refiners.fluxion.layers as fl -from refiners.fluxion.model_converter import ModelConverter -from refiners.fluxion.utils import save_to_safetensors -from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, CLIPTextEncoderG, CLIPTextEncoderL -from refiners.foundationals.clip.tokenizer import CLIPTokenizer -from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder - - -class Args(argparse.Namespace): - source_path: str - subfolder: str - output_path: str | None - half: bool - verbose: bool - - -class CLIPTextEncoderConfig(NamedTuple): - architectures: list[str] - vocab_size: int - hidden_size: int - intermediate_size: int - num_hidden_layers: int - num_attention_heads: int - hidden_act: str - layer_norm_eps: float - projection_dim: int - - -def setup_converter(args: Args, with_projection: bool = False) -> ModelConverter: - # low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate` - cls = CLIPTextModelWithProjection if with_projection else CLIPTextModel - source: nn.Module = cls.from_pretrained( # type: ignore - pretrained_model_name_or_path=args.source_path, - subfolder=args.subfolder, - low_cpu_mem_usage=False, - ) - assert isinstance(source, nn.Module), "Source model is not a nn.Module" - config = cast(CLIPTextEncoderConfig, source.config) # pyright: ignore[reportArgumentType, reportUnknownMemberType] - architecture: str = config.architectures[0] - embedding_dim: int = config.hidden_size - projection_dim: int = config.projection_dim - use_quick_gelu = config.hidden_act == "quick_gelu" - assert architecture in ("CLIPTextModel", "CLIPTextModelWithProjection"), f"Unsupported architecture: {architecture}" - target = CLIPTextEncoder( - embedding_dim=config.hidden_size, - num_layers=config.num_hidden_layers, - num_attention_heads=config.num_attention_heads, - feedforward_dim=config.intermediate_size, - use_quick_gelu=use_quick_gelu, - ) - if architecture == "CLIPTextModelWithProjection": - target.append(module=fl.Linear(in_features=embedding_dim, out_features=projection_dim, bias=False)) - text = "What a nice cat you have there!" - tokenizer = target.ensure_find(CLIPTokenizer) - tokens = tokenizer(text) - converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose) - if not converter.run(source_args=(tokens,), target_args=(text,)): - raise RuntimeError("Model conversion failed") - return converter - - -def main() -> None: - parser = argparse.ArgumentParser( - description="Converts a CLIPTextEncoder from the library transformers from the HuggingFace Hub to refiners." - ) - parser.add_argument( - "--from", - type=str, - dest="source_path", - default="runwayml/stable-diffusion-v1-5", - help=( - "Can be a path to a .bin file, a .safetensors file or a model name from the HuggingFace Hub. Default:" - " runwayml/stable-diffusion-v1-5" - ), - ) - parser.add_argument( - "--subfolder", - type=str, - dest="subfolder", - default="text_encoder", - help=( - "Subfolder in the source path where the model is located inside the Hub. Default: text_encoder (for" - " CLIPTextModel)" - ), - ) - parser.add_argument( - "--subfolder2", - type=str, - dest="subfolder2", - default=None, - help="Additional subfolder for the 2nd text encoder (useful for SDXL). Default: None", - ) - parser.add_argument( - "--to", - type=str, - dest="output_path", - default=None, - help=( - "Output path (.safetensors) for converted model. If not provided, the output path will be the same as the" - " source path." - ), - ) - parser.add_argument("--half", action="store_true", help="Convert to half precision.") - parser.add_argument( - "--verbose", - action="store_true", - default=False, - help="Prints additional information during conversion. Default: False", - ) - args = parser.parse_args(namespace=Args()) - if args.output_path is None: - args.output_path = f"{Path(args.source_path).stem}-{args.subfolder}.safetensors" - converter = setup_converter(args=args) - if args.subfolder2 is not None: - # Assume this is the second text encoder of Stable Diffusion XL - args.subfolder = args.subfolder2 - converter2 = setup_converter(args=args, with_projection=True) - - text_encoder_l = CLIPTextEncoderL() - text_encoder_l.load_state_dict(state_dict=converter.get_state_dict()) - - projection = cast(CLIPTextEncoder, converter2.target_model)[-1] - assert isinstance(projection, fl.Linear) - text_encoder_g_with_projection = CLIPTextEncoderG() - text_encoder_g_with_projection.append(module=projection) - text_encoder_g_with_projection.load_state_dict(state_dict=converter2.get_state_dict()) - - projection = text_encoder_g_with_projection.pop(index=-1) - assert isinstance(projection, fl.Linear) - double_text_encoder = DoubleTextEncoder( - text_encoder_l=text_encoder_l, text_encoder_g=text_encoder_g_with_projection, projection=projection - ) - - state_dict = double_text_encoder.state_dict() - if args.half: - state_dict = {key: value.half() for key, value in state_dict.items()} - save_to_safetensors(path=args.output_path, tensors=state_dict) - else: - converter.save_to_safetensors(path=args.output_path, half=args.half) - - -if __name__ == "__main__": - main() diff --git a/scripts/prepare_test_weights.py b/scripts/prepare_test_weights.py deleted file mode 100644 index 356718e..0000000 --- a/scripts/prepare_test_weights.py +++ /dev/null @@ -1,945 +0,0 @@ -""" -Download and convert weights for testing - -To see what weights will be downloaded and converted, run: -DRY_RUN=1 python scripts/prepare_test_weights.py -""" - -import hashlib -import os -import subprocess -import sys -from urllib.parse import urlparse - -import gdown -import requests -from tqdm import tqdm - -# Set the base directory to the parent directory of the script -project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -test_weights_dir = os.path.join(project_dir, "tests", "weights") - -previous_line = "\033[F" - -download_count = 0 -bytes_count = 0 - - -def die(message: str) -> None: - print(message, file=sys.stderr) - sys.exit(1) - - -def rel(path: str) -> str: - return os.path.relpath(path, project_dir) - - -def calc_hash(filepath: str) -> str: - with open(filepath, "rb") as f: - data = f.read() - found = hashlib.blake2b(data, digest_size=int(32 / 8)).hexdigest() - return found - - -def check_hash(path: str, expected: str) -> str: - found = calc_hash(path) - if found != expected: - die(f"❌ Invalid hash for {path} ({found} != {expected})") - return found - - -def download_file( - url: str, - dest_folder: str, - dry_run: bool | None = None, - skip_existing: bool = True, - expected_hash: str | None = None, - filename: str | None = None, -): - """ - Downloads a file - - Features: - - shows a progress bar - - skips existing files - - uses a temporary file to prevent partial downloads - - can do a dry run to check the url is valid - - displays the downloaded file hash - - """ - global download_count, bytes_count - filename = os.path.basename(urlparse(url).path) if filename is None else filename - dest_filename = os.path.join(dest_folder, filename) - temp_filename = dest_filename + ".part" - dry_run = bool(os.environ.get("DRY_RUN") == "1") if dry_run is None else dry_run - - is_downloaded = os.path.exists(dest_filename) - if is_downloaded and skip_existing: - skip_icon = "✖️ " - else: - skip_icon = "🔽" - - if dry_run: - response = requests.head(url, allow_redirects=True) - readable_size = "" - - if response.status_code == 200: - content_length = response.headers.get("content-length") - - if content_length: - size_in_bytes = int(content_length) - readable_size = human_readable_size(size_in_bytes) - download_count += 1 - bytes_count += size_in_bytes - print(f"✅{skip_icon} {response.status_code} READY {readable_size:<8} {url}") - - else: - print(f"❌{skip_icon} {response.status_code} ERROR {readable_size:<8} {url}") - return - - if skip_existing and is_downloaded: - print(f"{skip_icon}️ Skipping previously downloaded {url}") - if expected_hash is not None: - check_hash(dest_filename, expected_hash) - return - - os.makedirs(dest_folder, exist_ok=True) - - print(f"🔽 Downloading {url} => '{rel(dest_filename)}'", end="\n") - response = requests.get(url, stream=True) - if response.status_code != 200: - print(response.content[:1000]) - die(f"Failed to download {url}. Status code: {response.status_code}") - total = int(response.headers.get("content-length", 0)) - bar = tqdm( - desc=filename, - total=total, - unit="iB", - unit_scale=True, - unit_divisor=1024, - leave=False, - ) - with open(temp_filename, "wb") as f, bar: - for data in response.iter_content(chunk_size=1024 * 1000): - size = f.write(data) - bar.update(size) - - os.rename(temp_filename, dest_filename) - calculated_hash = calc_hash(dest_filename) - - print(f"{previous_line}✅ Downloaded {calculated_hash} {url} => '{rel(dest_filename)}' ") - if expected_hash is not None: - check_hash(dest_filename, expected_hash) - - -def download_files(urls: list[str], dest_folder: str): - for url in urls: - download_file(url, dest_folder) - - -def human_readable_size(size: int | float, decimal_places: int = 2) -> str: - for unit in ["B", "KB", "MB", "GB", "TB", "PB"]: - if size < 1024.0: - break - size /= 1024.0 - return f"{size:.{decimal_places}f}{unit}" # type: ignore - - -def download_sd_text_encoder(hf_repo_id: str = "runwayml/stable-diffusion-v1-5", subdir: str = "text_encoder"): - encoder_filename = "model.safetensors" if "inpainting" not in hf_repo_id else "model.fp16.safetensors" - base_url = f"https://huggingface.co/{hf_repo_id}" - download_files( - urls=[ - f"{base_url}/raw/main/{subdir}/config.json", - f"{base_url}/resolve/main/{subdir}/{encoder_filename}", - ], - dest_folder=os.path.join(test_weights_dir, hf_repo_id, subdir), - ) - - -def download_sd_tokenizer(hf_repo_id: str = "runwayml/stable-diffusion-v1-5", subdir: str = "tokenizer"): - download_files( - urls=[ - f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/merges.txt", - f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/special_tokens_map.json", - f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/tokenizer_config.json", - f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/vocab.json", - ], - dest_folder=os.path.join(test_weights_dir, hf_repo_id, subdir), - ) - - -def download_sd_base(hf_repo_id: str = "runwayml/stable-diffusion-v1-5"): - is_inpainting = "inpainting" in hf_repo_id - ext = "safetensors" if not is_inpainting else "bin" - base_folder = os.path.join(test_weights_dir, hf_repo_id) - download_file(f"https://huggingface.co/{hf_repo_id}/raw/main/model_index.json", base_folder) - download_file( - f"https://huggingface.co/{hf_repo_id}/raw/main/scheduler/scheduler_config.json", - os.path.join(base_folder, "scheduler"), - ) - - for subdir in ["unet", "vae"]: - subdir_folder = os.path.join(base_folder, subdir) - download_file(f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/config.json", subdir_folder) - download_file( - f"https://huggingface.co/{hf_repo_id}/resolve/main/{subdir}/diffusion_pytorch_model.{ext}", subdir_folder - ) - # we only need the unet for the inpainting model - if not is_inpainting: - download_sd_text_encoder(hf_repo_id, "text_encoder") - download_sd_tokenizer(hf_repo_id, "tokenizer") - - -def download_sd15(hf_repo_id: str = "runwayml/stable-diffusion-v1-5"): - download_sd_base(hf_repo_id) - base_folder = os.path.join(test_weights_dir, hf_repo_id) - - subdir = "feature_extractor" - download_file( - f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/preprocessor_config.json", - os.path.join(base_folder, subdir), - ) - - if "inpainting" not in hf_repo_id: - subdir = "safety_checker" - subdir_folder = os.path.join(base_folder, subdir) - download_file(f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/config.json", subdir_folder) - download_file(f"https://huggingface.co/{hf_repo_id}/resolve/main/{subdir}/model.safetensors", subdir_folder) - - -def download_sdxl(hf_repo_id: str = "stabilityai/stable-diffusion-xl-base-1.0"): - download_sd_base(hf_repo_id) - download_sd_text_encoder(hf_repo_id, "text_encoder_2") - download_sd_tokenizer(hf_repo_id, "tokenizer_2") - - -def download_vae_fp16_fix(): - download_files( - urls=[ - "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/raw/main/config.json", - "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/diffusion_pytorch_model.safetensors", - ], - dest_folder=os.path.join(test_weights_dir, "madebyollin", "sdxl-vae-fp16-fix"), - ) - - -def download_vae_ft_mse(): - download_files( - urls=[ - "https://huggingface.co/stabilityai/sd-vae-ft-mse/raw/main/config.json", - "https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/diffusion_pytorch_model.safetensors", - ], - dest_folder=os.path.join(test_weights_dir, "stabilityai", "sd-vae-ft-mse"), - ) - - -def download_loras(): - dest_folder = os.path.join(test_weights_dir, "loras", "pokemon-lora") - download_file( - "https://huggingface.co/pcuenq/pokemon-lora/resolve/main/pytorch_lora_weights.bin", - dest_folder, - expected_hash="89992ea6", - ) - - dest_folder = os.path.join(test_weights_dir, "loras", "dpo-lora") - download_file( - "https://huggingface.co/radames/sdxl-DPO-LoRA/resolve/main/pytorch_lora_weights.safetensors", - dest_folder, - expected_hash="a51e9144", - ) - - dest_folder = os.path.join(test_weights_dir, "loras", "sliders") - download_file("https://sliders.baulab.info/weights/xl_sliders/age.pt", dest_folder, expected_hash="908f07d3") - download_file( - "https://sliders.baulab.info/weights/xl_sliders/cartoon_style.pt", dest_folder, expected_hash="25652004" - ) - download_file("https://sliders.baulab.info/weights/xl_sliders/eyesize.pt", dest_folder, expected_hash="ee170e4d") - - dest_folder = os.path.join(test_weights_dir, "loras") - download_file( - "https://civitai.com/api/download/models/140624", - filename="Sci-fi_Environments_sdxl.safetensors", - dest_folder=dest_folder, - expected_hash="6a4afda8", - ) - download_file( - "https://civitai.com/api/download/models/135931", - filename="pixel-art-xl-v1.1.safetensors", - dest_folder=dest_folder, - expected_hash="71aaa6ca", - ) - - -def download_preprocessors(): - dest_folder = os.path.join(test_weights_dir, "carolineec", "informativedrawings") - download_file("https://huggingface.co/spaces/carolineec/informativedrawings/resolve/main/model2.pth", dest_folder) - - -def download_controlnet(): - base_folder = os.path.join(test_weights_dir, "lllyasviel") - controlnets = [ - "control_v11p_sd15_canny", - "control_v11f1p_sd15_depth", - "control_v11p_sd15_normalbae", - "control_v11p_sd15_lineart", - ] - for net in controlnets: - net_folder = os.path.join(base_folder, net) - urls = [ - f"https://huggingface.co/lllyasviel/{net}/raw/main/config.json", - f"https://huggingface.co/lllyasviel/{net}/resolve/main/diffusion_pytorch_model.safetensors", - ] - download_files(urls, net_folder) - - tile_folder = os.path.join(base_folder, "control_v11f1e_sd15_tile") - urls = [ - "https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile/raw/main/config.json", - "https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile/resolve/main/diffusion_pytorch_model.bin", - ] - download_files(urls, tile_folder) - - mfidabel_folder = os.path.join(test_weights_dir, "mfidabel", "controlnet-segment-anything") - urls = [ - "https://huggingface.co/mfidabel/controlnet-segment-anything/raw/main/config.json", - "https://huggingface.co/mfidabel/controlnet-segment-anything/resolve/main/diffusion_pytorch_model.bin", - ] - download_files(urls, mfidabel_folder) - - -def download_control_lora_fooocus(): - base_folder = os.path.join(test_weights_dir, "lllyasviel", "misc") - - download_file( - url=f"https://huggingface.co/lllyasviel/misc/resolve/main/control-lora-canny-rank128.safetensors", - dest_folder=base_folder, - expected_hash="fec9e32b", - ) - - download_file( - url=f"https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_xl_cpds_128.safetensors", - dest_folder=base_folder, - expected_hash="fc04b120", - ) - - -def download_unclip(): - base_folder = os.path.join(test_weights_dir, "stabilityai", "stable-diffusion-2-1-unclip") - download_file( - "https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/raw/main/model_index.json", base_folder - ) - image_encoder_folder = os.path.join(base_folder, "image_encoder") - urls = [ - "https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/raw/main/image_encoder/config.json", - "https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/resolve/main/image_encoder/model.safetensors", - ] - download_files(urls, image_encoder_folder) - - -def download_ip_adapter(): - base_folder = os.path.join(test_weights_dir, "h94", "IP-Adapter") - models_folder = os.path.join(base_folder, "models") - urls = [ - "https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter_sd15.bin", - "https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-plus_sd15.bin", - ] - download_files(urls, models_folder) - - sdxl_models_folder = os.path.join(base_folder, "sdxl_models") - urls = [ - "https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter_sdxl_vit-h.bin", - "https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin", - ] - download_files(urls, sdxl_models_folder) - - -def download_t5xl_fp16(): - base_folder = os.path.join(test_weights_dir, "QQGYLab", "T5XLFP16") - urls = [ - "https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/config.json", - "https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/model.safetensors", - "https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/special_tokens_map.json", - "https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/spiece.model", - "https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/tokenizer.json", - "https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/tokenizer_config.json", - ] - download_files(urls, base_folder) - - -def download_ella_adapter(): - download_t5xl_fp16() - base_folder = os.path.join(test_weights_dir, "QQGYLab", "ELLA") - download_file( - "https://huggingface.co/QQGYLab/ELLA/resolve/main/ella-sd1.5-tsc-t5xl.safetensors", - base_folder, - expected_hash="5af7b200", - ) - - -def download_t2i_adapter(): - base_folder = os.path.join(test_weights_dir, "TencentARC", "t2iadapter_depth_sd15v2") - urls = [ - "https://huggingface.co/TencentARC/t2iadapter_depth_sd15v2/raw/main/config.json", - "https://huggingface.co/TencentARC/t2iadapter_depth_sd15v2/resolve/main/diffusion_pytorch_model.bin", - ] - download_files(urls, base_folder) - - canny_sdxl_folder = os.path.join(test_weights_dir, "TencentARC", "t2i-adapter-canny-sdxl-1.0") - urls = [ - "https://huggingface.co/TencentARC/t2i-adapter-canny-sdxl-1.0/raw/main/config.json", - "https://huggingface.co/TencentARC/t2i-adapter-canny-sdxl-1.0/resolve/main/diffusion_pytorch_model.safetensors", - ] - download_files(urls, canny_sdxl_folder) - - -def download_sam(): - weights_folder = os.path.join(test_weights_dir) - download_file( - "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", weights_folder, expected_hash="06785e66" - ) - - -def download_hq_sam(): - weights_folder = os.path.join(test_weights_dir) - download_file( - "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth", weights_folder, expected_hash="66da2472" - ) - - -def download_dinov2(): - # For conversion - weights_folder = os.path.join(test_weights_dir) - urls = [ - "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth", - "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth", - "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", - "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth", - "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth", - "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth", - "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth", - "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth", - ] - download_files(urls, weights_folder) - - -def download_lcm_base(): - base_folder = os.path.join(test_weights_dir, "latent-consistency/lcm-sdxl") - download_file(f"https://huggingface.co/latent-consistency/lcm-sdxl/raw/main/config.json", base_folder) - download_file( - f"https://huggingface.co/latent-consistency/lcm-sdxl/resolve/main/diffusion_pytorch_model.safetensors", - base_folder, - ) - - -def download_lcm_lora(): - download_file( - "https://huggingface.co/latent-consistency/lcm-lora-sdxl/resolve/main/pytorch_lora_weights.safetensors", - dest_folder=test_weights_dir, - filename="sdxl-lcm-lora.safetensors", - expected_hash="6312a30a", - ) - - -def download_sdxl_lightning_base(): - base_folder = os.path.join(test_weights_dir, "ByteDance/SDXL-Lightning") - download_file( - f"https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_4step_unet.safetensors", - base_folder, - expected_hash="1b76cca3", - ) - download_file( - f"https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_1step_unet_x0.safetensors", - base_folder, - expected_hash="38e605bd", - ) - - -def download_sdxl_lightning_lora(): - download_file( - "https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_4step_lora.safetensors", - dest_folder=test_weights_dir, - expected_hash="9783edac", - ) - - -def download_ic_light(): - download_file( - "https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors", - dest_folder=test_weights_dir, - expected_hash="bce70123", - ) - - -def download_mvanet(): - fn = "Model_80.pth" - dest_folder = os.path.join(test_weights_dir, "mvanet") - dest_filename = os.path.join(dest_folder, fn) - - if os.environ.get("DRY_RUN") == "1": - return - - if os.path.exists(dest_filename): - print(f"✖️ ️ Skipping previously downloaded mvanet/{fn}") - else: - os.makedirs(dest_folder, exist_ok=True) - print(f"🔽 Downloading mvanet/{fn} => '{rel(dest_filename)}'", end="\n") - gdown.download(id="1_gabQXOF03MfXnf3EWDK1d_8wKiOemOv", output=dest_filename, quiet=True) - print(f"{previous_line}✅ Downloaded mvanet/{fn} => '{rel(dest_filename)}' ") - - check_hash(dest_filename, "b915d492") - - -def download_box_segmenter(): - download_file( - "https://huggingface.co/finegrain/finegrain-box-segmenter/resolve/v0.1/model.safetensors", - dest_folder=test_weights_dir, - filename="finegrain-box-segmenter-v0-1.safetensors", - expected_hash="e0450e8c", - ) - - -def printg(msg: str): - """print in green color""" - print("\033[92m" + msg + "\033[0m") - - -def run_conversion_script( - script_filename: str, - from_weights: str, - to_weights: str, - half: bool = False, - expected_hash: str | None = None, - additional_args: list[str] | None = None, - skip_existing: bool = True, -): - if skip_existing and expected_hash and os.path.exists(to_weights): - found_hash = check_hash(to_weights, expected_hash) - if expected_hash == found_hash: - printg(f"✖️ Skipping converted from {from_weights} to {to_weights} (hash {found_hash} confirmed) ") - return - - msg = f"Converting {from_weights} to {to_weights}" - printg(msg) - - args = ["python", f"scripts/conversion/{script_filename}", "--from", from_weights, "--to", to_weights] - if half: - args.append("--half") - if additional_args: - args.extend(additional_args) - - subprocess.run(args, check=True) - if expected_hash is not None: - found_hash = check_hash(to_weights, expected_hash) - printg(f"✅ Converted from {from_weights} to {to_weights} (hash {found_hash} confirmed) ") - else: - printg(f"✅⚠️ Converted from {from_weights} to {to_weights} (no hash check performed)") - - -def convert_sd15(): - run_conversion_script( - script_filename="convert_transformers_clip_text_model.py", - from_weights="tests/weights/runwayml/stable-diffusion-v1-5", - to_weights="tests/weights/CLIPTextEncoderL.safetensors", - half=True, - expected_hash="6c9cbc59", - ) - run_conversion_script( - "convert_diffusers_autoencoder_kl.py", - "tests/weights/runwayml/stable-diffusion-v1-5", - "tests/weights/lda.safetensors", - expected_hash="329e369c", - ) - run_conversion_script( - "convert_diffusers_unet.py", - "tests/weights/runwayml/stable-diffusion-v1-5", - "tests/weights/unet.safetensors", - half=True, - expected_hash="f81ac65a", - ) - os.makedirs("tests/weights/inpainting", exist_ok=True) - run_conversion_script( - "convert_diffusers_unet.py", - "tests/weights/runwayml/stable-diffusion-inpainting", - "tests/weights/inpainting/unet.safetensors", - half=True, - expected_hash="c07a8c61", - ) - - -def convert_sdxl(): - run_conversion_script( - "convert_transformers_clip_text_model.py", - "tests/weights/stabilityai/stable-diffusion-xl-base-1.0", - "tests/weights/DoubleCLIPTextEncoder.safetensors", - half=True, - expected_hash="7f99c30b", - additional_args=["--subfolder2", "text_encoder_2"], - ) - run_conversion_script( - "convert_diffusers_autoencoder_kl.py", - "tests/weights/stabilityai/stable-diffusion-xl-base-1.0", - "tests/weights/sdxl-lda.safetensors", - half=True, - expected_hash="7464e9dc", - ) - run_conversion_script( - "convert_diffusers_unet.py", - "tests/weights/stabilityai/stable-diffusion-xl-base-1.0", - "tests/weights/sdxl-unet.safetensors", - half=True, - expected_hash="2e5c4911", - ) - - -def convert_vae_ft_mse(): - run_conversion_script( - "convert_diffusers_autoencoder_kl.py", - "tests/weights/stabilityai/sd-vae-ft-mse", - "tests/weights/lda_ft_mse.safetensors", - half=True, - expected_hash="4d0bae7e", - ) - - -def convert_vae_fp16_fix(): - run_conversion_script( - "convert_diffusers_autoencoder_kl.py", - "tests/weights/madebyollin/sdxl-vae-fp16-fix", - "tests/weights/sdxl-lda-fp16-fix.safetensors", - additional_args=["--subfolder", "''"], - half=True, - expected_hash="98c7e998", - ) - - -def convert_preprocessors(): - subprocess.run( - [ - "curl", - "-L", - "https://raw.githubusercontent.com/carolineec/informative-drawings/main/model.py", - "-o", - "src/model.py", - ], - check=True, - ) - run_conversion_script( - "convert_informative_drawings.py", - "tests/weights/carolineec/informativedrawings/model2.pth", - "tests/weights/informative-drawings.safetensors", - expected_hash="93dca207", - ) - os.remove("src/model.py") - - -def convert_controlnet(): - os.makedirs("tests/weights/controlnet", exist_ok=True) - run_conversion_script( - "convert_diffusers_controlnet.py", - "tests/weights/lllyasviel/control_v11p_sd15_canny", - "tests/weights/controlnet/lllyasviel_control_v11p_sd15_canny.safetensors", - expected_hash="9a1a48cf", - ) - run_conversion_script( - "convert_diffusers_controlnet.py", - "tests/weights/lllyasviel/control_v11f1p_sd15_depth", - "tests/weights/controlnet/lllyasviel_control_v11f1p_sd15_depth.safetensors", - expected_hash="bbe7e5a6", - ) - run_conversion_script( - "convert_diffusers_controlnet.py", - "tests/weights/lllyasviel/control_v11p_sd15_normalbae", - "tests/weights/controlnet/lllyasviel_control_v11p_sd15_normalbae.safetensors", - expected_hash="9fa88ed5", - ) - run_conversion_script( - "convert_diffusers_controlnet.py", - "tests/weights/lllyasviel/control_v11p_sd15_lineart", - "tests/weights/controlnet/lllyasviel_control_v11p_sd15_lineart.safetensors", - expected_hash="c29e8c03", - ) - run_conversion_script( - "convert_diffusers_controlnet.py", - "tests/weights/mfidabel/controlnet-segment-anything", - "tests/weights/controlnet/mfidabel_controlnet-segment-anything.safetensors", - expected_hash="d536eebb", - ) - run_conversion_script( - "convert_diffusers_controlnet.py", - "tests/weights/lllyasviel/control_v11f1e_sd15_tile", - "tests/weights/controlnet/lllyasviel_control_v11f1e_sd15_tile.safetensors", - expected_hash="42463af8", - ) - - -def convert_unclip(): - run_conversion_script( - "convert_transformers_clip_image_model.py", - "tests/weights/stabilityai/stable-diffusion-2-1-unclip", - "tests/weights/CLIPImageEncoderH.safetensors", - half=True, - expected_hash="4ddb44d2", - ) - - -def convert_ip_adapter(): - run_conversion_script( - "convert_diffusers_ip_adapter.py", - "tests/weights/h94/IP-Adapter/models/ip-adapter_sd15.bin", - "tests/weights/ip-adapter_sd15.safetensors", - expected_hash="3fb0472e", - ) - run_conversion_script( - "convert_diffusers_ip_adapter.py", - "tests/weights/h94/IP-Adapter/sdxl_models/ip-adapter_sdxl_vit-h.bin", - "tests/weights/ip-adapter_sdxl_vit-h.safetensors", - half=True, - expected_hash="860518fe", - ) - run_conversion_script( - "convert_diffusers_ip_adapter.py", - "tests/weights/h94/IP-Adapter/models/ip-adapter-plus_sd15.bin", - "tests/weights/ip-adapter-plus_sd15.safetensors", - half=True, - expected_hash="aba8503b", - ) - run_conversion_script( - "convert_diffusers_ip_adapter.py", - "tests/weights/h94/IP-Adapter/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin", - "tests/weights/ip-adapter-plus_sdxl_vit-h.safetensors", - half=True, - expected_hash="545d5ce7", - ) - - -def convert_ella_adapter(): - os.makedirs("tests/weights/ELLA-Adapter", exist_ok=True) - run_conversion_script( - "convert_ella_adapter.py", - "tests/weights/QQGYLab/ELLA/ella-sd1.5-tsc-t5xl.safetensors", - "tests/weights/ELLA-Adapter/ella-sd1.5-tsc-t5xl.safetensors", - half=True, - expected_hash="b8244cb6", - ) - - -def convert_t2i_adapter(): - os.makedirs("tests/weights/T2I-Adapter", exist_ok=True) - run_conversion_script( - "convert_diffusers_t2i_adapter.py", - "tests/weights/TencentARC/t2iadapter_depth_sd15v2", - "tests/weights/T2I-Adapter/t2iadapter_depth_sd15v2.safetensors", - half=True, - expected_hash="bb2b3115", - ) - run_conversion_script( - "convert_diffusers_t2i_adapter.py", - "tests/weights/TencentARC/t2i-adapter-canny-sdxl-1.0", - "tests/weights/T2I-Adapter/t2i-adapter-canny-sdxl-1.0.safetensors", - half=True, - expected_hash="f07249a6", - ) - - -def convert_sam(): - run_conversion_script( - "convert_segment_anything.py", - "tests/weights/sam_vit_h_4b8939.pth", - "tests/weights/segment-anything-h.safetensors", - expected_hash="5ffb976f", - ) - - -def convert_hq_sam(): - run_conversion_script( - "convert_hq_segment_anything.py", - "tests/weights/sam_hq_vit_h.pth", - "tests/weights/refiners-sam-hq-vit-h.safetensors", - expected_hash="b2f5e79f", - ) - - -def convert_dinov2(): - run_conversion_script( - "convert_dinov2.py", - "tests/weights/dinov2_vits14_pretrain.pth", - "tests/weights/dinov2_vits14_pretrain.safetensors", - expected_hash="af000ded", - ) - run_conversion_script( - "convert_dinov2.py", - "tests/weights/dinov2_vitb14_pretrain.pth", - "tests/weights/dinov2_vitb14_pretrain.safetensors", - expected_hash="d6294087", - ) - run_conversion_script( - "convert_dinov2.py", - "tests/weights/dinov2_vitl14_pretrain.pth", - "tests/weights/dinov2_vitl14_pretrain.safetensors", - expected_hash="ddd4819f", - ) - run_conversion_script( - "convert_dinov2.py", - "tests/weights/dinov2_vitg14_pretrain.pth", - "tests/weights/dinov2_vitg14_pretrain.safetensors", - expected_hash="880c61f5", - ) - run_conversion_script( - "convert_dinov2.py", - "tests/weights/dinov2_vits14_reg4_pretrain.pth", - "tests/weights/dinov2_vits14_reg4_pretrain.safetensors", - expected_hash="080247c7", - ) - run_conversion_script( - "convert_dinov2.py", - "tests/weights/dinov2_vitb14_reg4_pretrain.pth", - "tests/weights/dinov2_vitb14_reg4_pretrain.safetensors", - expected_hash="5cd4d408", - ) - run_conversion_script( - "convert_dinov2.py", - "tests/weights/dinov2_vitl14_reg4_pretrain.pth", - "tests/weights/dinov2_vitl14_reg4_pretrain.safetensors", - expected_hash="b1221702", - ) - run_conversion_script( - "convert_dinov2.py", - "tests/weights/dinov2_vitg14_reg4_pretrain.pth", - "tests/weights/dinov2_vitg14_reg4_pretrain.safetensors", - expected_hash="639398eb", - ) - - -def convert_control_lora_fooocus(): - run_conversion_script( - "convert_fooocus_control_lora.py", - "tests/weights/lllyasviel/misc/control-lora-canny-rank128.safetensors", - "tests/weights/control-loras/refiners_control-lora-canny-rank128.safetensors", - expected_hash="4d505134", - ) - run_conversion_script( - "convert_fooocus_control_lora.py", - "tests/weights/lllyasviel/misc/fooocus_xl_cpds_128.safetensors", - "tests/weights/control-loras/refiners_fooocus_xl_cpds_128.safetensors", - expected_hash="d81aa461", - ) - - -def convert_lcm_base(): - run_conversion_script( - "convert_diffusers_unet.py", - "tests/weights/latent-consistency/lcm-sdxl", - "tests/weights/sdxl-lcm-unet.safetensors", - half=True, - expected_hash="e161b20c", - ) - - -def convert_sdxl_lightning_base(): - run_conversion_script( - "convert_diffusers_unet.py", - "tests/weights/stabilityai/stable-diffusion-xl-base-1.0", - "tests/weights/sdxl_lightning_4step_unet.safetensors", - additional_args=[ - "--override-weights", - "tests/weights/ByteDance/SDXL-Lightning/sdxl_lightning_4step_unet.safetensors", - ], - half=True, - expected_hash="cfdc46da", - ) - - run_conversion_script( - "convert_diffusers_unet.py", - "tests/weights/stabilityai/stable-diffusion-xl-base-1.0", - "tests/weights/sdxl_lightning_1step_unet_x0.safetensors", - additional_args=[ - "--override-weights", - "tests/weights/ByteDance/SDXL-Lightning/sdxl_lightning_1step_unet_x0.safetensors", - ], - half=True, - expected_hash="21166a64", - ) - - -def convert_ic_light(): - run_conversion_script( - "convert_ic_light.py", - "tests/weights/iclight_sd15_fc.safetensors", - "tests/weights/iclight_sd15_fc-refiners.safetensors", - half=False, - expected_hash="be315c1f", - ) - - -def convert_mvanet(): - run_conversion_script( - "convert_mvanet.py", - "tests/weights/mvanet/Model_80.pth", - "tests/weights/mvanet/mvanet.safetensors", - half=True, - expected_hash="bf9ae4cb", - ) - - -def download_all(): - print(f"\nAll weights will be downloaded to {test_weights_dir}\n") - download_sd15("runwayml/stable-diffusion-v1-5") - download_sd15("runwayml/stable-diffusion-inpainting") - download_sdxl("stabilityai/stable-diffusion-xl-base-1.0") - download_vae_ft_mse() - download_vae_fp16_fix() - download_loras() - download_preprocessors() - download_controlnet() - download_unclip() - download_ip_adapter() - download_t2i_adapter() - download_ella_adapter() - download_sam() - download_hq_sam() - download_dinov2() - download_control_lora_fooocus() - download_lcm_base() - download_lcm_lora() - download_sdxl_lightning_base() - download_sdxl_lightning_lora() - download_ic_light() - download_mvanet() - download_box_segmenter() - - -def convert_all(): - convert_sd15() - convert_sdxl() - convert_vae_ft_mse() - convert_vae_fp16_fix() - # Note: no convert loras: this is done at runtime by `SDLoraManager` - convert_preprocessors() - convert_controlnet() - convert_unclip() - convert_ip_adapter() - convert_t2i_adapter() - convert_ella_adapter() - convert_sam() - convert_hq_sam() - convert_dinov2() - convert_control_lora_fooocus() - convert_lcm_base() - convert_sdxl_lightning_base() - convert_ic_light() - convert_mvanet() - - -def main(): - try: - download_all() - print(f"{download_count} files ({human_readable_size(bytes_count)})\n") - if not bool(os.environ.get("DRY_RUN") == "1"): - printg("Converting weights to refiners format\n") - convert_all() - except KeyboardInterrupt: - print("Stopped") - - -if __name__ == "__main__": - main() diff --git a/src/refiners/fluxion/model_converter.py b/src/refiners/fluxion/model_converter.py deleted file mode 100644 index 3b2f8b5..0000000 --- a/src/refiners/fluxion/model_converter.py +++ /dev/null @@ -1,654 +0,0 @@ -from collections import defaultdict -from enum import Enum, auto -from pathlib import Path -from typing import Any, DefaultDict, TypedDict - -import torch -from torch import Tensor, nn -from torch.utils.hooks import RemovableHandle - -from refiners.fluxion.utils import no_grad, norm, save_to_safetensors - -TORCH_BASIC_LAYERS: list[type[nn.Module]] = [ - nn.Conv1d, - nn.Conv2d, - nn.Conv3d, - nn.ConvTranspose1d, - nn.ConvTranspose2d, - nn.ConvTranspose3d, - nn.Linear, - nn.BatchNorm1d, - nn.BatchNorm2d, - nn.BatchNorm3d, - nn.LayerNorm, - nn.GroupNorm, - nn.Embedding, - nn.MaxPool2d, - nn.AvgPool2d, - nn.AdaptiveAvgPool2d, -] - - -ModelTypeShape = tuple[type[nn.Module], tuple[torch.Size, ...]] - - -class ModuleArgsDict(TypedDict): - """Represents positional and keyword arguments passed to a module. - - - `positional`: A tuple of positional arguments. - - `keyword`: A dictionary of keyword arguments. - """ - - positional: tuple[Any, ...] - keyword: dict[str, Any] - - -class ConversionStage(Enum): - """Represents the current stage of the conversion process. - - Attributes: - INIT: The conversion process has not started. - BASIC_LAYERS_MATCH: The source and target models have the same number of basic layers. - SHAPE_AND_LAYERS_MATCH: The shape of both models agree. - MODELS_OUTPUT_AGREE: The source and target models agree. - """ - - INIT = auto() - BASIC_LAYERS_MATCH = auto() - SHAPE_AND_LAYERS_MATCH = auto() - MODELS_OUTPUT_AGREE = auto() - - -class ModelConverter: - """Converts a model's state_dict to match another model's state_dict. - - Note: The conversion process consists of three stages - 1. Verify that the source and target models have the same number of basic layers. - 2. Find matching shapes and layers between the source and target models. - 3. Convert the source model's state_dict to match the target model's state_dict. - 4. Compare the outputs of the source and target models. - - The conversion process can be run multiple times, and will resume from the last stage. - - Example: - ```py - source = ... - target = ... - - converter = ModelConverter( - source_model=source, - target_model=target, - threshold=0.1, - verbose=False - ) - - is_converted = converter(args) - if is_converted: - converter.save_to_safetensors(path="converted_model.pt") - ``` - """ - - ModuleArgs = tuple[Any, ...] | dict[str, Any] | ModuleArgsDict - stage: ConversionStage = ConversionStage.INIT - _stored_mapping: dict[str, str] | None = None - - def __init__( - self, - source_model: nn.Module, - target_model: nn.Module, - source_keys_to_skip: list[str] | None = None, - target_keys_to_skip: list[str] | None = None, - custom_layer_mapping: dict[type[nn.Module], type[nn.Module]] | None = None, - threshold: float = 1e-5, - skip_output_check: bool = False, - skip_init_check: bool = False, - verbose: bool = True, - ) -> None: - """Initializes the ModelConverter. - - Args: - source_model: The model to convert from. - target_model: The model to convert to. - source_keys_to_skip: A list of keys to skip when tracing the source model. - target_keys_to_skip: A list of keys to skip when tracing the target model. - custom_layer_mapping: A dictionary mapping custom layer types between the source and target models. - threshold: The threshold for comparing outputs between the source and target models. - skip_output_check: Whether to skip comparing the outputs of the source and target models. - skip_init_check: Whether to skip checking that the source and target models have the same number of basic - layers. - verbose: Whether to print messages during the conversion process. - - """ - self.source_model = source_model - self.target_model = target_model - self.source_keys_to_skip = source_keys_to_skip or [] - self.target_keys_to_skip = target_keys_to_skip or [] - self.custom_layer_mapping = custom_layer_mapping or {} - self.threshold = threshold - self.skip_output_check = skip_output_check - self.skip_init_check = skip_init_check - self.verbose = verbose - - def __repr__(self) -> str: - return ( - f"ModelConverter(source_model={self.source_model.__class__.__name__}," - f" target_model={self.target_model.__class__.__name__}, stage={self.stage})" - ) - - def __bool__(self) -> bool: - return self.stage.value >= 2 if self.skip_output_check else self.stage.value >= 3 - - def run(self, source_args: ModuleArgs, target_args: ModuleArgs | None = None) -> bool: - """Run the conversion process. - - Args: - source_args: The arguments to pass to the source model it can be either a tuple of positional arguments, - a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args` - is not provided, these arguments will also be passed to the target model. - target_args: The arguments to pass to the target model it can be either a tuple of positional arguments, - a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. - - Returns: - True if the conversion process is done and the models agree. - """ - if target_args is None: - target_args = source_args - - match self.stage: - case ConversionStage.MODELS_OUTPUT_AGREE: - self._increment_stage() - return True - - case ConversionStage.SHAPE_AND_LAYERS_MATCH if self._run_shape_and_layers_match_stage( - source_args=source_args, target_args=target_args - ): - self._increment_stage() - return True - - case ConversionStage.BASIC_LAYERS_MATCH if self._run_basic_layers_match_stage( - source_args=source_args, target_args=target_args - ): - self._increment_stage() - return self.run(source_args=source_args, target_args=target_args) - - case ConversionStage.INIT if self._run_init_stage(): - self._increment_stage() - return self.run(source_args=source_args, target_args=target_args) - - case _: - self._log(message=f"Conversion failed at stage {self.stage.value}") - return False - - def _increment_stage(self) -> None: - """Increment the stage of the conversion process.""" - match self.stage: - case ConversionStage.INIT: - self.stage = ConversionStage.BASIC_LAYERS_MATCH - self._log( - message=( - "Stage 0 -> 1 - Models have the same number of basic layers. Finding matching shapes and" - " layers..." - ) - ) - case ConversionStage.BASIC_LAYERS_MATCH: - self.stage = ConversionStage.SHAPE_AND_LAYERS_MATCH - self._log( - message=( - "Stage 1 -> 2 - Shape of both models agree. Applying state_dict to target model. Comparing" - " models..." - ) - ) - - case ConversionStage.SHAPE_AND_LAYERS_MATCH: - if self.skip_output_check: - self._log( - message=( - "Stage 2 - Nothing to do. Skipping output check. If you want to compare the outputs, set" - " `skip_output_check` to `False`" - ) - ) - else: - self.stage = ConversionStage.MODELS_OUTPUT_AGREE - self._log( - message=( - "Stage 2 -> 3 - Conversion is done and source and target models agree: you can export the" - " converted model using `save_to_safetensors`" - ) - ) - case ConversionStage.MODELS_OUTPUT_AGREE: - self._log( - message=( - "Stage 3 - Nothing to do. Conversion is done and source and target models agree: you can export" - " the converted model using `save_to_safetensors`" - ) - ) - - def get_state_dict(self) -> dict[str, Tensor]: - """Get the converted state_dict.""" - if not self: - raise ValueError("The conversion process is not done yet. Run `converter(args)` first.") - return self.target_model.state_dict() - - def get_mapping(self) -> dict[str, str]: - """Get the mapping between the source and target models' state_dicts.""" - if not self: - raise ValueError("The conversion process is not done yet. Run `converter(args)` first.") - assert self._stored_mapping is not None, "Mapping is not stored" - return self._stored_mapping - - def save_to_safetensors(self, path: Path | str, metadata: dict[str, str] | None = None, half: bool = False) -> None: - """Save the converted model to a SafeTensors file. - - Warning: - This method can only be called after the conversion process is done. - - Args: - path: The path to save the converted model to. - metadata: Metadata to save with the converted model. - half: Whether to save the converted model as half precision. - - Raises: - ValueError: If the conversion process is not done yet. Run `converter` first. - """ - if not self: - raise ValueError("The conversion process is not done yet. Run `converter(args)` first.") - state_dict = self.get_state_dict() - if half: - state_dict = {key: value.half() for key, value in state_dict.items()} - save_to_safetensors(path=path, tensors=state_dict, metadata=metadata) - - def map_state_dicts( - self, - source_args: ModuleArgs, - target_args: ModuleArgs | None = None, - ) -> dict[str, str] | None: - """Find a mapping between the source and target models' state_dicts. - - Args: - source_args: The arguments to pass to the source model it can be either a tuple of positional arguments, - a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args` - is not provided, these arguments will also be passed to the target model. - target_args: The arguments to pass to the target model it can be either a tuple of positional arguments, - a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. - - Returns: - A dictionary mapping keys in the target model's state_dict to keys in the source model's state_dict. - """ - if target_args is None: - target_args = source_args - - source_order = self._trace_module_execution_order( - module=self.source_model, args=source_args, keys_to_skip=self.source_keys_to_skip - ) - target_order = self._trace_module_execution_order( - module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip - ) - - if not self._assert_shapes_aligned(source_order=source_order, target_order=target_order): - return None - - mapping: dict[str, str] = {} - for source_type_shape in source_order: - source_keys = source_order[source_type_shape] - target_type_shape = source_type_shape - if not self._is_torch_basic_layer(module_type=source_type_shape[0]): - for source_custom_type, target_custom_type in self.custom_layer_mapping.items(): - if source_custom_type == source_type_shape[0]: - target_type_shape = (target_custom_type, source_type_shape[1]) - break - - target_keys = target_order[target_type_shape] - mapping.update(zip(target_keys, source_keys)) - - return mapping - - def compare_models( - self, - source_args: ModuleArgs, - target_args: ModuleArgs | None = None, - threshold: float = 1e-5, - ) -> bool: - """Compare the outputs of the source and target models. - - Args: - source_args: The arguments to pass to the source model it can be either a tuple of positional arguments, - a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args` - is not provided, these arguments will also be passed to the target model. - target_args: The arguments to pass to the target model it can be either a tuple of positional arguments, - a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. - threshold: The threshold for comparing outputs between the source and target models. - - Returns: - True if the outputs of the source and target models agree. - """ - if target_args is None: - target_args = source_args - - source_outputs = self._collect_layers_outputs( - module=self.source_model, args=source_args, keys_to_skip=self.source_keys_to_skip - ) - target_outputs = self._collect_layers_outputs( - module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip - ) - - diff, prev_source_key, prev_target_key = None, None, None - for (source_key, source_output), (target_key, target_output) in zip(source_outputs, target_outputs): - diff = norm(source_output - target_output.reshape(shape=source_output.shape)).item() - if diff > threshold: - self._log( - f"Models diverged between {prev_source_key} and {source_key}, and between {prev_target_key} and" - f" {target_key}, difference in norm: {diff}" - ) - return False - prev_source_key, prev_target_key = source_key, target_key - - self._log(message=f"Models agree. Difference in norm: {diff}") - - return True - - def _run_init_stage(self) -> bool: - """Run the init stage of the conversion process.""" - if self.skip_init_check: - self._log( - message=( - "Skipping init check. If you want to check the number of basic layers, set `skip_init_check` to" - " `False`" - ) - ) - return True - - is_count_correct = self._verify_basic_layers_count() - is_not_missing_layers = self._verify_missing_basic_layers() - - return is_count_correct and is_not_missing_layers - - def _run_basic_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool: - """Run the basic layers match stage of the conversion process.""" - mapping = self.map_state_dicts(source_args=source_args, target_args=target_args) - self._stored_mapping = mapping - if mapping is None: - self._log(message="Models do not have matching shapes.") - return False - - source_state_dict = self.source_model.state_dict() - target_state_dict = self.target_model.state_dict() - converted_state_dict = self._convert_state_dict( - source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping - ) - self.target_model.load_state_dict(state_dict=converted_state_dict) - - return True - - def _run_shape_and_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool: - """Run the shape and layers match stage of the conversion process.""" - if self.skip_output_check: - self._log( - message="Skipping output check. If you want to compare the outputs, set `skip_output_check` to `False`" - ) - return True - - try: - if self.compare_models(source_args=source_args, target_args=target_args, threshold=self.threshold): - self._log(message="Models agree. You can export the converted model using `save_to_safetensors`") - return True - else: - self._log(message="Models do not agree. Try to increase the threshold or modify the models.") - return False - except Exception as e: - self._log(message=f"An error occurred while comparing the models: {e}") - return False - - def _log(self, message: str) -> None: - """Print a message if `verbose` is `True`.""" - if self.verbose: - print(message) - - def _debug_print_shapes( - self, - shape: ModelTypeShape, - source_keys: list[str], - target_keys: list[str], - ) -> None: - """Print the shapes of the sub-modules in `source_keys` and `target_keys`.""" - self._log(message=f"{shape}") - max_len = max(len(source_keys), len(target_keys)) - for i in range(max_len): - source_key = source_keys[i] if i < len(source_keys) else "---" - target_key = target_keys[i] if i < len(target_keys) else "---" - self._log(f"\t{source_key}\t{target_key}") - - @staticmethod - def _unpack_module_args(module_args: ModuleArgs) -> tuple[tuple[Any, ...], dict[str, Any]]: - """Unpack the positional and keyword arguments passed to a module.""" - match module_args: - case tuple(positional_args): - keyword_args: dict[str, Any] = {} - case {"positional": positional_args, "keyword": keyword_args}: - pass - case _: - positional_args = () - keyword_args = dict(**module_args) - - return positional_args, keyword_args - - def _is_torch_basic_layer(self, module_type: type[nn.Module]) -> bool: - """Check if a module type is a subclass of a torch basic layer.""" - return any(issubclass(module_type, torch_basic_layer) for torch_basic_layer in TORCH_BASIC_LAYERS) - - def _infer_basic_layer_type(self, module: nn.Module) -> type[nn.Module] | None: - """Infer the type of a basic layer.""" - layer_types = ( - set(self.custom_layer_mapping.keys()) | set(self.custom_layer_mapping.values()) | set(TORCH_BASIC_LAYERS) - ) - for layer_type in layer_types: - if isinstance(module, layer_type): - return layer_type - - return None - - def get_module_signature(self, module: nn.Module) -> ModelTypeShape: - """Get the signature of a module.""" - layer_type = self._infer_basic_layer_type(module=module) - assert layer_type is not None, f"Module {module} is not a basic layer" - param_shapes = [p.shape for p in module.parameters()] - return (layer_type, tuple(param_shapes)) - - def _count_basic_layers(self, module: nn.Module) -> dict[type[nn.Module], int]: - """Count the number of basic layers in a module.""" - count: DefaultDict[type[nn.Module], int] = defaultdict(int) - for submodule in module.modules(): - layer_type = self._infer_basic_layer_type(module=submodule) - if layer_type is not None: - count[layer_type] += 1 - - return count - - def _verify_basic_layers_count(self) -> bool: - """Verify that the source and target models have the same number of basic layers.""" - source_layers = self._count_basic_layers(module=self.source_model) - target_layers = self._count_basic_layers(module=self.target_model) - - reverse_mapping = {v: k for k, v in self.custom_layer_mapping.items()} - - diff: dict[type[nn.Module], tuple[int, int]] = {} - for layer_type, source_count in source_layers.items(): - target_type = self.custom_layer_mapping.get(layer_type, layer_type) - target_count = target_layers.get(target_type, 0) - - if source_count != target_count: - diff[layer_type] = (source_count, target_count) - - for layer_type, target_count in target_layers.items(): - source_type = reverse_mapping.get(layer_type, layer_type) - source_count = source_layers.get(source_type, 0) - - if source_count != target_count: - diff[layer_type] = (source_count, target_count) - - if diff: - message = "Models do not have the same number of basic layers:\n" - for layer_type, counts in diff.items(): - message += f" {layer_type}: Source {counts[0]} - Target {counts[1]}\n" - self._log(message=message.strip()) - return False - - return True - - def _is_weighted_leaf_module(self, module: nn.Module) -> bool: - """Check if a module is a leaf module with weights.""" - return next(module.parameters(), None) is not None and next(module.children(), None) is None - - def _check_for_missing_basic_layers(self, module: nn.Module) -> list[type[nn.Module]]: - """Check if a module has weighted leaf modules that are not basic layers.""" - return [ - type(submodule) - for submodule in module.modules() - if self._is_weighted_leaf_module(module=submodule) and not self._infer_basic_layer_type(module=submodule) - ] - - def _verify_missing_basic_layers(self) -> bool: - """Verify that the source and target models do not have missing basic layers.""" - missing_source_layers = self._check_for_missing_basic_layers(module=self.source_model) - missing_target_layers = self._check_for_missing_basic_layers(module=self.target_model) - - if missing_source_layers or missing_target_layers: - self._log( - message=( - "Models might have missing basic layers. If you want to skip this check, set" - f" `skip_init_check` to `True`: {missing_source_layers}, {missing_target_layers}" - ) - ) - return False - - return True - - @no_grad() - def _trace_module_execution_order( - self, - module: nn.Module, - args: ModuleArgs, - keys_to_skip: list[str], - ) -> dict[ModelTypeShape, list[str]]: - """Execute a forward pass and store the order of execution of specific sub-modules. - - Args: - module: The module to trace. - args: The arguments to pass to the module it can be either a tuple of positional arguments, - a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. - keys_to_skip: A list of keys to skip when tracing the module. - - Returns: - A dictionary mapping the signature of each sub-module to a list of keys in the module's `named_modules` - """ - submodule_to_key: dict[nn.Module, str] = {} - execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list) - - def collect_execution_order_hook(layer: nn.Module, *_: Any) -> None: - layer_signature = self.get_module_signature(module=layer) - execution_order[layer_signature].append(submodule_to_key[layer]) - - hooks: list[RemovableHandle] = [] - named_modules: list[tuple[str, nn.Module]] = module.named_modules() # type: ignore - for name, submodule in named_modules: - if (self._infer_basic_layer_type(module=submodule) is not None) and name not in keys_to_skip: - submodule_to_key[submodule] = name # type: ignore - hook = submodule.register_forward_hook(hook=collect_execution_order_hook) - hooks.append(hook) - - positional_args, keyword_args = self._unpack_module_args(module_args=args) - module(*positional_args, **keyword_args) - - for hook in hooks: - hook.remove() - - return dict(execution_order) - - def _assert_shapes_aligned( - self, source_order: dict[ModelTypeShape, list[str]], target_order: dict[ModelTypeShape, list[str]] - ) -> bool: - """Assert that the shapes of the sub-modules in `source_order` and `target_order` are aligned.""" - model_type_shapes = set(source_order.keys()) | set(target_order.keys()) - - default_type_shapes = [ - type_shape for type_shape in model_type_shapes if self._is_torch_basic_layer(module_type=type_shape[0]) - ] - - shape_mismatched = False - - for model_type_shape in default_type_shapes: - source_keys = source_order.get(model_type_shape, []) - target_keys = target_order.get(model_type_shape, []) - - if len(source_keys) != len(target_keys): - shape_mismatched = True - self._debug_print_shapes(shape=model_type_shape, source_keys=source_keys, target_keys=target_keys) - - for source_custom_type in self.custom_layer_mapping.keys(): - # iterate over all type_shapes that have the same type as source_custom_type - for source_type_shape in [ - type_shape for type_shape in model_type_shapes if type_shape[0] == source_custom_type - ]: - source_keys = source_order.get(source_type_shape, []) - target_custom_type = self.custom_layer_mapping[source_custom_type] - target_type_shape = (target_custom_type, source_type_shape[1]) - target_keys = target_order.get(target_type_shape, []) - - if len(source_keys) != len(target_keys): - shape_mismatched = True - self._debug_print_shapes(shape=source_type_shape, source_keys=source_keys, target_keys=target_keys) - - return not shape_mismatched - - @staticmethod - def _convert_state_dict( - source_state_dict: dict[str, Tensor], target_state_dict: dict[str, Tensor], state_dict_mapping: dict[str, str] - ) -> dict[str, Tensor]: - """Convert the source model's state_dict to match the target model's state_dict.""" - converted_state_dict: dict[str, Tensor] = {} - for target_key in target_state_dict: - target_prefix, suffix = target_key.rsplit(sep=".", maxsplit=1) - source_prefix = state_dict_mapping[target_prefix] - source_key = ".".join([source_prefix, suffix]) - converted_state_dict[target_key] = source_state_dict[source_key] - - return converted_state_dict - - @no_grad() - def _collect_layers_outputs( - self, module: nn.Module, args: ModuleArgs, keys_to_skip: list[str] - ) -> list[tuple[str, Tensor]]: - """Execute a forward pass and store the output of specific sub-modules. - - Args: - module: The module to trace. - args: The arguments to pass to the module it can be either a tuple of positional arguments, - a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. - keys_to_skip: A list of keys to skip when tracing the module. - - Returns: - A list of tuples containing the key of each sub-module and its output. - - Note: - The output of each sub-module is cloned to avoid memory leaks. - """ - submodule_to_key: dict[nn.Module, str] = {} - execution_order: list[tuple[str, Tensor]] = [] - - def collect_execution_order_hook(layer: nn.Module, _: Any, output: Tensor) -> None: - execution_order.append((submodule_to_key[layer], output.clone())) - - hooks: list[RemovableHandle] = [] - named_modules: list[tuple[str, nn.Module]] = module.named_modules() # type: ignore - for name, submodule in named_modules: - if (self._infer_basic_layer_type(module=submodule) is not None) and name not in keys_to_skip: - submodule_to_key[submodule] = name # type: ignore - hook = submodule.register_forward_hook(hook=collect_execution_order_hook) - hooks.append(hook) - - positional_args, keyword_args = self._unpack_module_args(module_args=args) - module(*positional_args, **keyword_args) - - for hook in hooks: - hook.remove() - - return execution_order