mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
implement the ConvertModule class and refactor conversion scripts
This commit is contained in:
parent
3680f9d196
commit
7ca6bd0ccd
12
README.md
12
README.md
|
@ -212,9 +212,9 @@ Here is how to perform a text-to-image inference using the Stable Diffusion 1.5
|
||||||
Step 1: prepare the model weights in refiners' format:
|
Step 1: prepare the model weights in refiners' format:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python scripts/convert-clip-weights.py --output-file CLIPTextEncoderL.safetensors
|
python scripts/conversion/convert_transformers_clip_text_model.py --to clip.safetensors
|
||||||
python scripts/convert-sd-lda-weights.py --output-file lda.safetensors
|
python scripts/conversion/convert_diffusers_autoencoder_kl.py --to lda.safetensors
|
||||||
python scripts/convert-sd-unet-weights.py --output-file unet.safetensors
|
python scripts/conversion/convert_diffusers_unet.py --to unet.safetensors
|
||||||
```
|
```
|
||||||
|
|
||||||
> Note: this will download the original weights from https://huggingface.co/runwayml/stable-diffusion-v1-5 which takes some time. If you already have this repo cloned locally, use the `--from /path/to/stable-diffusion-v1-5` option instead.
|
> Note: this will download the original weights from https://huggingface.co/runwayml/stable-diffusion-v1-5 which takes some time. If you already have this repo cloned locally, use the `--from /path/to/stable-diffusion-v1-5` option instead.
|
||||||
|
@ -223,9 +223,9 @@ Step 2: download and convert a community Pokemon LoRA, e.g. [this one](https://h
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl -LO https://huggingface.co/pcuenq/pokemon-lora/resolve/main/pytorch_lora_weights.bin
|
curl -LO https://huggingface.co/pcuenq/pokemon-lora/resolve/main/pytorch_lora_weights.bin
|
||||||
python scripts/convert-lora-weights.py \
|
python scripts/conversion/convert_diffusers_lora.py \
|
||||||
--from pytorch_lora_weights.bin \
|
--from pytorch_lora_weights.bin \
|
||||||
--output-file pokemon_lora.safetensors
|
--to pokemon_lora.safetensors
|
||||||
```
|
```
|
||||||
|
|
||||||
Step 3: run inference using the GPU:
|
Step 3: run inference using the GPU:
|
||||||
|
@ -238,7 +238,7 @@ import torch
|
||||||
|
|
||||||
|
|
||||||
sd15 = StableDiffusion_1(device="cuda")
|
sd15 = StableDiffusion_1(device="cuda")
|
||||||
sd15.clip_text_encoder.load_state_dict(load_from_safetensors("CLIPTextEncoderL.safetensors"))
|
sd15.clip_text_encoder.load_state_dict(load_from_safetensors("clip.safetensors"))
|
||||||
sd15.lda.load_state_dict(load_from_safetensors("lda.safetensors"))
|
sd15.lda.load_state_dict(load_from_safetensors("lda.safetensors"))
|
||||||
sd15.unet.load_state_dict(load_from_safetensors("unet.safetensors"))
|
sd15.unet.load_state_dict(load_from_safetensors("unet.safetensors"))
|
||||||
|
|
||||||
|
|
67
scripts/conversion/convert_diffusers_autoencoder_kl.py
Normal file
67
scripts/conversion/convert_diffusers_autoencoder_kl.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from diffusers import AutoencoderKL # type: ignore
|
||||||
|
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||||
|
from refiners.fluxion.model_converter import ModelConverter
|
||||||
|
|
||||||
|
|
||||||
|
class Args(argparse.Namespace):
|
||||||
|
source_path: str
|
||||||
|
output_path: str | None
|
||||||
|
use_half: bool
|
||||||
|
verbose: bool
|
||||||
|
|
||||||
|
|
||||||
|
def setup_converter(args: Args) -> ModelConverter:
|
||||||
|
target = LatentDiffusionAutoencoder()
|
||||||
|
source: nn.Module = AutoencoderKL.from_pretrained(pretrained_model_name_or_path=args.source_path, subfolder="vae") # type: ignore
|
||||||
|
x = torch.randn(1, 3, 512, 512)
|
||||||
|
converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose)
|
||||||
|
if not converter.run(source_args=(x,)):
|
||||||
|
raise RuntimeError("Model conversion failed")
|
||||||
|
return converter
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Convert a pretrained diffusers AutoencoderKL model to a refiners Latent Diffusion Autoencoder"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--from",
|
||||||
|
type=str,
|
||||||
|
dest="source_path",
|
||||||
|
default="runwayml/stable-diffusion-v1-5",
|
||||||
|
help="Path to the source pretrained model (default: 'runwayml/stable-diffusion-v1-5').",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--to",
|
||||||
|
type=str,
|
||||||
|
dest="output_path",
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Path to save the converted model (extension will be .safetensors). If not specified, the output path will"
|
||||||
|
" be the source path with the extension changed to .safetensors."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--half",
|
||||||
|
action="store_true",
|
||||||
|
dest="use_half",
|
||||||
|
default=False,
|
||||||
|
help="Use this flag to save the output file as half precision (default: full precision).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose",
|
||||||
|
action="store_true",
|
||||||
|
dest="verbose",
|
||||||
|
default=False,
|
||||||
|
help="Use this flag to print verbose output during conversion.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args(namespace=Args())
|
||||||
|
if args.output_path is None:
|
||||||
|
args.output_path = f"{Path(args.source_path).stem}-autoencoder.safetensors"
|
||||||
|
assert args.output_path is not None
|
||||||
|
converter = setup_converter(args=args)
|
||||||
|
converter.save_to_safetensors(path=args.output_path, half=args.use_half)
|
|
@ -1,18 +1,24 @@
|
||||||
|
# pyright: reportPrivateUsage=false
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
from diffusers import ControlNetModel # type: ignore
|
from diffusers import ControlNetModel # type: ignore
|
||||||
from refiners.fluxion.utils import (
|
from refiners.fluxion.utils import save_to_safetensors
|
||||||
forward_order_of_execution,
|
from refiners.fluxion.model_converter import ModelConverter
|
||||||
verify_shape_match,
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1Controlnet
|
||||||
convert_state_dict,
|
|
||||||
save_to_safetensors,
|
|
||||||
)
|
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet
|
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
||||||
from refiners.foundationals.latent_diffusion import SD1UNet
|
from refiners.foundationals.latent_diffusion import SD1UNet
|
||||||
|
|
||||||
|
|
||||||
|
class Args(argparse.Namespace):
|
||||||
|
source_path: str
|
||||||
|
output_path: str | None
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
|
def convert(args: Args) -> dict[str, torch.Tensor]:
|
||||||
|
controlnet_src: nn.Module = ControlNetModel.from_pretrained(pretrained_model_name_or_path=args.source_path) # type: ignore
|
||||||
controlnet = SD1Controlnet(name="mycn")
|
controlnet = SD1Controlnet(name="mycn")
|
||||||
|
|
||||||
condition = torch.randn(1, 3, 512, 512)
|
condition = torch.randn(1, 3, 512, 512)
|
||||||
|
@ -33,10 +39,16 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
|
||||||
# to diffusers in order, since we compute the residuals inline instead of
|
# to diffusers in order, since we compute the residuals inline instead of
|
||||||
# in a separate step.
|
# in a separate step.
|
||||||
|
|
||||||
source_order = forward_order_of_execution(module=controlnet_src, example_args=(x, timestep, clip_text_embedding, condition)) # type: ignore
|
converter = ModelConverter(
|
||||||
target_order = forward_order_of_execution(module=controlnet, example_args=(x,))
|
source_model=controlnet_src, target_model=controlnet, skip_output_check=True, verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
broken_k = ("Conv2d", (torch.Size([320, 320, 1, 1]), torch.Size([320])))
|
source_order = converter._trace_module_execution_order(
|
||||||
|
module=controlnet_src, args=(x, timestep, clip_text_embedding, condition), keys_to_skip=[]
|
||||||
|
)
|
||||||
|
target_order = converter._trace_module_execution_order(module=controlnet, args=(x,), keys_to_skip=[])
|
||||||
|
|
||||||
|
broken_k = (str(object=nn.Conv2d), (torch.Size([320, 320, 1, 1]), torch.Size([320])))
|
||||||
|
|
||||||
expected_source_order = [
|
expected_source_order = [
|
||||||
"down_blocks.0.attentions.0.proj_in",
|
"down_blocks.0.attentions.0.proj_in",
|
||||||
|
@ -75,7 +87,7 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
|
||||||
assert target_order[broken_k] == expected_target_order
|
assert target_order[broken_k] == expected_target_order
|
||||||
source_order[broken_k] = fixed_source_order
|
source_order[broken_k] = fixed_source_order
|
||||||
|
|
||||||
broken_k = ("Conv2d", (torch.Size([640, 640, 1, 1]), torch.Size([640])))
|
broken_k = (str(object=nn.Conv2d), (torch.Size([640, 640, 1, 1]), torch.Size([640])))
|
||||||
|
|
||||||
expected_source_order = [
|
expected_source_order = [
|
||||||
"down_blocks.1.attentions.0.proj_in",
|
"down_blocks.1.attentions.0.proj_in",
|
||||||
|
@ -111,7 +123,7 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
|
||||||
assert target_order[broken_k] == expected_target_order
|
assert target_order[broken_k] == expected_target_order
|
||||||
source_order[broken_k] = fixed_source_order
|
source_order[broken_k] = fixed_source_order
|
||||||
|
|
||||||
broken_k = ("Conv2d", (torch.Size([1280, 1280, 1, 1]), torch.Size([1280])))
|
broken_k = (str(object=nn.Conv2d), (torch.Size([1280, 1280, 1, 1]), torch.Size([1280])))
|
||||||
|
|
||||||
expected_source_order = [
|
expected_source_order = [
|
||||||
"down_blocks.2.attentions.0.proj_in",
|
"down_blocks.2.attentions.0.proj_in",
|
||||||
|
@ -162,7 +174,7 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
|
||||||
assert target_order[broken_k] == expected_target_order
|
assert target_order[broken_k] == expected_target_order
|
||||||
source_order[broken_k] = fixed_source_order
|
source_order[broken_k] = fixed_source_order
|
||||||
|
|
||||||
assert verify_shape_match(source_order=source_order, target_order=target_order)
|
assert converter._assert_shapes_aligned(source_order=source_order, target_order=target_order), "Shapes not aligned"
|
||||||
|
|
||||||
mapping: dict[str, str] = {}
|
mapping: dict[str, str] = {}
|
||||||
for model_type_shape in source_order:
|
for model_type_shape in source_order:
|
||||||
|
@ -170,7 +182,7 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
|
||||||
target_keys = target_order[model_type_shape]
|
target_keys = target_order[model_type_shape]
|
||||||
mapping.update(zip(target_keys, source_keys))
|
mapping.update(zip(target_keys, source_keys))
|
||||||
|
|
||||||
state_dict = convert_state_dict(
|
state_dict = converter._convert_state_dict(
|
||||||
source_state_dict=controlnet_src.state_dict(),
|
source_state_dict=controlnet_src.state_dict(),
|
||||||
target_state_dict=controlnet.state_dict(),
|
target_state_dict=controlnet.state_dict(),
|
||||||
state_dict_mapping=mapping,
|
state_dict_mapping=mapping,
|
||||||
|
@ -180,27 +192,33 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
import argparse
|
parser = argparse.ArgumentParser(description="Convert a diffusers ControlNet model to a Refiners ControlNet model")
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--from",
|
"--from",
|
||||||
type=str,
|
type=str,
|
||||||
dest="source",
|
dest="source_path",
|
||||||
required=True,
|
default="lllyasviel/sd-controlnet-depth",
|
||||||
help="Source model",
|
help=(
|
||||||
|
"Can be a path to a .bin, a .safetensors file, or a model identifier from Hugging Face Hub. Defaults to"
|
||||||
|
" lllyasviel/sd-controlnet-depth"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output-file",
|
"--to",
|
||||||
type=str,
|
type=str,
|
||||||
|
dest="output_path",
|
||||||
required=False,
|
required=False,
|
||||||
default="output.safetensors",
|
default=None,
|
||||||
help="Path for the output file",
|
help=(
|
||||||
|
"Output path (.safetensors) for converted model. If not provided, the output path will be the same as the"
|
||||||
|
" source path."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args(namespace=Args())
|
||||||
controlnet_src = ControlNetModel.from_pretrained(pretrained_model_name_or_path=args.source) # type: ignore
|
if args.output_path is None:
|
||||||
tensors = convert(controlnet_src=controlnet_src) # type: ignore
|
args.output_path = f"{Path(args.source_path).stem}-controlnet.safetensors"
|
||||||
save_to_safetensors(path=args.output_file, tensors=tensors)
|
state_dict = convert(args=args)
|
||||||
|
save_to_safetensors(path=args.output_path, tensors=state_dict)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
|
@ -1,20 +1,17 @@
|
||||||
# Note: this conversion script currently only support simple LoRAs which adapt
|
import argparse
|
||||||
# the UNet's attentions such as https://huggingface.co/pcuenq/pokemon-lora
|
from pathlib import Path
|
||||||
|
|
||||||
from typing import cast
|
from typing import cast
|
||||||
import torch
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
from torch.nn.init import zeros_
|
from torch.nn.init import zeros_
|
||||||
from torch.nn import Parameter as TorchParameter
|
from torch.nn import Parameter as TorchParameter
|
||||||
|
from diffusers import DiffusionPipeline # type: ignore
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
|
from refiners.fluxion.model_converter import ModelConverter
|
||||||
from refiners.fluxion.utils import save_to_safetensors
|
from refiners.fluxion.utils import save_to_safetensors
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget, apply_loras_to_target
|
from refiners.foundationals.latent_diffusion.lora import LoraTarget, apply_loras_to_target
|
||||||
from refiners.adapters.lora import Lora
|
from refiners.adapters.lora import Lora
|
||||||
from refiners.fluxion.utils import create_state_dict_mapping
|
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def get_weight(linear: fl.Linear) -> torch.Tensor:
|
def get_weight(linear: fl.Linear) -> torch.Tensor:
|
||||||
|
@ -31,10 +28,17 @@ def build_loras_safetensors(module: fl.Chain, key_prefix: str) -> dict[str, torc
|
||||||
return {f"{key_prefix}{i:03d}": w for i, w in enumerate(iterable=weights)}
|
return {f"{key_prefix}{i:03d}": w for i, w in enumerate(iterable=weights)}
|
||||||
|
|
||||||
|
|
||||||
|
class Args(argparse.Namespace):
|
||||||
|
source_path: str
|
||||||
|
base_model: str
|
||||||
|
output_file: str
|
||||||
|
verbose: bool
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def process(source: str, base_model: str, output_file: str) -> None:
|
def process(args: Args) -> None:
|
||||||
diffusers_state_dict = torch.load(source, map_location="cpu") # type: ignore
|
diffusers_state_dict = cast(dict[str, Tensor], torch.load(args.source_path, map_location="cpu")) # type: ignore
|
||||||
diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=base_model) # type: ignore
|
diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.base_model) # type: ignore
|
||||||
diffusers_model = cast(fl.Module, diffusers_sd.unet) # type: ignore
|
diffusers_model = cast(fl.Module, diffusers_sd.unet) # type: ignore
|
||||||
|
|
||||||
refiners_model = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
refiners_model = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
||||||
|
@ -54,10 +58,16 @@ def process(source: str, base_model: str, output_file: str) -> None:
|
||||||
|
|
||||||
diffusers_args = (x, timestep, clip_text_embeddings)
|
diffusers_args = (x, timestep, clip_text_embeddings)
|
||||||
|
|
||||||
diffusers_to_refiners = create_state_dict_mapping(
|
converter = ModelConverter(
|
||||||
source_model=refiners_model, target_model=diffusers_model, source_args=refiners_args, target_args=diffusers_args
|
source_model=refiners_model, target_model=diffusers_model, skip_output_check=True, verbose=args.verbose
|
||||||
)
|
)
|
||||||
assert diffusers_to_refiners is not None, "Model conversion failed"
|
if not converter.run(
|
||||||
|
source_args=refiners_args,
|
||||||
|
target_args=diffusers_args,
|
||||||
|
):
|
||||||
|
raise RuntimeError("Model conversion failed")
|
||||||
|
|
||||||
|
diffusers_to_refiners = converter.get_mapping()
|
||||||
|
|
||||||
apply_loras_to_target(module=refiners_model, target=LoraTarget(target), rank=rank, scale=1.0)
|
apply_loras_to_target(module=refiners_model, target=LoraTarget(target), rank=rank, scale=1.0)
|
||||||
for layer in refiners_model.layers(layer_type=Lora):
|
for layer in refiners_model.layers(layer_type=Lora):
|
||||||
|
@ -83,36 +93,47 @@ def process(source: str, base_model: str, output_file: str) -> None:
|
||||||
|
|
||||||
state_dict = build_loras_safetensors(module=refiners_model, key_prefix="unet.")
|
state_dict = build_loras_safetensors(module=refiners_model, key_prefix="unet.")
|
||||||
assert len(state_dict) == 320
|
assert len(state_dict) == 320
|
||||||
save_to_safetensors(path=output_file, tensors=state_dict, metadata=metadata)
|
save_to_safetensors(path=args.output_path, tensors=state_dict, metadata=metadata)
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
import argparse
|
parser = argparse.ArgumentParser(description="Convert LoRAs saved using the diffusers library to refiners format.")
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--from",
|
"--from",
|
||||||
type=str,
|
type=str,
|
||||||
dest="source",
|
dest="source_path",
|
||||||
required=True,
|
required=True,
|
||||||
help="Source file path (.bin)",
|
help="Source file path (.bin|safetensors) containing the LoRAs.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--base-model",
|
"--base-model",
|
||||||
type=str,
|
type=str,
|
||||||
required=False,
|
required=False,
|
||||||
default="runwayml/stable-diffusion-v1-5",
|
default="runwayml/stable-diffusion-v1-5",
|
||||||
help="Base model",
|
help="Base model, used for the UNet structure. Default: runwayml/stable-diffusion-v1-5",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output-file",
|
"--to",
|
||||||
type=str,
|
type=str,
|
||||||
|
dest="output_path",
|
||||||
required=False,
|
required=False,
|
||||||
default="output.safetensors",
|
default=None,
|
||||||
help="Path for the output file",
|
help=(
|
||||||
|
"Output file path (.safetensors) for converted LoRAs. If not provided, the output path will be the same as"
|
||||||
|
" the source path."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
parser.add_argument(
|
||||||
process(source=args.source, base_model=args.base_model, output_file=args.output_file)
|
"--verbose",
|
||||||
|
action="store_true",
|
||||||
|
dest="verbose",
|
||||||
|
default=False,
|
||||||
|
help="Use this flag to print verbose output during conversion.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args(namespace=Args())
|
||||||
|
if args.output_path is None:
|
||||||
|
args.output_path = f"{Path(args.source_path).stem}-refiners.safetensors"
|
||||||
|
process(args=args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
96
scripts/conversion/convert_diffusers_unet.py
Normal file
96
scripts/conversion/convert_diffusers_unet.py
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from refiners.fluxion.model_converter import ModelConverter
|
||||||
|
from diffusers import UNet2DConditionModel # type: ignore
|
||||||
|
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
|
||||||
|
|
||||||
|
|
||||||
|
class Args(argparse.Namespace):
|
||||||
|
source_path: str
|
||||||
|
output_path: str | None
|
||||||
|
half: bool
|
||||||
|
verbose: bool
|
||||||
|
|
||||||
|
|
||||||
|
def setup_converter(args: Args) -> ModelConverter:
|
||||||
|
source: nn.Module = UNet2DConditionModel.from_pretrained( # type: ignore
|
||||||
|
pretrained_model_name_or_path=args.source_path, subfolder="unet"
|
||||||
|
)
|
||||||
|
source_in_channels: int = source.config.in_channels # type: ignore
|
||||||
|
source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore
|
||||||
|
source_has_time_ids: bool = source.config.addition_embed_type == "text_time" # type: ignore
|
||||||
|
target = (
|
||||||
|
SDXLUNet(in_channels=source_in_channels)
|
||||||
|
if source_has_time_ids
|
||||||
|
else SD1UNet(in_channels=source_in_channels, clip_embedding_dim=source_clip_embedding_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
x = torch.randn(1, source_in_channels, 32, 32)
|
||||||
|
timestep = torch.tensor(data=[0])
|
||||||
|
clip_text_embeddings = torch.randn(1, 77, source_clip_embedding_dim)
|
||||||
|
|
||||||
|
target.set_timestep(timestep=timestep)
|
||||||
|
target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
|
||||||
|
added_cond_kwargs = {}
|
||||||
|
if source_has_time_ids:
|
||||||
|
added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)}
|
||||||
|
target.set_time_ids(time_ids=added_cond_kwargs["time_ids"])
|
||||||
|
target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
|
||||||
|
|
||||||
|
target_args = (x,)
|
||||||
|
source_args = {
|
||||||
|
"positional": (x, timestep, clip_text_embeddings),
|
||||||
|
"keyword": {"added_cond_kwargs": added_cond_kwargs} if source_has_time_ids else {},
|
||||||
|
}
|
||||||
|
|
||||||
|
converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose)
|
||||||
|
if not converter.run(
|
||||||
|
source_args=source_args,
|
||||||
|
target_args=target_args,
|
||||||
|
):
|
||||||
|
raise RuntimeError("Model conversion failed")
|
||||||
|
return converter
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Converts a Diffusion UNet model to a Refiners SD1UNet or SDXLUNet model"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--from",
|
||||||
|
type=str,
|
||||||
|
dest="source_path",
|
||||||
|
default="runwayml/stable-diffusion-v1-5",
|
||||||
|
help=(
|
||||||
|
"Can be a path to a .bin file, a .safetensors file or a model name from the HuggingFace Hub. Default:"
|
||||||
|
" runwayml/stable-diffusion-v1-5"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--to",
|
||||||
|
type=str,
|
||||||
|
dest="output_path",
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Output path (.safetensors) for converted model. If not provided, the output path will be the same as the"
|
||||||
|
" source path."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument("--half", action="store_true", default=True, help="Convert to half precision. Default: True")
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Prints additional information during conversion. Default: False",
|
||||||
|
)
|
||||||
|
args = parser.parse_args(namespace=Args())
|
||||||
|
if args.output_path is None:
|
||||||
|
args.output_path = f"{Path(args.source_path).stem}-unet.safetensors"
|
||||||
|
converter = setup_converter(args=args)
|
||||||
|
converter.save_to_safetensors(path=args.output_path, half=args.half)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
64
scripts/conversion/convert_informative_drawings.py
Normal file
64
scripts/conversion/convert_informative_drawings.py
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
import argparse
|
||||||
|
from typing import TYPE_CHECKING, cast
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from refiners.fluxion.model_converter import ModelConverter
|
||||||
|
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
|
||||||
|
|
||||||
|
try:
|
||||||
|
from model import Generator # type: ignore
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Please download the model.py file from https://github.com/carolineec/informative-drawings and add it to your"
|
||||||
|
" PYTHONPATH"
|
||||||
|
)
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
Generator = cast(nn.Module, Generator)
|
||||||
|
|
||||||
|
|
||||||
|
class Args(argparse.Namespace):
|
||||||
|
source_path: str
|
||||||
|
output_path: str
|
||||||
|
verbose: bool
|
||||||
|
half: bool
|
||||||
|
|
||||||
|
|
||||||
|
def setup_converter(args: Args) -> ModelConverter:
|
||||||
|
source = Generator(3, 1, 3)
|
||||||
|
source.load_state_dict(state_dict=torch.load(f=args.source_path, map_location="cpu")) # type: ignore
|
||||||
|
source.eval()
|
||||||
|
target = InformativeDrawings()
|
||||||
|
x = torch.randn(1, 3, 512, 512)
|
||||||
|
converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose)
|
||||||
|
if not converter.run(source_args=(x,)):
|
||||||
|
raise RuntimeError("Model conversion failed")
|
||||||
|
return converter
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Converts a pretrained Informative Drawings model to a refiners Informative Drawings model"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--from",
|
||||||
|
type=str,
|
||||||
|
dest="source_path",
|
||||||
|
default="model2.pth",
|
||||||
|
help="Path to the source model. (default: 'model2.pth').",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--to",
|
||||||
|
type=str,
|
||||||
|
dest="output_path",
|
||||||
|
default="informative-drawings.safetensors",
|
||||||
|
help="Path to save the converted model. (default: 'informative-drawings.safetensors').",
|
||||||
|
)
|
||||||
|
parser.add_argument("--verbose", action="store_true", dest="verbose")
|
||||||
|
parser.add_argument("--half", action="store_true", dest="half")
|
||||||
|
args = parser.parse_args(namespace=Args())
|
||||||
|
converter = setup_converter(args=args)
|
||||||
|
converter.save_to_safetensors(path=args.output_path, half=args.half)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -1,49 +1,36 @@
|
||||||
|
import argparse
|
||||||
|
from functools import partial
|
||||||
|
from torch import Tensor
|
||||||
from refiners.fluxion.utils import (
|
from refiners.fluxion.utils import (
|
||||||
load_from_safetensors,
|
load_from_safetensors,
|
||||||
load_metadata_from_safetensors,
|
load_metadata_from_safetensors,
|
||||||
save_to_safetensors,
|
save_to_safetensors,
|
||||||
)
|
)
|
||||||
|
from convert_diffusers_unet import setup_converter as convert_unet, Args as UnetConversionArgs
|
||||||
|
from convert_transformers_clip_text_model import (
|
||||||
|
setup_converter as convert_text_encoder,
|
||||||
|
Args as TextEncoderConversionArgs,
|
||||||
|
)
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
|
||||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget
|
from refiners.foundationals.latent_diffusion.lora import LoraTarget
|
||||||
from refiners.fluxion.layers.module import Module
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1UNet
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.utils import create_state_dict_mapping
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline # type: ignore
|
|
||||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
|
|
||||||
from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
def get_unet_mapping(source_path: str) -> dict[str, str]:
|
||||||
def create_unet_mapping(src_model: UNet2DConditionModel, dst_model: SD1UNet) -> dict[str, str] | None:
|
args = UnetConversionArgs(source_path=source_path, verbose=False)
|
||||||
x = torch.randn(1, 4, 32, 32)
|
return convert_unet(args=args).get_mapping()
|
||||||
timestep = torch.tensor(data=[0])
|
|
||||||
clip_text_embeddings = torch.randn(1, 77, 768)
|
|
||||||
|
|
||||||
src_args = (x, timestep, clip_text_embeddings)
|
|
||||||
dst_model.set_timestep(timestep=timestep)
|
|
||||||
dst_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
|
|
||||||
dst_args = (x,)
|
|
||||||
|
|
||||||
return create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=src_args, target_args=dst_args) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
def get_text_encoder_mapping(source_path: str) -> dict[str, str]:
|
||||||
def create_text_encoder_mapping(src_model: CLIPTextModel, dst_model: CLIPTextEncoderL) -> dict[str, str] | None:
|
args = TextEncoderConversionArgs(source_path=source_path, subfolder="text_encoder", verbose=False)
|
||||||
tokenizer = dst_model.find(layer_type=CLIPTokenizer)
|
return convert_text_encoder(
|
||||||
assert tokenizer is not None, "Could not find tokenizer"
|
args=args,
|
||||||
tokens = tokenizer("Nice cat")
|
).get_mapping()
|
||||||
return create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[tokens], target_args=["Nice cat"]) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
import argparse
|
parser = argparse.ArgumentParser(description="Converts a refiner's LoRA weights to SD-WebUI's LoRA weights")
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-i",
|
"-i",
|
||||||
"--input-file",
|
"--input-file",
|
||||||
|
@ -55,7 +42,7 @@ def main() -> None:
|
||||||
"-o",
|
"-o",
|
||||||
"--output-file",
|
"--output-file",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
default="sdwebui_loras.safetensors",
|
||||||
help="Path to the output file with sd-webui's LoRA weights (safetensors format)",
|
help="Path to the output file with sd-webui's LoRA weights (safetensors format)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -66,27 +53,22 @@ def main() -> None:
|
||||||
help="Path (preferred) or repository ID of Stable Diffusion 1.5 model (Hugging Face diffusers format)",
|
help="Path (preferred) or repository ID of Stable Diffusion 1.5 model (Hugging Face diffusers format)",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
metadata = load_metadata_from_safetensors(path=args.input_file)
|
metadata = load_metadata_from_safetensors(path=args.input_file)
|
||||||
assert metadata is not None
|
assert metadata is not None, f"Could not load metadata from {args.input_file}"
|
||||||
tensors = load_from_safetensors(path=args.input_file)
|
tensors = load_from_safetensors(path=args.input_file)
|
||||||
|
|
||||||
diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.sd15) # type: ignore
|
state_dict: dict[str, Tensor] = {}
|
||||||
|
|
||||||
state_dict: dict[str, torch.Tensor] = {}
|
|
||||||
|
|
||||||
for meta_key, meta_value in metadata.items():
|
for meta_key, meta_value in metadata.items():
|
||||||
match meta_key:
|
match meta_key:
|
||||||
case "unet_targets":
|
case "unet_targets":
|
||||||
src_model = diffusers_sd.unet # type: ignore
|
model = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
||||||
dst_model = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
create_mapping = partial(get_unet_mapping, source_path=args.sd15)
|
||||||
create_mapping = create_unet_mapping
|
|
||||||
key_prefix = "unet."
|
key_prefix = "unet."
|
||||||
lora_prefix = "lora_unet_"
|
lora_prefix = "lora_unet_"
|
||||||
case "text_encoder_targets":
|
case "text_encoder_targets":
|
||||||
src_model = diffusers_sd.text_encoder # type: ignore
|
model = CLIPTextEncoderL()
|
||||||
dst_model = CLIPTextEncoderL()
|
create_mapping = partial(get_text_encoder_mapping, source_path=args.sd15)
|
||||||
create_mapping = create_text_encoder_mapping
|
|
||||||
key_prefix = "text_encoder."
|
key_prefix = "text_encoder."
|
||||||
lora_prefix = "lora_te_"
|
lora_prefix = "lora_te_"
|
||||||
case "lda_targets":
|
case "lda_targets":
|
||||||
|
@ -94,8 +76,8 @@ def main() -> None:
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unexpected key in checkpoint metadata: {meta_key}")
|
raise ValueError(f"Unexpected key in checkpoint metadata: {meta_key}")
|
||||||
|
|
||||||
submodule_to_key: dict[Module, str] = {}
|
submodule_to_key: dict[fl.Module, str] = {}
|
||||||
for name, submodule in dst_model.named_modules():
|
for name, submodule in model.named_modules():
|
||||||
submodule_to_key[submodule] = name
|
submodule_to_key[submodule] = name
|
||||||
|
|
||||||
# SD-WebUI expects LoRA state dicts with keys derived from the diffusers format, e.g.:
|
# SD-WebUI expects LoRA state dicts with keys derived from the diffusers format, e.g.:
|
||||||
|
@ -110,14 +92,14 @@ def main() -> None:
|
||||||
#
|
#
|
||||||
# [1]: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/394ffa7/extensions-builtin/Lora/lora.py#L158-L225
|
# [1]: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/394ffa7/extensions-builtin/Lora/lora.py#L158-L225
|
||||||
|
|
||||||
refiners_to_diffusers = create_mapping(src_model, dst_model) # type: ignore
|
refiners_to_diffusers = create_mapping()
|
||||||
assert refiners_to_diffusers is not None
|
assert refiners_to_diffusers is not None
|
||||||
|
|
||||||
# Compute the corresponding diffusers' keys where LoRA layers must be applied
|
# Compute the corresponding diffusers' keys where LoRA layers must be applied
|
||||||
lora_injection_points: list[str] = [
|
lora_injection_points: list[str] = [
|
||||||
refiners_to_diffusers[submodule_to_key[linear]]
|
refiners_to_diffusers[submodule_to_key[linear]]
|
||||||
for target in [LoraTarget(t) for t in meta_value.split(sep=",")]
|
for target in [LoraTarget(t) for t in meta_value.split(sep=",")]
|
||||||
for layer in dst_model.layers(layer_type=target.get_class())
|
for layer in model.layers(layer_type=target.get_class())
|
||||||
for linear in layer.layers(layer_type=fl.Linear)
|
for linear in layer.layers(layer_type=fl.Linear)
|
||||||
]
|
]
|
||||||
|
|
104
scripts/conversion/convert_transformers_clip_text_model.py
Normal file
104
scripts/conversion/convert_transformers_clip_text_model.py
Normal file
|
@ -0,0 +1,104 @@
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from torch import nn
|
||||||
|
from refiners.fluxion.model_converter import ModelConverter
|
||||||
|
from transformers import CLIPTextModelWithProjection # type: ignore
|
||||||
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder
|
||||||
|
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
|
||||||
|
class Args(argparse.Namespace):
|
||||||
|
source_path: str
|
||||||
|
subfolder: str
|
||||||
|
output_path: str | None
|
||||||
|
use_half: bool
|
||||||
|
verbose: bool
|
||||||
|
|
||||||
|
|
||||||
|
def setup_converter(args: Args) -> ModelConverter:
|
||||||
|
source: nn.Module = CLIPTextModelWithProjection.from_pretrained( # type: ignore
|
||||||
|
pretrained_model_name_or_path=args.source_path, subfolder=args.subfolder
|
||||||
|
)
|
||||||
|
assert isinstance(source, nn.Module), "Source model is not a nn.Module"
|
||||||
|
architecture: str = source.config.architectures[0] # type: ignore
|
||||||
|
embedding_dim: int = source.config.hidden_size # type: ignore
|
||||||
|
projection_dim: int = source.config.projection_dim # type: ignore
|
||||||
|
num_layers: int = source.config.num_hidden_layers # type: ignore
|
||||||
|
num_attention_heads: int = source.config.num_attention_heads # type: ignore
|
||||||
|
feed_forward_dim: int = source.config.intermediate_size # type: ignore
|
||||||
|
use_quick_gelu: bool = source.config.hidden_act == "quick_gelu" # type: ignore
|
||||||
|
target = CLIPTextEncoder(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
num_layers=num_layers,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
feedforward_dim=feed_forward_dim,
|
||||||
|
use_quick_gelu=use_quick_gelu,
|
||||||
|
)
|
||||||
|
match architecture:
|
||||||
|
case "CLIPTextModel":
|
||||||
|
source.text_projection = fl.Identity()
|
||||||
|
case "CLIPTextModelWithProjection":
|
||||||
|
target.append(module=fl.Linear(in_features=embedding_dim, out_features=projection_dim, bias=False))
|
||||||
|
case _:
|
||||||
|
raise RuntimeError(f"Unsupported architecture: {architecture}")
|
||||||
|
text = "What a nice cat you have there!"
|
||||||
|
tokenizer = target.find(layer_type=CLIPTokenizer)
|
||||||
|
assert tokenizer is not None, "Could not find tokenizer"
|
||||||
|
tokens = tokenizer(text)
|
||||||
|
converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose)
|
||||||
|
if not converter.run(source_args=(tokens,), target_args=(text,)):
|
||||||
|
raise RuntimeError("Model conversion failed")
|
||||||
|
return converter
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Converts a CLIPTextEncoder from the library transformers from the HuggingFace Hub to refiners."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--from",
|
||||||
|
type=str,
|
||||||
|
dest="source_path",
|
||||||
|
default="runwayml/stable-diffusion-v1-5",
|
||||||
|
help=(
|
||||||
|
"Can be a path to a .bin file, a .safetensors file or a model name from the HuggingFace Hub. Default:"
|
||||||
|
" runwayml/stable-diffusion-v1-5"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--subfolder",
|
||||||
|
type=str,
|
||||||
|
dest="subfolder",
|
||||||
|
default="text_encoder",
|
||||||
|
help=(
|
||||||
|
"Subfolder in the source path where the model is located inside the Hub. Default: text_encoder (for"
|
||||||
|
" CLIPTextModel)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--to",
|
||||||
|
type=str,
|
||||||
|
dest="output_path",
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Output path (.safetensors) for converted model. If not provided, the output path will be the same as the"
|
||||||
|
" source path."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument("--half", action="store_true", default=True, help="Convert to half precision. Default: True")
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Prints additional information during conversion. Default: False",
|
||||||
|
)
|
||||||
|
args = parser.parse_args(namespace=Args())
|
||||||
|
if args.output_path is None:
|
||||||
|
args.output_path = f"{Path(args.source_path).stem}-{args.subfolder}.safetensors"
|
||||||
|
converter = setup_converter(args=args)
|
||||||
|
converter.save_to_safetensors(path=args.output_path, half=args.half)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -1,52 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict, save_to_safetensors
|
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline # type: ignore
|
|
||||||
from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
|
|
||||||
|
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
|
||||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]:
|
|
||||||
dst_model = CLIPTextEncoderL()
|
|
||||||
tokenizer = dst_model.find(layer_type=CLIPTokenizer)
|
|
||||||
assert tokenizer is not None, "Could not find tokenizer"
|
|
||||||
tokens = tokenizer("Nice cat")
|
|
||||||
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[tokens], target_args=["Nice cat"]) # type: ignore
|
|
||||||
assert mapping is not None, "Model conversion failed"
|
|
||||||
state_dict = convert_state_dict(
|
|
||||||
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
|
|
||||||
)
|
|
||||||
return {k: v.half() for k, v in state_dict.items()}
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--from",
|
|
||||||
type=str,
|
|
||||||
dest="source",
|
|
||||||
required=False,
|
|
||||||
default="runwayml/stable-diffusion-v1-5",
|
|
||||||
help="Source model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-file",
|
|
||||||
type=str,
|
|
||||||
required=False,
|
|
||||||
default="CLIPTextEncoderL.safetensors",
|
|
||||||
help="Path for the output file",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).text_encoder # type: ignore
|
|
||||||
tensors = convert(src_model=src_model) # type: ignore
|
|
||||||
save_to_safetensors(path=args.output_file, tensors=tensors)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -1,57 +0,0 @@
|
||||||
# Original weights can be found here: https://huggingface.co/spaces/carolineec/informativedrawings
|
|
||||||
# Code is at https://github.com/carolineec/informative-drawings
|
|
||||||
# Copy `model.py` in your `PYTHONPATH`. You can edit it to remove un-necessary code
|
|
||||||
# and imports if you want, we only need `Generator`.
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict, save_to_safetensors
|
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
|
|
||||||
from model import Generator # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def convert(checkpoint: str, device: torch.device | str) -> dict[str, torch.Tensor]:
|
|
||||||
src_model = Generator(3, 1, 3) # type: ignore
|
|
||||||
src_model.load_state_dict(torch.load(checkpoint, map_location=device)) # type: ignore
|
|
||||||
src_model.eval() # type: ignore
|
|
||||||
|
|
||||||
dst_model = InformativeDrawings()
|
|
||||||
|
|
||||||
x = torch.randn(1, 3, 512, 512)
|
|
||||||
|
|
||||||
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore
|
|
||||||
assert mapping is not None, "Model conversion failed"
|
|
||||||
state_dict = convert_state_dict(source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping) # type: ignore
|
|
||||||
return {k: v.half() for k, v in state_dict.items()}
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--from",
|
|
||||||
type=str,
|
|
||||||
dest="source",
|
|
||||||
required=False,
|
|
||||||
default="model2.pth",
|
|
||||||
help="Source model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-file",
|
|
||||||
type=str,
|
|
||||||
required=False,
|
|
||||||
default="informative-drawings.safetensors",
|
|
||||||
help="Path for the output file",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
|
|
||||||
tensors = convert(checkpoint=args.source, device=device)
|
|
||||||
save_to_safetensors(path=args.output_file, tensors=tensors)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -1,53 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
from refiners.fluxion.utils import (
|
|
||||||
create_state_dict_mapping,
|
|
||||||
convert_state_dict,
|
|
||||||
save_to_safetensors,
|
|
||||||
)
|
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline # type: ignore
|
|
||||||
from diffusers.models.autoencoder_kl import AutoencoderKL # type: ignore
|
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def convert(src_model: AutoencoderKL) -> dict[str, torch.Tensor]:
|
|
||||||
dst_model = LatentDiffusionAutoencoder()
|
|
||||||
x = torch.randn(1, 3, 512, 512)
|
|
||||||
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore
|
|
||||||
assert mapping is not None, "Model conversion failed"
|
|
||||||
state_dict = convert_state_dict(
|
|
||||||
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
|
|
||||||
)
|
|
||||||
return {k: v.half() for k, v in state_dict.items()}
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--from",
|
|
||||||
type=str,
|
|
||||||
dest="source",
|
|
||||||
required=False,
|
|
||||||
default="runwayml/stable-diffusion-v1-5",
|
|
||||||
help="Source model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-file",
|
|
||||||
type=str,
|
|
||||||
required=False,
|
|
||||||
default="lda.safetensors",
|
|
||||||
help="Path for the output file",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).vae # type: ignore
|
|
||||||
tensors = convert(src_model=src_model) # type: ignore
|
|
||||||
save_to_safetensors(path=args.output_file, tensors=tensors)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -1,58 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict, save_to_safetensors
|
|
||||||
|
|
||||||
from diffusers import StableDiffusionInpaintPipeline # type: ignore
|
|
||||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
|
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
|
|
||||||
dst_model = SD1UNet(in_channels=9, clip_embedding_dim=768)
|
|
||||||
|
|
||||||
x = torch.randn(1, 9, 32, 32)
|
|
||||||
timestep = torch.tensor(data=[0])
|
|
||||||
clip_text_embeddings = torch.randn(1, 77, 768)
|
|
||||||
|
|
||||||
src_args = (x, timestep, clip_text_embeddings)
|
|
||||||
dst_model.set_timestep(timestep=timestep)
|
|
||||||
dst_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
|
|
||||||
dst_args = (x,)
|
|
||||||
|
|
||||||
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=src_args, target_args=dst_args) # type: ignore
|
|
||||||
assert mapping is not None, "Model conversion failed"
|
|
||||||
state_dict = convert_state_dict(
|
|
||||||
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
|
|
||||||
)
|
|
||||||
return {k: v.half() for k, v in state_dict.items()}
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--from",
|
|
||||||
type=str,
|
|
||||||
dest="source",
|
|
||||||
required=False,
|
|
||||||
default="runwayml/stable-diffusion-inpainting",
|
|
||||||
help="Source model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-file",
|
|
||||||
type=str,
|
|
||||||
required=False,
|
|
||||||
default="stable_diffusion_1_5_inpainting_unet.safetensors",
|
|
||||||
help="Path for the output file",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
src_model = StableDiffusionInpaintPipeline.from_pretrained(pretrained_model_name_or_path=args.source).unet # type: ignore
|
|
||||||
tensors = convert(src_model=src_model) # type: ignore
|
|
||||||
save_to_safetensors(path=args.output_file, tensors=tensors)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -1,58 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict, save_to_safetensors
|
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline # type: ignore
|
|
||||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
|
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
|
|
||||||
dst_model = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
|
||||||
|
|
||||||
x = torch.randn(1, 4, 32, 32)
|
|
||||||
timestep = torch.tensor(data=[0])
|
|
||||||
clip_text_embeddings = torch.randn(1, 77, 768)
|
|
||||||
|
|
||||||
src_args = (x, timestep, clip_text_embeddings)
|
|
||||||
dst_model.set_timestep(timestep=timestep)
|
|
||||||
dst_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
|
|
||||||
dst_args = (x,)
|
|
||||||
|
|
||||||
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=src_args, target_args=dst_args) # type: ignore
|
|
||||||
assert mapping is not None, "Model conversion failed"
|
|
||||||
state_dict = convert_state_dict(
|
|
||||||
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
|
|
||||||
)
|
|
||||||
return {k: v.half() for k, v in state_dict.items()}
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--from",
|
|
||||||
type=str,
|
|
||||||
dest="source",
|
|
||||||
required=False,
|
|
||||||
default="runwayml/stable-diffusion-v1-5",
|
|
||||||
help="Source model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-file",
|
|
||||||
type=str,
|
|
||||||
required=False,
|
|
||||||
default="stable_diffusion_1_5_unet.safetensors",
|
|
||||||
help="Path for the output file",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).unet # type: ignore
|
|
||||||
tensors = convert(src_model=src_model) # type: ignore
|
|
||||||
save_to_safetensors(path=args.output_file, tensors=tensors)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -1,57 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
from safetensors.torch import save_file # type: ignore
|
|
||||||
from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict
|
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline # type: ignore
|
|
||||||
from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
|
|
||||||
|
|
||||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG
|
|
||||||
import refiners.fluxion.layers as fl
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]:
|
|
||||||
dst_model = CLIPTextEncoderG()
|
|
||||||
# Extra projection layer (see CLIPTextModelWithProjection in transformers)
|
|
||||||
dst_model.append(module=fl.Linear(in_features=1280, out_features=1280, bias=False))
|
|
||||||
tokenizer = dst_model.find(layer_type=CLIPTokenizer)
|
|
||||||
assert tokenizer is not None, "Could not find tokenizer"
|
|
||||||
tokens = tokenizer("Nice cat")
|
|
||||||
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[tokens], target_args=["Nice cat"]) # type: ignore
|
|
||||||
if mapping is None:
|
|
||||||
raise RuntimeError("Could not create state dict mapping")
|
|
||||||
state_dict = convert_state_dict(
|
|
||||||
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
|
|
||||||
)
|
|
||||||
return {k: v.half() for k, v in state_dict.items()}
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--from",
|
|
||||||
type=str,
|
|
||||||
dest="source",
|
|
||||||
required=False,
|
|
||||||
default="stabilityai/stable-diffusion-xl-base-0.9",
|
|
||||||
help="Source model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-file",
|
|
||||||
type=str,
|
|
||||||
required=False,
|
|
||||||
default="CLIPTextEncoderG.safetensors",
|
|
||||||
help="Path for the output file",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).text_encoder_2 # type: ignore
|
|
||||||
tensors = convert(src_model=src_model) # type: ignore
|
|
||||||
save_file(tensors=tensors, filename=args.output_file)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -1,65 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
from safetensors.torch import save_file # type: ignore
|
|
||||||
from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict
|
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline # type: ignore
|
|
||||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
|
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
|
|
||||||
dst_model = SDXLUNet(in_channels=4)
|
|
||||||
|
|
||||||
x = torch.randn(1, 4, 32, 32)
|
|
||||||
timestep = torch.tensor(data=[0])
|
|
||||||
clip_text_embeddings = torch.randn(1, 77, 2048)
|
|
||||||
|
|
||||||
added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)}
|
|
||||||
src_args = (x, timestep, clip_text_embeddings, None, None, None, None, added_cond_kwargs)
|
|
||||||
dst_model.set_timestep(timestep=timestep)
|
|
||||||
dst_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
|
|
||||||
dst_model.set_time_ids(time_ids=added_cond_kwargs["time_ids"])
|
|
||||||
dst_model.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
|
|
||||||
dst_args = (x,)
|
|
||||||
|
|
||||||
mapping = create_state_dict_mapping(
|
|
||||||
source_model=src_model, target_model=dst_model, source_args=src_args, target_args=dst_args # type: ignore
|
|
||||||
)
|
|
||||||
if mapping is None:
|
|
||||||
raise RuntimeError("Could not create state dict mapping")
|
|
||||||
state_dict = convert_state_dict(
|
|
||||||
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
|
|
||||||
)
|
|
||||||
return {k: v for k, v in state_dict.items()}
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--from",
|
|
||||||
type=str,
|
|
||||||
dest="source",
|
|
||||||
required=False,
|
|
||||||
default="stabilityai/stable-diffusion-xl-base-0.9",
|
|
||||||
help="Source model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-file",
|
|
||||||
type=str,
|
|
||||||
required=False,
|
|
||||||
default="stable_diffusion_xl_unet.safetensors",
|
|
||||||
help="Path for the output file",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).unet # type: ignore
|
|
||||||
tensors = convert(src_model=src_model) # type: ignore
|
|
||||||
save_file(tensors=tensors, filename=args.output_file)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
545
src/refiners/fluxion/model_converter.py
Normal file
545
src/refiners/fluxion/model_converter.py
Normal file
|
@ -0,0 +1,545 @@
|
||||||
|
from collections import defaultdict
|
||||||
|
from enum import Enum, auto
|
||||||
|
from pathlib import Path
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from torch.utils.hooks import RemovableHandle
|
||||||
|
import torch
|
||||||
|
from typing import Any, DefaultDict, TypedDict
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import norm, save_to_safetensors
|
||||||
|
|
||||||
|
TORCH_BASIC_LAYERS: list[type[nn.Module]] = [
|
||||||
|
nn.Conv1d,
|
||||||
|
nn.Conv2d,
|
||||||
|
nn.Conv3d,
|
||||||
|
nn.ConvTranspose1d,
|
||||||
|
nn.ConvTranspose2d,
|
||||||
|
nn.ConvTranspose3d,
|
||||||
|
nn.Linear,
|
||||||
|
nn.BatchNorm1d,
|
||||||
|
nn.BatchNorm2d,
|
||||||
|
nn.BatchNorm3d,
|
||||||
|
nn.LayerNorm,
|
||||||
|
nn.GroupNorm,
|
||||||
|
nn.Embedding,
|
||||||
|
nn.MaxPool2d,
|
||||||
|
nn.AvgPool2d,
|
||||||
|
nn.AdaptiveAvgPool2d,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
ModelTypeShape = tuple[str, tuple[torch.Size, ...]]
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleArgsDict(TypedDict):
|
||||||
|
"""Represents positional and keyword arguments passed to a module.
|
||||||
|
|
||||||
|
- `positional`: A tuple of positional arguments.
|
||||||
|
- `keyword`: A dictionary of keyword arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
positional: tuple[Any, ...]
|
||||||
|
keyword: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class ConversionStage(Enum):
|
||||||
|
"""Represents the current stage of the conversion process.
|
||||||
|
|
||||||
|
- `INIT`: The conversion process has not started.
|
||||||
|
- `BASIC_LAYERS_MATCH`: The source and target models have the same number of basic layers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
INIT = auto()
|
||||||
|
BASIC_LAYERS_MATCH = auto()
|
||||||
|
SHAPE_AND_LAYERS_MATCH = auto()
|
||||||
|
MODELS_OUTPUT_AGREE = auto()
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConverter:
|
||||||
|
ModuleArgs = tuple[Any, ...] | dict[str, Any] | ModuleArgsDict
|
||||||
|
stage: ConversionStage = ConversionStage.INIT
|
||||||
|
_stored_mapping: dict[str, str] | None = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
source_model: nn.Module,
|
||||||
|
target_model: nn.Module,
|
||||||
|
source_keys_to_skip: list[str] | None = None,
|
||||||
|
target_keys_to_skip: list[str] | None = None,
|
||||||
|
custom_layer_mapping: dict[type[nn.Module], type[nn.Module]] | None = None,
|
||||||
|
threshold: float = 1e-5,
|
||||||
|
skip_output_check: bool = False,
|
||||||
|
verbose: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Create a ModelConverter.
|
||||||
|
|
||||||
|
- `source_model`: The model to convert from.
|
||||||
|
- `target_model`: The model to convert to.
|
||||||
|
- `source_keys_to_skip`: A list of keys to skip when tracing the source model.
|
||||||
|
- `target_keys_to_skip`: A list of keys to skip when tracing the target model.
|
||||||
|
- `custom_layer_mapping`: A dictionary mapping custom layer types between the source and target models.
|
||||||
|
- `threshold`: The threshold for comparing outputs between the source and target models.
|
||||||
|
- `skip_output_check`: Whether to skip comparing the outputs of the source and target models.
|
||||||
|
- `verbose`: Whether to print messages during the conversion process.
|
||||||
|
|
||||||
|
The conversion process consists of three stages:
|
||||||
|
|
||||||
|
1. Verify that the source and target models have the same number of basic layers.
|
||||||
|
2. Find matching shapes and layers between the source and target models.
|
||||||
|
3. Convert the source model's state_dict to match the target model's state_dict.
|
||||||
|
4. Compare the outputs of the source and target models.
|
||||||
|
|
||||||
|
The conversion process can be run multiple times, and will resume from the last stage.
|
||||||
|
|
||||||
|
### Example:
|
||||||
|
```
|
||||||
|
converter = ModelConverter(source_model=source, target_model=target, threshold=0.1, verbose=False)
|
||||||
|
is_converted = converter(args)
|
||||||
|
if is_converted:
|
||||||
|
converter.save_to_safetensors(path="test.pt")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
self.source_model = source_model
|
||||||
|
self.target_model = target_model
|
||||||
|
self.source_keys_to_skip = source_keys_to_skip or []
|
||||||
|
self.target_keys_to_skip = target_keys_to_skip or []
|
||||||
|
self.custom_layer_mapping = custom_layer_mapping or {}
|
||||||
|
self.threshold = threshold
|
||||||
|
self.skip_output_check = skip_output_check
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
def run(self, source_args: ModuleArgs, target_args: ModuleArgs | None = None) -> bool:
|
||||||
|
"""
|
||||||
|
Run the conversion process.
|
||||||
|
|
||||||
|
- `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments,
|
||||||
|
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
|
||||||
|
is not provided, these arguments will also be passed to the target model.
|
||||||
|
- `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments,
|
||||||
|
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||||
|
|
||||||
|
### Returns:
|
||||||
|
|
||||||
|
- `True` if the conversion process is done and the models agree.
|
||||||
|
|
||||||
|
The conversion process consists of three stages:
|
||||||
|
|
||||||
|
1. Verify that the source and target models have the same number of basic layers.
|
||||||
|
2. Find matching shapes and layers between the source and target models.
|
||||||
|
3. Convert the source model's state_dict to match the target model's state_dict.
|
||||||
|
4. Compare the outputs of the source and target models.
|
||||||
|
|
||||||
|
The conversion process can be run multiple times, and will resume from the last stage.
|
||||||
|
"""
|
||||||
|
if target_args is None:
|
||||||
|
target_args = source_args
|
||||||
|
|
||||||
|
match self.stage:
|
||||||
|
case ConversionStage.MODELS_OUTPUT_AGREE:
|
||||||
|
self._log(message="Conversion is done: you can export the converted model using `save_to_safetensors`")
|
||||||
|
return True
|
||||||
|
|
||||||
|
case ConversionStage.SHAPE_AND_LAYERS_MATCH if self._run_models_output_agree_stage():
|
||||||
|
self.stage = ConversionStage.MODELS_OUTPUT_AGREE
|
||||||
|
return True
|
||||||
|
|
||||||
|
case ConversionStage.BASIC_LAYERS_MATCH if self._run_basic_layers_match_stage(
|
||||||
|
source_args=source_args, target_args=target_args
|
||||||
|
):
|
||||||
|
self.stage = (
|
||||||
|
ConversionStage.SHAPE_AND_LAYERS_MATCH
|
||||||
|
if not self.skip_output_check
|
||||||
|
else ConversionStage.MODELS_OUTPUT_AGREE
|
||||||
|
)
|
||||||
|
return self.run(source_args=source_args, target_args=target_args)
|
||||||
|
|
||||||
|
case ConversionStage.INIT if self._run_init_stage():
|
||||||
|
self.stage = ConversionStage.BASIC_LAYERS_MATCH
|
||||||
|
return self.run(source_args=source_args, target_args=target_args)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"ModelConverter(source_model={self.source_model.__class__.__name__},"
|
||||||
|
f" target_model={self.target_model.__class__.__name__}, stage={self.stage})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
return self.stage == ConversionStage.MODELS_OUTPUT_AGREE
|
||||||
|
|
||||||
|
def get_state_dict(self) -> dict[str, Tensor]:
|
||||||
|
"""Get the converted state_dict."""
|
||||||
|
if not self:
|
||||||
|
raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
|
||||||
|
return self.target_model.state_dict()
|
||||||
|
|
||||||
|
def get_mapping(self) -> dict[str, str]:
|
||||||
|
"""Get the mapping between the source and target models' state_dicts."""
|
||||||
|
if not self:
|
||||||
|
raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
|
||||||
|
assert self._stored_mapping is not None, "Mapping is not stored"
|
||||||
|
return self._stored_mapping
|
||||||
|
|
||||||
|
def save_to_safetensors(self, path: Path | str, metadata: dict[str, str] | None = None, half: bool = False) -> None:
|
||||||
|
"""Save the converted model to a SafeTensors file.
|
||||||
|
|
||||||
|
This method can only be called after the conversion process is done.
|
||||||
|
|
||||||
|
- `path`: The path to save the converted model to.
|
||||||
|
- `metadata`: Metadata to save with the converted model.
|
||||||
|
- `half`: Whether to save the converted model as half precision.
|
||||||
|
|
||||||
|
### Raises:
|
||||||
|
- `ValueError` if the conversion process is not done yet. Run `converter(args)` first.
|
||||||
|
"""
|
||||||
|
if not self:
|
||||||
|
raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
|
||||||
|
state_dict = self.get_state_dict()
|
||||||
|
if half:
|
||||||
|
state_dict = {key: value.half() for key, value in state_dict.items()}
|
||||||
|
save_to_safetensors(path=path, tensors=state_dict, metadata=metadata)
|
||||||
|
|
||||||
|
def map_state_dicts(
|
||||||
|
self,
|
||||||
|
source_args: ModuleArgs,
|
||||||
|
target_args: ModuleArgs | None = None,
|
||||||
|
) -> dict[str, str] | None:
|
||||||
|
"""
|
||||||
|
Find a mapping between the source and target models' state_dicts.
|
||||||
|
|
||||||
|
- `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments,
|
||||||
|
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
|
||||||
|
is not provided, these arguments will also be passed to the target model.
|
||||||
|
- `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments,
|
||||||
|
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||||
|
|
||||||
|
### Returns:
|
||||||
|
- A dictionary mapping keys in the target model's state_dict to keys in the source model's state_dict.
|
||||||
|
"""
|
||||||
|
if target_args is None:
|
||||||
|
target_args = source_args
|
||||||
|
|
||||||
|
source_order = self._trace_module_execution_order(
|
||||||
|
module=self.source_model, args=source_args, keys_to_skip=self.source_keys_to_skip
|
||||||
|
)
|
||||||
|
target_order = self._trace_module_execution_order(
|
||||||
|
module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self._assert_shapes_aligned(source_order=source_order, target_order=target_order):
|
||||||
|
return None
|
||||||
|
|
||||||
|
mapping: dict[str, str] = {}
|
||||||
|
for model_type_shape in source_order:
|
||||||
|
source_keys = source_order[model_type_shape]
|
||||||
|
target_keys = target_order[model_type_shape]
|
||||||
|
mapping.update(zip(target_keys, source_keys))
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
def compare_models(
|
||||||
|
self,
|
||||||
|
source_args: ModuleArgs,
|
||||||
|
target_args: ModuleArgs | None = None,
|
||||||
|
threshold: float = 1e-5,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Compare the outputs of the source and target models.
|
||||||
|
|
||||||
|
- `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments,
|
||||||
|
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
|
||||||
|
is not provided, these arguments will also be passed to the target model.
|
||||||
|
- `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments,
|
||||||
|
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||||
|
- `threshold`: The threshold for comparing outputs between the source and target models.
|
||||||
|
"""
|
||||||
|
if target_args is None:
|
||||||
|
target_args = source_args
|
||||||
|
|
||||||
|
source_outputs = self._collect_layers_outputs(
|
||||||
|
module=self.source_model, args=source_args, keys_to_skip=self.source_keys_to_skip
|
||||||
|
)
|
||||||
|
target_outputs = self._collect_layers_outputs(
|
||||||
|
module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip
|
||||||
|
)
|
||||||
|
|
||||||
|
prev_source_key, prev_target_key = None, None
|
||||||
|
for (source_key, source_output), (target_key, target_output) in zip(source_outputs, target_outputs):
|
||||||
|
diff = norm(source_output - target_output).item()
|
||||||
|
if diff > threshold:
|
||||||
|
self._log(
|
||||||
|
f"Models diverged between {prev_source_key} and {source_key}, and between {prev_target_key} and"
|
||||||
|
f" {target_key}, difference in norm: {diff}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
prev_source_key, prev_target_key = source_key, target_key
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _run_init_stage(self) -> bool:
|
||||||
|
"""Run the init stage of the conversion process."""
|
||||||
|
is_count_correct = self._verify_basic_layers_count()
|
||||||
|
is_not_missing_layers = self._verify_missing_basic_layers()
|
||||||
|
|
||||||
|
return is_count_correct and is_not_missing_layers
|
||||||
|
|
||||||
|
def _run_basic_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool:
|
||||||
|
"""Run the basic layers match stage of the conversion process."""
|
||||||
|
self._log(message="Finding matching shapes and layers...")
|
||||||
|
|
||||||
|
mapping = self.map_state_dicts(source_args=source_args, target_args=target_args)
|
||||||
|
self._stored_mapping = mapping
|
||||||
|
if mapping is None:
|
||||||
|
self._log(message="Models do not have matching shapes.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._log(message="Found matching shapes and layers. Converting state_dict...")
|
||||||
|
|
||||||
|
source_state_dict = self.source_model.state_dict()
|
||||||
|
target_state_dict = self.target_model.state_dict()
|
||||||
|
converted_state_dict = self._convert_state_dict(
|
||||||
|
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
|
||||||
|
)
|
||||||
|
self.target_model.load_state_dict(state_dict=converted_state_dict)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _run_shape_and_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool:
|
||||||
|
"""Run the shape and layers match stage of the conversion process."""
|
||||||
|
if self.compare_models(source_args=source_args, target_args=target_args, threshold=self.threshold):
|
||||||
|
self._log(message="Models agree. You can export the converted model using `save_to_safetensors`")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self._log(message="Models do not agree. Try to increase the threshold or modify the models.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _run_models_output_agree_stage(self) -> bool:
|
||||||
|
"""Run the models output agree stage of the conversion process."""
|
||||||
|
self._log(message="Conversion is done: you can export the converted model using `save_to_safetensors`")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _log(self, message: str) -> None:
|
||||||
|
"""Print a message if `verbose` is `True`."""
|
||||||
|
if self.verbose:
|
||||||
|
print(message)
|
||||||
|
|
||||||
|
def _debug_print_shapes(
|
||||||
|
self,
|
||||||
|
shape: ModelTypeShape,
|
||||||
|
source_keys: list[str],
|
||||||
|
target_keys: list[str],
|
||||||
|
) -> None:
|
||||||
|
"""Print the shapes of the sub-modules in `source_keys` and `target_keys`."""
|
||||||
|
self._log(message=f"{shape}")
|
||||||
|
max_len = max(len(source_keys), len(target_keys))
|
||||||
|
for i in range(max_len):
|
||||||
|
source_key = source_keys[i] if i < len(source_keys) else "---"
|
||||||
|
target_key = target_keys[i] if i < len(target_keys) else "---"
|
||||||
|
self._log(f"\t{source_key}\t{target_key}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _unpack_module_args(module_args: ModuleArgs) -> tuple[tuple[Any, ...], dict[str, Any]]:
|
||||||
|
"""Unpack the positional and keyword arguments passed to a module."""
|
||||||
|
match module_args:
|
||||||
|
case tuple(positional_args):
|
||||||
|
keyword_args: dict[str, Any] = {}
|
||||||
|
case {"positional": positional_args, "keyword": keyword_args}:
|
||||||
|
pass
|
||||||
|
case _:
|
||||||
|
positional_args = ()
|
||||||
|
keyword_args = dict(**module_args)
|
||||||
|
|
||||||
|
return positional_args, keyword_args
|
||||||
|
|
||||||
|
def _infer_basic_layer_type(self, module: nn.Module) -> type[nn.Module] | None:
|
||||||
|
"""Infer the type of a basic layer."""
|
||||||
|
for layer_type in TORCH_BASIC_LAYERS:
|
||||||
|
if isinstance(module, layer_type):
|
||||||
|
return layer_type
|
||||||
|
|
||||||
|
for source_type in self.custom_layer_mapping.keys():
|
||||||
|
if isinstance(module, source_type):
|
||||||
|
return source_type
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_module_signature(self, module: nn.Module) -> ModelTypeShape:
|
||||||
|
"""Get the signature of a module."""
|
||||||
|
layer_type = self._infer_basic_layer_type(module=module)
|
||||||
|
assert layer_type is not None, f"Module {module} is not a basic layer"
|
||||||
|
param_shapes = [p.shape for p in module.parameters()]
|
||||||
|
return (str(object=layer_type), tuple(param_shapes))
|
||||||
|
|
||||||
|
def _count_basic_layers(self, module: nn.Module) -> dict[type[nn.Module], int]:
|
||||||
|
"""Count the number of basic layers in a module."""
|
||||||
|
count: DefaultDict[type[nn.Module], int] = defaultdict(int)
|
||||||
|
for submodule in module.modules():
|
||||||
|
layer_type = self._infer_basic_layer_type(module=submodule)
|
||||||
|
if layer_type is not None:
|
||||||
|
count[layer_type] += 1
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
def _verify_basic_layers_count(self) -> bool:
|
||||||
|
"""Verify that the source and target models have the same number of basic layers."""
|
||||||
|
source_layers = self._count_basic_layers(module=self.source_model)
|
||||||
|
target_layers = self._count_basic_layers(module=self.target_model)
|
||||||
|
|
||||||
|
diff: dict[type[nn.Module], tuple[int, int]] = {}
|
||||||
|
for layer_type in set(source_layers.keys()) | set(target_layers.keys()):
|
||||||
|
source_count = source_layers.get(layer_type, 0)
|
||||||
|
target_count = target_layers.get(layer_type, 0)
|
||||||
|
if source_count != target_count:
|
||||||
|
diff[layer_type] = (source_count, target_count)
|
||||||
|
|
||||||
|
if diff:
|
||||||
|
message = "Models do not have the same number of basic layers:\n"
|
||||||
|
for layer_type, counts in diff.items():
|
||||||
|
message += f" {layer_type}: Source {counts[0]} - Target {counts[1]}\n"
|
||||||
|
self._log(message=message.rstrip())
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _is_weighted_leaf_module(self, module: nn.Module) -> bool:
|
||||||
|
"""Check if a module is a leaf module with weights."""
|
||||||
|
return next(module.parameters(), None) is not None and next(module.children(), None) is None
|
||||||
|
|
||||||
|
def _check_for_missing_basic_layers(self, module: nn.Module) -> list[type[nn.Module]]:
|
||||||
|
"""Check if a module has weighted leaf modules that are not basic layers."""
|
||||||
|
return [
|
||||||
|
type(submodule)
|
||||||
|
for submodule in module.modules()
|
||||||
|
if self._is_weighted_leaf_module(module=submodule) and not self._infer_basic_layer_type(module=submodule)
|
||||||
|
]
|
||||||
|
|
||||||
|
def _verify_missing_basic_layers(self) -> bool:
|
||||||
|
"""Verify that the source and target models do not have missing basic layers."""
|
||||||
|
missing_source_layers = self._check_for_missing_basic_layers(module=self.source_model)
|
||||||
|
missing_target_layers = self._check_for_missing_basic_layers(module=self.target_model)
|
||||||
|
|
||||||
|
if missing_source_layers or missing_target_layers:
|
||||||
|
self._log(
|
||||||
|
message=(
|
||||||
|
"Models might have missing basic layers. You can either pass them into keys to skip or set"
|
||||||
|
f" `check_missing_basic_layer` to `False`: {missing_source_layers}, {missing_target_layers}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _trace_module_execution_order(
|
||||||
|
self,
|
||||||
|
module: nn.Module,
|
||||||
|
args: ModuleArgs,
|
||||||
|
keys_to_skip: list[str],
|
||||||
|
) -> dict[ModelTypeShape, list[str]]:
|
||||||
|
"""
|
||||||
|
Execute a forward pass and store the order of execution of specific sub-modules.
|
||||||
|
|
||||||
|
- `module`: The module to trace.
|
||||||
|
- `args`: The arguments to pass to the module it can be either a tuple of positional arguments,
|
||||||
|
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||||
|
- `keys_to_skip`: A list of keys to skip when tracing the module.
|
||||||
|
|
||||||
|
### Returns:
|
||||||
|
- A dictionary mapping the signature of each sub-module to a list of keys in the module's `named_modules`
|
||||||
|
"""
|
||||||
|
submodule_to_key: dict[nn.Module, str] = {}
|
||||||
|
execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list)
|
||||||
|
|
||||||
|
def collect_execution_order_hook(layer: nn.Module, *_: Any) -> None:
|
||||||
|
layer_signature = self.get_module_signature(module=layer)
|
||||||
|
execution_order[layer_signature].append(submodule_to_key[layer])
|
||||||
|
|
||||||
|
hooks: list[RemovableHandle] = []
|
||||||
|
named_modules: list[tuple[str, nn.Module]] = module.named_modules() # type: ignore
|
||||||
|
for name, submodule in named_modules:
|
||||||
|
if (self._infer_basic_layer_type(module=submodule) is not None) and name not in keys_to_skip:
|
||||||
|
submodule_to_key[submodule] = name # type: ignore
|
||||||
|
hook = submodule.register_forward_hook(hook=collect_execution_order_hook)
|
||||||
|
hooks.append(hook)
|
||||||
|
|
||||||
|
positional_args, keyword_args = self._unpack_module_args(module_args=args)
|
||||||
|
module(*positional_args, **keyword_args)
|
||||||
|
|
||||||
|
for hook in hooks:
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
|
return dict(execution_order)
|
||||||
|
|
||||||
|
def _assert_shapes_aligned(
|
||||||
|
self, source_order: dict[ModelTypeShape, list[str]], target_order: dict[ModelTypeShape, list[str]]
|
||||||
|
) -> bool:
|
||||||
|
"""Assert that the shapes of the sub-modules in `source_order` and `target_order` are aligned."""
|
||||||
|
model_type_shapes = set(source_order.keys()) | set(target_order.keys())
|
||||||
|
shape_missmatched = False
|
||||||
|
|
||||||
|
for model_type_shape in model_type_shapes:
|
||||||
|
source_keys = source_order.get(model_type_shape, [])
|
||||||
|
target_keys = target_order.get(model_type_shape, [])
|
||||||
|
|
||||||
|
if len(source_keys) != len(target_keys):
|
||||||
|
shape_missmatched = True
|
||||||
|
self._debug_print_shapes(shape=model_type_shape, source_keys=source_keys, target_keys=target_keys)
|
||||||
|
|
||||||
|
return not shape_missmatched
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_state_dict(
|
||||||
|
source_state_dict: dict[str, Tensor], target_state_dict: dict[str, Tensor], state_dict_mapping: dict[str, str]
|
||||||
|
) -> dict[str, Tensor]:
|
||||||
|
"""Convert the source model's state_dict to match the target model's state_dict."""
|
||||||
|
converted_state_dict: dict[str, Tensor] = {}
|
||||||
|
for target_key in target_state_dict:
|
||||||
|
target_prefix, suffix = target_key.rsplit(sep=".", maxsplit=1)
|
||||||
|
source_prefix = state_dict_mapping[target_prefix]
|
||||||
|
source_key = ".".join([source_prefix, suffix])
|
||||||
|
converted_state_dict[target_key] = source_state_dict[source_key]
|
||||||
|
|
||||||
|
return converted_state_dict
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _collect_layers_outputs(
|
||||||
|
self, module: nn.Module, args: ModuleArgs, keys_to_skip: list[str]
|
||||||
|
) -> list[tuple[str, Tensor]]:
|
||||||
|
"""
|
||||||
|
Execute a forward pass and store the output of specific sub-modules.
|
||||||
|
|
||||||
|
- `module`: The module to trace.
|
||||||
|
- `args`: The arguments to pass to the module it can be either a tuple of positional arguments,
|
||||||
|
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||||
|
- `keys_to_skip`: A list of keys to skip when tracing the module.
|
||||||
|
|
||||||
|
### Returns:
|
||||||
|
- A list of tuples containing the key of each sub-module and its output.
|
||||||
|
|
||||||
|
### Note:
|
||||||
|
- The output of each sub-module is cloned to avoid memory leaks.
|
||||||
|
"""
|
||||||
|
submodule_to_key: dict[nn.Module, str] = {}
|
||||||
|
execution_order: list[tuple[str, Tensor]] = []
|
||||||
|
|
||||||
|
def collect_execution_order_hook(layer: nn.Module, _: Any, output: Tensor) -> None:
|
||||||
|
execution_order.append((submodule_to_key[layer], output.clone()))
|
||||||
|
|
||||||
|
hooks: list[RemovableHandle] = []
|
||||||
|
named_modules: list[tuple[str, nn.Module]] = module.named_modules() # type: ignore
|
||||||
|
for name, submodule in named_modules:
|
||||||
|
if (self._infer_basic_layer_type(module=submodule) is not None) and name not in keys_to_skip:
|
||||||
|
submodule_to_key[submodule] = name # type: ignore
|
||||||
|
hook = submodule.register_forward_hook(hook=collect_execution_order_hook)
|
||||||
|
hooks.append(hook)
|
||||||
|
|
||||||
|
positional_args, keyword_args = self._unpack_module_args(module_args=args)
|
||||||
|
module(*positional_args, **keyword_args)
|
||||||
|
|
||||||
|
for hook in hooks:
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
|
return execution_order
|
|
@ -1,18 +1,13 @@
|
||||||
from collections import defaultdict
|
from typing import Dict, Iterable, Literal, TypeVar
|
||||||
from enum import Enum
|
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Literal, TypeVar
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from numpy import array, float32
|
from numpy import array, float32
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from safetensors import safe_open as _safe_open # type: ignore
|
from safetensors import safe_open as _safe_open # type: ignore
|
||||||
from safetensors.torch import save_file as _save_file # type: ignore
|
from safetensors.torch import save_file as _save_file # type: ignore
|
||||||
from torch import norm as _norm, manual_seed as _manual_seed # type: ignore
|
from torch import norm as _norm, manual_seed as _manual_seed # type: ignore
|
||||||
|
import torch
|
||||||
from torch.nn.functional import pad as _pad, interpolate as _interpolate # type: ignore
|
from torch.nn.functional import pad as _pad, interpolate as _interpolate # type: ignore
|
||||||
from torch import Size, Tensor, tensor, no_grad, device as Device, dtype as DType, nn
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
from torch.utils.hooks import RemovableHandle
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from refiners.fluxion.layers.module import Module
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
@ -31,7 +26,7 @@ def pad(x: Tensor, pad: Iterable[int], value: float = 0.0) -> Tensor:
|
||||||
return _pad(input=x, pad=pad, value=value) # type: ignore
|
return _pad(input=x, pad=pad, value=value) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def interpolate(x: Tensor, factor: float | Size, mode: str = "nearest") -> Tensor:
|
def interpolate(x: Tensor, factor: float | torch.Size, mode: str = "nearest") -> Tensor:
|
||||||
return (
|
return (
|
||||||
_interpolate(x, scale_factor=factor, mode=mode)
|
_interpolate(x, scale_factor=factor, mode=mode)
|
||||||
if isinstance(factor, float | int)
|
if isinstance(factor, float | int)
|
||||||
|
@ -44,7 +39,9 @@ def bidirectional_mapping(mapping: Dict[str, str]) -> Dict[str, str]:
|
||||||
|
|
||||||
|
|
||||||
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
|
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
|
||||||
return tensor(array(image).astype(float32).transpose(2, 0, 1) / 255.0, device=device, dtype=dtype).unsqueeze(0)
|
return torch.tensor(array(image).astype(float32).transpose(2, 0, 1) / 255.0, device=device, dtype=dtype).unsqueeze(
|
||||||
|
0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def tensor_to_image(tensor: Tensor) -> Image.Image:
|
def tensor_to_image(tensor: Tensor) -> Image.Image:
|
||||||
|
@ -77,220 +74,3 @@ def load_metadata_from_safetensors(path: Path | str) -> dict[str, str] | None:
|
||||||
|
|
||||||
def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None:
|
def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None:
|
||||||
_save_file(tensors, path, metadata) # type: ignore
|
_save_file(tensors, path, metadata) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class BasicLayers(Enum):
|
|
||||||
Conv1d = nn.Conv1d
|
|
||||||
Conv2d = nn.Conv2d
|
|
||||||
Conv3d = nn.Conv3d
|
|
||||||
ConvTranspose1d = nn.ConvTranspose1d
|
|
||||||
ConvTranspose2d = nn.ConvTranspose2d
|
|
||||||
ConvTranspose3d = nn.ConvTranspose3d
|
|
||||||
Linear = nn.Linear
|
|
||||||
BatchNorm1d = nn.BatchNorm1d
|
|
||||||
BatchNorm2d = nn.BatchNorm2d
|
|
||||||
BatchNorm3d = nn.BatchNorm3d
|
|
||||||
LayerNorm = nn.LayerNorm
|
|
||||||
GroupNorm = nn.GroupNorm
|
|
||||||
Embedding = nn.Embedding
|
|
||||||
MaxPool2d = nn.MaxPool2d
|
|
||||||
AvgPool2d = nn.AvgPool2d
|
|
||||||
AdaptiveAvgPool2d = nn.AdaptiveAvgPool2d
|
|
||||||
|
|
||||||
|
|
||||||
ModelTypeShape = tuple[str, tuple[Size, ...]]
|
|
||||||
|
|
||||||
|
|
||||||
def infer_basic_layer_type(module: nn.Module) -> BasicLayers | None:
|
|
||||||
"""Identify if the provided module matches any in the BasicLayers enum."""
|
|
||||||
for layer_type in BasicLayers:
|
|
||||||
if isinstance(module, layer_type.value):
|
|
||||||
return layer_type
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_module_signature(module: "Module") -> ModelTypeShape:
|
|
||||||
"""Return a tuple representing the module's type and parameter shapes."""
|
|
||||||
layer_type = infer_basic_layer_type(module=module)
|
|
||||||
assert layer_type is not None, f"Module {module} is not a basic layer"
|
|
||||||
param_shapes = [p.shape for p in module.parameters()]
|
|
||||||
return (str(object=layer_type), tuple(param_shapes))
|
|
||||||
|
|
||||||
|
|
||||||
def forward_order_of_execution(
|
|
||||||
module: "Module",
|
|
||||||
example_args: tuple[Any, ...],
|
|
||||||
key_skipper: Callable[[str], bool] | None = None,
|
|
||||||
) -> dict[ModelTypeShape, list[str]]:
|
|
||||||
"""
|
|
||||||
Determine the execution order of sub-modules during a forward pass.
|
|
||||||
|
|
||||||
Optionally skips specific modules using `key_skipper`.
|
|
||||||
"""
|
|
||||||
key_skipper = key_skipper or (lambda _: False)
|
|
||||||
|
|
||||||
submodule_to_key: dict["Module", str] = {}
|
|
||||||
execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list)
|
|
||||||
|
|
||||||
def collect_execution_order_hook(layer: "Module", *_: Any) -> None:
|
|
||||||
layer_signature = get_module_signature(module=layer)
|
|
||||||
execution_order[layer_signature].append(submodule_to_key[layer])
|
|
||||||
|
|
||||||
hooks: list[RemovableHandle] = []
|
|
||||||
for name, submodule in module.named_modules():
|
|
||||||
if (infer_basic_layer_type(module=submodule) is not None) and not key_skipper(name):
|
|
||||||
submodule_to_key[submodule] = name
|
|
||||||
hook = submodule.register_forward_hook(hook=collect_execution_order_hook)
|
|
||||||
hooks.append(hook)
|
|
||||||
|
|
||||||
with no_grad():
|
|
||||||
module(*example_args)
|
|
||||||
|
|
||||||
for hook in hooks:
|
|
||||||
hook.remove()
|
|
||||||
|
|
||||||
return dict(execution_order)
|
|
||||||
|
|
||||||
|
|
||||||
def print_side_by_side(
|
|
||||||
shape: ModelTypeShape,
|
|
||||||
source_keys: list[str],
|
|
||||||
target_keys: list[str],
|
|
||||||
) -> None:
|
|
||||||
"""Print module keys side by side, useful for debugging shape mismatches."""
|
|
||||||
print(f"{shape}")
|
|
||||||
max_len = max(len(source_keys), len(target_keys))
|
|
||||||
for i in range(max_len):
|
|
||||||
source_key = source_keys[i] if i < len(source_keys) else "---"
|
|
||||||
target_key = target_keys[i] if i < len(target_keys) else "---"
|
|
||||||
print(f"\t{source_key}\t{target_key}")
|
|
||||||
|
|
||||||
|
|
||||||
def verify_shape_match(
|
|
||||||
source_order: dict[ModelTypeShape, list[str]], target_order: dict[ModelTypeShape, list[str]]
|
|
||||||
) -> bool:
|
|
||||||
"""Check if the sub-modules in source and target have matching shapes."""
|
|
||||||
model_type_shapes = set(source_order.keys()) | set(target_order.keys())
|
|
||||||
shape_missmatched = False
|
|
||||||
|
|
||||||
for model_type_shape in model_type_shapes:
|
|
||||||
source_keys = source_order.get(model_type_shape, [])
|
|
||||||
target_keys = target_order.get(model_type_shape, [])
|
|
||||||
|
|
||||||
if len(source_keys) != len(target_keys):
|
|
||||||
shape_missmatched = True
|
|
||||||
print_side_by_side(shape=model_type_shape, source_keys=source_keys, target_keys=target_keys)
|
|
||||||
|
|
||||||
return not shape_missmatched
|
|
||||||
|
|
||||||
|
|
||||||
def create_state_dict_mapping(
|
|
||||||
source_model: "Module",
|
|
||||||
target_model: "Module",
|
|
||||||
source_args: tuple[Any, ...],
|
|
||||||
target_args: tuple[Any, ...] | None = None,
|
|
||||||
source_key_skipper: Callable[[str], bool] | None = None,
|
|
||||||
target_key_skipper: Callable[[str], bool] | None = None,
|
|
||||||
) -> dict[str, str] | None:
|
|
||||||
"""
|
|
||||||
Create a mapping between state_dict keys of the source and target models.
|
|
||||||
|
|
||||||
This facilitates the transfer of weights when architectures have slight differences.
|
|
||||||
"""
|
|
||||||
if target_args is None:
|
|
||||||
target_args = source_args
|
|
||||||
|
|
||||||
source_order = forward_order_of_execution(
|
|
||||||
module=source_model, example_args=source_args, key_skipper=source_key_skipper
|
|
||||||
)
|
|
||||||
target_order = forward_order_of_execution(
|
|
||||||
module=target_model, example_args=target_args, key_skipper=target_key_skipper
|
|
||||||
)
|
|
||||||
|
|
||||||
if not verify_shape_match(source_order=source_order, target_order=target_order):
|
|
||||||
return None
|
|
||||||
|
|
||||||
mapping: dict[str, str] = {}
|
|
||||||
for model_type_shape in source_order:
|
|
||||||
source_keys = source_order[model_type_shape]
|
|
||||||
target_keys = target_order[model_type_shape]
|
|
||||||
mapping.update(zip(target_keys, source_keys))
|
|
||||||
|
|
||||||
return mapping
|
|
||||||
|
|
||||||
|
|
||||||
def convert_state_dict(
|
|
||||||
source_state_dict: dict[str, Tensor], target_state_dict: dict[str, Tensor], state_dict_mapping: dict[str, str]
|
|
||||||
) -> dict[str, Tensor]:
|
|
||||||
"""Convert source state_dict based on the provided mapping to match target state_dict structure."""
|
|
||||||
converted_state_dict: dict[str, Tensor] = {}
|
|
||||||
for target_key in target_state_dict:
|
|
||||||
target_prefix, suffix = target_key.rsplit(sep=".", maxsplit=1)
|
|
||||||
source_prefix = state_dict_mapping[target_prefix]
|
|
||||||
source_key = ".".join([source_prefix, suffix])
|
|
||||||
converted_state_dict[target_key] = source_state_dict[source_key]
|
|
||||||
|
|
||||||
return converted_state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def forward_store_outputs(
|
|
||||||
module: "Module",
|
|
||||||
example_args: tuple[Any, ...],
|
|
||||||
key_skipper: Callable[[str], bool] | None = None,
|
|
||||||
) -> list[tuple[str, Tensor]]:
|
|
||||||
"""Execute a forward pass and store outputs of specific sub-modules."""
|
|
||||||
key_skipper = key_skipper or (lambda _: False)
|
|
||||||
submodule_to_key: dict["Module", str] = {}
|
|
||||||
execution_order: list[tuple[str, Tensor]] = [] # Store outputs in a list
|
|
||||||
|
|
||||||
def collect_execution_order_hook(layer: "Module", _: Any, output: Tensor) -> None:
|
|
||||||
execution_order.append((submodule_to_key[layer], output.clone())) # Store a copy of the output
|
|
||||||
|
|
||||||
hooks: list[RemovableHandle] = []
|
|
||||||
for name, submodule in module.named_modules():
|
|
||||||
if (infer_basic_layer_type(module=module) is not None) and not key_skipper(name):
|
|
||||||
submodule_to_key[submodule] = name
|
|
||||||
hook = submodule.register_forward_hook(hook=collect_execution_order_hook)
|
|
||||||
hooks.append(hook)
|
|
||||||
|
|
||||||
with no_grad():
|
|
||||||
module(*example_args)
|
|
||||||
|
|
||||||
for hook in hooks:
|
|
||||||
hook.remove()
|
|
||||||
|
|
||||||
return execution_order
|
|
||||||
|
|
||||||
|
|
||||||
def compare_models(
|
|
||||||
source_model: "Module",
|
|
||||||
target_model: "Module",
|
|
||||||
source_args: tuple[Any, ...],
|
|
||||||
target_args: tuple[Any, ...] | None = None,
|
|
||||||
source_key_skipper: Callable[[str], bool] | None = None,
|
|
||||||
target_key_skipper: Callable[[str], bool] | None = None,
|
|
||||||
threshold: float = 1e-5,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Compare the outputs of two models given the same inputs.
|
|
||||||
|
|
||||||
Flag if any difference exceeds the given threshold.
|
|
||||||
"""
|
|
||||||
if target_args is None:
|
|
||||||
target_args = source_args
|
|
||||||
|
|
||||||
source_order = forward_store_outputs(module=source_model, example_args=source_args, key_skipper=source_key_skipper)
|
|
||||||
target_order = forward_store_outputs(module=target_model, example_args=target_args, key_skipper=target_key_skipper)
|
|
||||||
|
|
||||||
prev_source_key, prev_target_key = None, None
|
|
||||||
for (source_key, source_output), (target_key, target_output) in zip(source_order, target_order):
|
|
||||||
diff = norm(source_output - target_output).item()
|
|
||||||
if diff > threshold:
|
|
||||||
print(
|
|
||||||
f"Models diverged between {prev_source_key} and {source_key}, and between {prev_target_key} and"
|
|
||||||
f" {target_key}, difference in norm: {diff}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
prev_source_key, prev_target_key = source_key, target_key
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
|
@ -3,9 +3,10 @@ from pathlib import Path
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from refiners.fluxion.utils import manual_seed
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
|
||||||
from refiners.fluxion.utils import compare_models
|
from refiners.fluxion.model_converter import ModelConverter
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@ -47,23 +48,28 @@ def refiners_sdxl_unet(sdxl_unet_weights_std: Path) -> SDXLUNet:
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test_sdxl_unet(diffusers_sdxl_unet: Any, refiners_sdxl_unet: SDXLUNet) -> None:
|
def test_sdxl_unet(diffusers_sdxl_unet: Any, refiners_sdxl_unet: SDXLUNet) -> None:
|
||||||
torch.manual_seed(seed=0) # type: ignore
|
source = diffusers_sdxl_unet
|
||||||
|
target = refiners_sdxl_unet
|
||||||
|
|
||||||
|
manual_seed(seed=0)
|
||||||
x = torch.randn(1, 4, 32, 32)
|
x = torch.randn(1, 4, 32, 32)
|
||||||
timestep = torch.tensor(data=[0])
|
timestep = torch.tensor(data=[0])
|
||||||
clip_text_embeddings = torch.randn(1, 77, 2048)
|
clip_text_embeddings = torch.randn(1, 77, 2048)
|
||||||
added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)}
|
added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)}
|
||||||
source_args = (x, timestep, clip_text_embeddings, None, None, None, None, added_cond_kwargs)
|
|
||||||
|
|
||||||
refiners_sdxl_unet.set_timestep(timestep=timestep)
|
target.set_timestep(timestep=timestep)
|
||||||
refiners_sdxl_unet.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
|
target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
|
||||||
refiners_sdxl_unet.set_time_ids(time_ids=added_cond_kwargs["time_ids"])
|
target.set_time_ids(time_ids=added_cond_kwargs["time_ids"])
|
||||||
refiners_sdxl_unet.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
|
target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
|
||||||
target_args = (x,)
|
target_args = (x,)
|
||||||
|
source_args = {
|
||||||
|
"positional": (x, timestep, clip_text_embeddings),
|
||||||
|
"keyword": {"added_cond_kwargs": added_cond_kwargs},
|
||||||
|
}
|
||||||
|
|
||||||
assert compare_models(
|
converter = ModelConverter(source_model=source, target_model=target, verbose=False, threshold=1e-2)
|
||||||
source_model=diffusers_sdxl_unet,
|
|
||||||
target_model=refiners_sdxl_unet,
|
assert converter.run(
|
||||||
source_args=source_args,
|
source_args=source_args,
|
||||||
target_args=target_args,
|
target_args=target_args,
|
||||||
threshold=1e-2,
|
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue