make clip text converter support SDXL

i.e. convert the 2nd text encoder and save the final double text encoder
This commit is contained in:
Cédric Deltheil 2023-09-11 16:02:22 +02:00 committed by Cédric Deltheil
parent be54cfc016
commit cc3b20320d

View file

@ -1,10 +1,13 @@
import argparse
from pathlib import Path
from typing import cast
from torch import nn
from refiners.fluxion.model_converter import ModelConverter
from transformers import CLIPTextModelWithProjection # type: ignore
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, CLIPTextEncoderL, CLIPTextEncoderG
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
from refiners.fluxion.utils import save_to_safetensors
import refiners.fluxion.layers as fl
@ -76,6 +79,13 @@ def main() -> None:
" CLIPTextModel)"
),
)
parser.add_argument(
"--subfolder2",
type=str,
dest="subfolder2",
default=None,
help="Additional subfolder for the 2nd text encoder (useful for SDXL). Default: None",
)
parser.add_argument(
"--to",
type=str,
@ -86,7 +96,7 @@ def main() -> None:
" source path."
),
)
parser.add_argument("--half", action="store_true", default=True, help="Convert to half precision. Default: True")
parser.add_argument("--half", action="store_true", help="Convert to half precision. Default: True")
parser.add_argument(
"--verbose",
action="store_true",
@ -97,6 +107,31 @@ def main() -> None:
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}-{args.subfolder}.safetensors"
converter = setup_converter(args=args)
if args.subfolder2 is not None:
# Assume this is the second text encoder of Stable Diffusion XL
args.subfolder = args.subfolder2
converter2 = setup_converter(args=args)
text_encoder_l = CLIPTextEncoderL()
text_encoder_l.load_state_dict(state_dict=converter.get_state_dict())
projection = cast(CLIPTextEncoder, converter2.target_model)[-1]
assert isinstance(projection, fl.Linear)
text_encoder_g_with_projection = CLIPTextEncoderG()
text_encoder_g_with_projection.append(module=projection)
text_encoder_g_with_projection.load_state_dict(state_dict=converter2.get_state_dict())
projection = text_encoder_g_with_projection.pop(index=-1)
assert isinstance(projection, fl.Linear)
double_text_encoder = DoubleTextEncoder(
text_encoder_l=text_encoder_l, text_encoder_g=text_encoder_g_with_projection, projection=projection
)
state_dict = double_text_encoder.state_dict()
if args.half:
state_dict = {key: value.half() for key, value in state_dict.items()}
save_to_safetensors(path=args.output_path, tensors=state_dict)
else:
converter.save_to_safetensors(path=args.output_path, half=args.half)