2023-08-30 08:05:31 +00:00
|
|
|
import argparse
|
|
|
|
from pathlib import Path
|
2024-08-02 09:56:37 +00:00
|
|
|
from typing import NamedTuple, cast
|
2023-12-11 10:46:38 +00:00
|
|
|
|
|
|
|
import torch
|
2023-08-30 08:05:31 +00:00
|
|
|
from torch import nn
|
|
|
|
from transformers import CLIPVisionModelWithProjection # type: ignore
|
2023-12-11 10:46:38 +00:00
|
|
|
|
2023-08-30 08:05:31 +00:00
|
|
|
import refiners.fluxion.layers as fl
|
2023-12-11 10:46:38 +00:00
|
|
|
from refiners.fluxion.model_converter import ModelConverter
|
|
|
|
from refiners.fluxion.utils import save_to_safetensors
|
|
|
|
from refiners.foundationals.clip.image_encoder import CLIPImageEncoder
|
2023-08-30 08:05:31 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Args(argparse.Namespace):
|
|
|
|
source_path: str
|
|
|
|
subfolder: str
|
|
|
|
output_path: str | None
|
|
|
|
half: bool
|
|
|
|
verbose: bool
|
2023-09-07 12:46:15 +00:00
|
|
|
threshold: float
|
2023-08-30 08:05:31 +00:00
|
|
|
|
|
|
|
|
2024-08-02 09:56:37 +00:00
|
|
|
class CLIPImageEncoderConfig(NamedTuple):
|
|
|
|
architectures: list[str]
|
|
|
|
num_channels: int
|
|
|
|
hidden_size: int
|
|
|
|
hidden_act: str
|
|
|
|
image_size: int
|
|
|
|
projection_dim: int
|
|
|
|
patch_size: int
|
|
|
|
num_hidden_layers: int
|
|
|
|
num_attention_heads: int
|
|
|
|
intermediate_size: int
|
|
|
|
layer_norm_eps: float
|
|
|
|
|
|
|
|
|
2023-08-30 08:05:31 +00:00
|
|
|
def setup_converter(args: Args) -> ModelConverter:
|
2023-12-14 03:53:06 +00:00
|
|
|
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
2023-08-30 08:05:31 +00:00
|
|
|
source: nn.Module = CLIPVisionModelWithProjection.from_pretrained( # type: ignore
|
2023-12-14 03:53:06 +00:00
|
|
|
pretrained_model_name_or_path=args.source_path,
|
|
|
|
subfolder=args.subfolder,
|
|
|
|
low_cpu_mem_usage=False,
|
2023-08-30 08:05:31 +00:00
|
|
|
)
|
|
|
|
assert isinstance(source, nn.Module), "Source model is not a nn.Module"
|
2024-08-02 09:56:37 +00:00
|
|
|
config = cast(CLIPImageEncoderConfig, source.config) # pyright: ignore[reportArgumentType, reportUnknownMemberType]
|
2023-08-30 08:05:31 +00:00
|
|
|
|
2024-08-02 09:56:37 +00:00
|
|
|
assert (
|
|
|
|
config.architectures[0] == "CLIPVisionModelWithProjection"
|
|
|
|
), f"Unsupported architecture: {config.architectures[0]}"
|
|
|
|
assert config.num_channels == 3, f"Expected 3 input channels, got {config.num_channels}"
|
|
|
|
assert config.hidden_act == "gelu", f"Unsupported activation: {config.hidden_act}"
|
2023-08-30 08:05:31 +00:00
|
|
|
|
|
|
|
target = CLIPImageEncoder(
|
2024-08-02 09:56:37 +00:00
|
|
|
image_size=config.image_size,
|
|
|
|
embedding_dim=config.hidden_size,
|
|
|
|
output_dim=config.projection_dim,
|
|
|
|
patch_size=config.patch_size,
|
|
|
|
num_layers=config.num_hidden_layers,
|
|
|
|
num_attention_heads=config.num_attention_heads,
|
|
|
|
feedforward_dim=config.intermediate_size,
|
|
|
|
layer_norm_eps=config.layer_norm_eps,
|
2023-08-30 08:05:31 +00:00
|
|
|
)
|
|
|
|
|
2024-08-02 09:56:37 +00:00
|
|
|
x = torch.randn(1, 3, config.image_size, config.image_size)
|
2023-08-30 08:05:31 +00:00
|
|
|
|
|
|
|
converter = ModelConverter(source_model=source, target_model=target, verbose=True)
|
|
|
|
|
|
|
|
# Custom conversion logic since the class embedding (fl.Parameter layer) is not supported out-of-the-box by the
|
|
|
|
# converter
|
|
|
|
mapping = converter.map_state_dicts((x,))
|
|
|
|
assert mapping is not None
|
|
|
|
|
|
|
|
source_state_dict = source.state_dict()
|
|
|
|
target_state_dict = target.state_dict()
|
|
|
|
|
|
|
|
# Remove the class embedding from state dict since it was not mapped by the model converter
|
2023-09-12 09:50:56 +00:00
|
|
|
class_embedding = target.ensure_find(fl.Parameter)
|
2023-12-08 09:12:37 +00:00
|
|
|
class_embedding_key = next((n for n, p in target.named_parameters() if id(p) == id(class_embedding.weight)), None)
|
2023-08-30 08:05:31 +00:00
|
|
|
assert class_embedding_key is not None
|
|
|
|
assert class_embedding_key in target_state_dict
|
|
|
|
del target_state_dict[class_embedding_key]
|
|
|
|
|
|
|
|
converted_state_dict = converter._convert_state_dict( # type: ignore[reportPrivateUsage]
|
|
|
|
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
|
|
|
|
)
|
|
|
|
target.load_state_dict(state_dict=converted_state_dict, strict=False)
|
|
|
|
|
|
|
|
# Ad hoc post-conversion steps
|
2023-12-08 09:12:37 +00:00
|
|
|
embed = source.vision_model.embeddings.class_embedding
|
|
|
|
class_embedding.weight = torch.nn.Parameter(embed.clone().reshape_as(class_embedding.weight)) # type: ignore
|
2023-08-30 08:05:31 +00:00
|
|
|
|
2023-09-07 12:46:15 +00:00
|
|
|
assert converter.compare_models((x,), threshold=args.threshold)
|
2023-08-30 08:05:31 +00:00
|
|
|
|
|
|
|
return converter
|
|
|
|
|
|
|
|
|
|
|
|
def main() -> None:
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
description="Converts a CLIPImageEncoder from the library transformers from the HuggingFace Hub to refiners."
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--from",
|
|
|
|
type=str,
|
|
|
|
dest="source_path",
|
|
|
|
default="stabilityai/stable-diffusion-2-1-unclip",
|
|
|
|
help=(
|
|
|
|
"Can be a path to a .bin file, a .safetensors file or a model name from the HuggingFace Hub. Default:"
|
|
|
|
" stabilityai/stable-diffusion-2-1-unclip"
|
|
|
|
),
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--subfolder",
|
|
|
|
type=str,
|
|
|
|
dest="subfolder",
|
|
|
|
default="image_encoder",
|
|
|
|
help="Subfolder in the source path where the model is located inside the Hub. Default: image_encoder",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--to",
|
|
|
|
type=str,
|
|
|
|
dest="output_path",
|
|
|
|
default=None,
|
|
|
|
help=(
|
|
|
|
"Output path (.safetensors) for converted model. If not provided, the output path will be the same as the"
|
|
|
|
" source path."
|
|
|
|
),
|
|
|
|
)
|
2024-01-16 14:15:52 +00:00
|
|
|
parser.add_argument("--half", action="store_true", help="Convert to half precision.")
|
2023-08-30 08:05:31 +00:00
|
|
|
parser.add_argument(
|
|
|
|
"--verbose",
|
|
|
|
action="store_true",
|
|
|
|
default=False,
|
|
|
|
help="Prints additional information during conversion. Default: False",
|
|
|
|
)
|
2023-09-07 12:46:15 +00:00
|
|
|
parser.add_argument("--threshold", type=float, default=1e-2, help="Threshold for model comparison. Default: 1e-2")
|
2023-08-30 08:05:31 +00:00
|
|
|
args = parser.parse_args(namespace=Args())
|
|
|
|
if args.output_path is None:
|
|
|
|
args.output_path = f"{Path(args.source_path).stem}-{args.subfolder}.safetensors"
|
|
|
|
converter = setup_converter(args=args)
|
|
|
|
# Do not use converter.save_to_safetensors since it is not in a valid state due to the ad hoc conversion
|
|
|
|
state_dict = converter.target_model.state_dict()
|
|
|
|
if args.half:
|
|
|
|
state_dict = {key: value.half() for key, value in state_dict.items()}
|
|
|
|
save_to_safetensors(path=args.output_path, tensors=state_dict)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|