refiners/scripts/conversion/convert_transformers_clip_image_model.py

150 lines
5.5 KiB
Python
Raw Permalink Normal View History

import argparse
from pathlib import Path
from typing import NamedTuple, cast
import torch
from torch import nn
from transformers import CLIPVisionModelWithProjection # type: ignore
import refiners.fluxion.layers as fl
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import save_to_safetensors
from refiners.foundationals.clip.image_encoder import CLIPImageEncoder
class Args(argparse.Namespace):
source_path: str
subfolder: str
output_path: str | None
half: bool
verbose: bool
threshold: float
class CLIPImageEncoderConfig(NamedTuple):
architectures: list[str]
num_channels: int
hidden_size: int
hidden_act: str
image_size: int
projection_dim: int
patch_size: int
num_hidden_layers: int
num_attention_heads: int
intermediate_size: int
layer_norm_eps: float
def setup_converter(args: Args) -> ModelConverter:
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
source: nn.Module = CLIPVisionModelWithProjection.from_pretrained( # type: ignore
pretrained_model_name_or_path=args.source_path,
subfolder=args.subfolder,
low_cpu_mem_usage=False,
)
assert isinstance(source, nn.Module), "Source model is not a nn.Module"
config = cast(CLIPImageEncoderConfig, source.config) # pyright: ignore[reportArgumentType, reportUnknownMemberType]
assert (
config.architectures[0] == "CLIPVisionModelWithProjection"
), f"Unsupported architecture: {config.architectures[0]}"
assert config.num_channels == 3, f"Expected 3 input channels, got {config.num_channels}"
assert config.hidden_act == "gelu", f"Unsupported activation: {config.hidden_act}"
target = CLIPImageEncoder(
image_size=config.image_size,
embedding_dim=config.hidden_size,
output_dim=config.projection_dim,
patch_size=config.patch_size,
num_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
feedforward_dim=config.intermediate_size,
layer_norm_eps=config.layer_norm_eps,
)
x = torch.randn(1, 3, config.image_size, config.image_size)
converter = ModelConverter(source_model=source, target_model=target, verbose=True)
# Custom conversion logic since the class embedding (fl.Parameter layer) is not supported out-of-the-box by the
# converter
mapping = converter.map_state_dicts((x,))
assert mapping is not None
source_state_dict = source.state_dict()
target_state_dict = target.state_dict()
# Remove the class embedding from state dict since it was not mapped by the model converter
class_embedding = target.ensure_find(fl.Parameter)
class_embedding_key = next((n for n, p in target.named_parameters() if id(p) == id(class_embedding.weight)), None)
assert class_embedding_key is not None
assert class_embedding_key in target_state_dict
del target_state_dict[class_embedding_key]
converted_state_dict = converter._convert_state_dict( # type: ignore[reportPrivateUsage]
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
)
target.load_state_dict(state_dict=converted_state_dict, strict=False)
# Ad hoc post-conversion steps
embed = source.vision_model.embeddings.class_embedding
class_embedding.weight = torch.nn.Parameter(embed.clone().reshape_as(class_embedding.weight)) # type: ignore
assert converter.compare_models((x,), threshold=args.threshold)
return converter
def main() -> None:
parser = argparse.ArgumentParser(
description="Converts a CLIPImageEncoder from the library transformers from the HuggingFace Hub to refiners."
)
parser.add_argument(
"--from",
type=str,
dest="source_path",
default="stabilityai/stable-diffusion-2-1-unclip",
help=(
"Can be a path to a .bin file, a .safetensors file or a model name from the HuggingFace Hub. Default:"
" stabilityai/stable-diffusion-2-1-unclip"
),
)
parser.add_argument(
"--subfolder",
type=str,
dest="subfolder",
default="image_encoder",
help="Subfolder in the source path where the model is located inside the Hub. Default: image_encoder",
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default=None,
help=(
"Output path (.safetensors) for converted model. If not provided, the output path will be the same as the"
" source path."
),
)
parser.add_argument("--half", action="store_true", help="Convert to half precision.")
parser.add_argument(
"--verbose",
action="store_true",
default=False,
help="Prints additional information during conversion. Default: False",
)
parser.add_argument("--threshold", type=float, default=1e-2, help="Threshold for model comparison. Default: 1e-2")
args = parser.parse_args(namespace=Args())
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}-{args.subfolder}.safetensors"
converter = setup_converter(args=args)
# Do not use converter.save_to_safetensors since it is not in a valid state due to the ad hoc conversion
state_dict = converter.target_model.state_dict()
if args.half:
state_dict = {key: value.half() for key, value in state_dict.items()}
save_to_safetensors(path=args.output_path, tensors=state_dict)
if __name__ == "__main__":
main()