mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
pick the right class for CLIP text converter
i.e. CLIPTextModel by default or CLIPTextModelWithProjection for SDXL so-called text_encoder_2 This silent false positive warnings like: Some weights of CLIPTextModelWithProjection were not initialized from the model checkpoint [...]
This commit is contained in:
parent
a6a9c8b972
commit
dd87b9706e
|
@ -3,7 +3,7 @@ from pathlib import Path
|
|||
from typing import cast
|
||||
|
||||
from torch import nn
|
||||
from transformers import CLIPTextModelWithProjection # type: ignore
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection # type: ignore
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.model_converter import ModelConverter
|
||||
|
@ -21,9 +21,10 @@ class Args(argparse.Namespace):
|
|||
verbose: bool
|
||||
|
||||
|
||||
def setup_converter(args: Args) -> ModelConverter:
|
||||
def setup_converter(args: Args, with_projection: bool = False) -> ModelConverter:
|
||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
||||
source: nn.Module = CLIPTextModelWithProjection.from_pretrained( # type: ignore
|
||||
cls = CLIPTextModelWithProjection if with_projection else CLIPTextModel
|
||||
source: nn.Module = cls.from_pretrained( # type: ignore
|
||||
pretrained_model_name_or_path=args.source_path,
|
||||
subfolder=args.subfolder,
|
||||
low_cpu_mem_usage=False,
|
||||
|
@ -36,6 +37,7 @@ def setup_converter(args: Args) -> ModelConverter:
|
|||
num_attention_heads: int = source.config.num_attention_heads # type: ignore
|
||||
feed_forward_dim: int = source.config.intermediate_size # type: ignore
|
||||
use_quick_gelu: bool = source.config.hidden_act == "quick_gelu" # type: ignore
|
||||
assert architecture in ("CLIPTextModel", "CLIPTextModelWithProjection"), f"Unsupported architecture: {architecture}"
|
||||
target = CLIPTextEncoder(
|
||||
embedding_dim=embedding_dim,
|
||||
num_layers=num_layers,
|
||||
|
@ -43,13 +45,8 @@ def setup_converter(args: Args) -> ModelConverter:
|
|||
feedforward_dim=feed_forward_dim,
|
||||
use_quick_gelu=use_quick_gelu,
|
||||
)
|
||||
match architecture:
|
||||
case "CLIPTextModel":
|
||||
source.text_projection = fl.Identity()
|
||||
case "CLIPTextModelWithProjection":
|
||||
target.append(module=fl.Linear(in_features=embedding_dim, out_features=projection_dim, bias=False))
|
||||
case _:
|
||||
raise RuntimeError(f"Unsupported architecture: {architecture}")
|
||||
if architecture == "CLIPTextModelWithProjection":
|
||||
target.append(module=fl.Linear(in_features=embedding_dim, out_features=projection_dim, bias=False))
|
||||
text = "What a nice cat you have there!"
|
||||
tokenizer = target.ensure_find(CLIPTokenizer)
|
||||
tokens = tokenizer(text)
|
||||
|
@ -114,7 +111,7 @@ def main() -> None:
|
|||
if args.subfolder2 is not None:
|
||||
# Assume this is the second text encoder of Stable Diffusion XL
|
||||
args.subfolder = args.subfolder2
|
||||
converter2 = setup_converter(args=args)
|
||||
converter2 = setup_converter(args=args, with_projection=True)
|
||||
|
||||
text_encoder_l = CLIPTextEncoderL()
|
||||
text_encoder_l.load_state_dict(state_dict=converter.get_state_dict())
|
||||
|
|
Loading…
Reference in a new issue