pick the right class for CLIP text converter

i.e. CLIPTextModel by default or CLIPTextModelWithProjection for SDXL
so-called text_encoder_2

This silent false positive warnings like:

    Some weights of CLIPTextModelWithProjection were not initialized
    from the model checkpoint [...]
This commit is contained in:
Cédric Deltheil 2024-01-17 17:15:46 +01:00 committed by Cédric Deltheil
parent a6a9c8b972
commit dd87b9706e

View file

@ -3,7 +3,7 @@ from pathlib import Path
from typing import cast
from torch import nn
from transformers import CLIPTextModelWithProjection # type: ignore
from transformers import CLIPTextModel, CLIPTextModelWithProjection # type: ignore
import refiners.fluxion.layers as fl
from refiners.fluxion.model_converter import ModelConverter
@ -21,9 +21,10 @@ class Args(argparse.Namespace):
verbose: bool
def setup_converter(args: Args) -> ModelConverter:
def setup_converter(args: Args, with_projection: bool = False) -> ModelConverter:
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
source: nn.Module = CLIPTextModelWithProjection.from_pretrained( # type: ignore
cls = CLIPTextModelWithProjection if with_projection else CLIPTextModel
source: nn.Module = cls.from_pretrained( # type: ignore
pretrained_model_name_or_path=args.source_path,
subfolder=args.subfolder,
low_cpu_mem_usage=False,
@ -36,6 +37,7 @@ def setup_converter(args: Args) -> ModelConverter:
num_attention_heads: int = source.config.num_attention_heads # type: ignore
feed_forward_dim: int = source.config.intermediate_size # type: ignore
use_quick_gelu: bool = source.config.hidden_act == "quick_gelu" # type: ignore
assert architecture in ("CLIPTextModel", "CLIPTextModelWithProjection"), f"Unsupported architecture: {architecture}"
target = CLIPTextEncoder(
embedding_dim=embedding_dim,
num_layers=num_layers,
@ -43,13 +45,8 @@ def setup_converter(args: Args) -> ModelConverter:
feedforward_dim=feed_forward_dim,
use_quick_gelu=use_quick_gelu,
)
match architecture:
case "CLIPTextModel":
source.text_projection = fl.Identity()
case "CLIPTextModelWithProjection":
target.append(module=fl.Linear(in_features=embedding_dim, out_features=projection_dim, bias=False))
case _:
raise RuntimeError(f"Unsupported architecture: {architecture}")
if architecture == "CLIPTextModelWithProjection":
target.append(module=fl.Linear(in_features=embedding_dim, out_features=projection_dim, bias=False))
text = "What a nice cat you have there!"
tokenizer = target.ensure_find(CLIPTokenizer)
tokens = tokenizer(text)
@ -114,7 +111,7 @@ def main() -> None:
if args.subfolder2 is not None:
# Assume this is the second text encoder of Stable Diffusion XL
args.subfolder = args.subfolder2
converter2 = setup_converter(args=args)
converter2 = setup_converter(args=args, with_projection=True)
text_encoder_l = CLIPTextEncoderL()
text_encoder_l.load_state_dict(state_dict=converter.get_state_dict())