mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
pick the right class for CLIP text converter
i.e. CLIPTextModel by default or CLIPTextModelWithProjection for SDXL so-called text_encoder_2 This silent false positive warnings like: Some weights of CLIPTextModelWithProjection were not initialized from the model checkpoint [...]
This commit is contained in:
parent
a6a9c8b972
commit
dd87b9706e
|
@ -3,7 +3,7 @@ from pathlib import Path
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import CLIPTextModelWithProjection # type: ignore
|
from transformers import CLIPTextModel, CLIPTextModelWithProjection # type: ignore
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.model_converter import ModelConverter
|
from refiners.fluxion.model_converter import ModelConverter
|
||||||
|
@ -21,9 +21,10 @@ class Args(argparse.Namespace):
|
||||||
verbose: bool
|
verbose: bool
|
||||||
|
|
||||||
|
|
||||||
def setup_converter(args: Args) -> ModelConverter:
|
def setup_converter(args: Args, with_projection: bool = False) -> ModelConverter:
|
||||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
||||||
source: nn.Module = CLIPTextModelWithProjection.from_pretrained( # type: ignore
|
cls = CLIPTextModelWithProjection if with_projection else CLIPTextModel
|
||||||
|
source: nn.Module = cls.from_pretrained( # type: ignore
|
||||||
pretrained_model_name_or_path=args.source_path,
|
pretrained_model_name_or_path=args.source_path,
|
||||||
subfolder=args.subfolder,
|
subfolder=args.subfolder,
|
||||||
low_cpu_mem_usage=False,
|
low_cpu_mem_usage=False,
|
||||||
|
@ -36,6 +37,7 @@ def setup_converter(args: Args) -> ModelConverter:
|
||||||
num_attention_heads: int = source.config.num_attention_heads # type: ignore
|
num_attention_heads: int = source.config.num_attention_heads # type: ignore
|
||||||
feed_forward_dim: int = source.config.intermediate_size # type: ignore
|
feed_forward_dim: int = source.config.intermediate_size # type: ignore
|
||||||
use_quick_gelu: bool = source.config.hidden_act == "quick_gelu" # type: ignore
|
use_quick_gelu: bool = source.config.hidden_act == "quick_gelu" # type: ignore
|
||||||
|
assert architecture in ("CLIPTextModel", "CLIPTextModelWithProjection"), f"Unsupported architecture: {architecture}"
|
||||||
target = CLIPTextEncoder(
|
target = CLIPTextEncoder(
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
num_layers=num_layers,
|
num_layers=num_layers,
|
||||||
|
@ -43,13 +45,8 @@ def setup_converter(args: Args) -> ModelConverter:
|
||||||
feedforward_dim=feed_forward_dim,
|
feedforward_dim=feed_forward_dim,
|
||||||
use_quick_gelu=use_quick_gelu,
|
use_quick_gelu=use_quick_gelu,
|
||||||
)
|
)
|
||||||
match architecture:
|
if architecture == "CLIPTextModelWithProjection":
|
||||||
case "CLIPTextModel":
|
target.append(module=fl.Linear(in_features=embedding_dim, out_features=projection_dim, bias=False))
|
||||||
source.text_projection = fl.Identity()
|
|
||||||
case "CLIPTextModelWithProjection":
|
|
||||||
target.append(module=fl.Linear(in_features=embedding_dim, out_features=projection_dim, bias=False))
|
|
||||||
case _:
|
|
||||||
raise RuntimeError(f"Unsupported architecture: {architecture}")
|
|
||||||
text = "What a nice cat you have there!"
|
text = "What a nice cat you have there!"
|
||||||
tokenizer = target.ensure_find(CLIPTokenizer)
|
tokenizer = target.ensure_find(CLIPTokenizer)
|
||||||
tokens = tokenizer(text)
|
tokens = tokenizer(text)
|
||||||
|
@ -114,7 +111,7 @@ def main() -> None:
|
||||||
if args.subfolder2 is not None:
|
if args.subfolder2 is not None:
|
||||||
# Assume this is the second text encoder of Stable Diffusion XL
|
# Assume this is the second text encoder of Stable Diffusion XL
|
||||||
args.subfolder = args.subfolder2
|
args.subfolder = args.subfolder2
|
||||||
converter2 = setup_converter(args=args)
|
converter2 = setup_converter(args=args, with_projection=True)
|
||||||
|
|
||||||
text_encoder_l = CLIPTextEncoderL()
|
text_encoder_l = CLIPTextEncoderL()
|
||||||
text_encoder_l.load_state_dict(state_dict=converter.get_state_dict())
|
text_encoder_l.load_state_dict(state_dict=converter.get_state_dict())
|
||||||
|
|
Loading…
Reference in a new issue