mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
make clip text converter support SDXL
i.e. convert the 2nd text encoder and save the final double text encoder
This commit is contained in:
parent
be54cfc016
commit
cc3b20320d
|
@ -1,10 +1,13 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
from torch import nn
|
||||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from transformers import CLIPTextModelWithProjection # type: ignore
|
||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder
|
||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, CLIPTextEncoderL, CLIPTextEncoderG
|
||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||
from refiners.fluxion.utils import save_to_safetensors
|
||||
import refiners.fluxion.layers as fl
|
||||
|
||||
|
||||
|
@ -76,6 +79,13 @@ def main() -> None:
|
|||
" CLIPTextModel)"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--subfolder2",
|
||||
type=str,
|
||||
dest="subfolder2",
|
||||
default=None,
|
||||
help="Additional subfolder for the 2nd text encoder (useful for SDXL). Default: None",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
type=str,
|
||||
|
@ -86,7 +96,7 @@ def main() -> None:
|
|||
" source path."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--half", action="store_true", default=True, help="Convert to half precision. Default: True")
|
||||
parser.add_argument("--half", action="store_true", help="Convert to half precision. Default: True")
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
|
@ -97,7 +107,32 @@ def main() -> None:
|
|||
if args.output_path is None:
|
||||
args.output_path = f"{Path(args.source_path).stem}-{args.subfolder}.safetensors"
|
||||
converter = setup_converter(args=args)
|
||||
converter.save_to_safetensors(path=args.output_path, half=args.half)
|
||||
if args.subfolder2 is not None:
|
||||
# Assume this is the second text encoder of Stable Diffusion XL
|
||||
args.subfolder = args.subfolder2
|
||||
converter2 = setup_converter(args=args)
|
||||
|
||||
text_encoder_l = CLIPTextEncoderL()
|
||||
text_encoder_l.load_state_dict(state_dict=converter.get_state_dict())
|
||||
|
||||
projection = cast(CLIPTextEncoder, converter2.target_model)[-1]
|
||||
assert isinstance(projection, fl.Linear)
|
||||
text_encoder_g_with_projection = CLIPTextEncoderG()
|
||||
text_encoder_g_with_projection.append(module=projection)
|
||||
text_encoder_g_with_projection.load_state_dict(state_dict=converter2.get_state_dict())
|
||||
|
||||
projection = text_encoder_g_with_projection.pop(index=-1)
|
||||
assert isinstance(projection, fl.Linear)
|
||||
double_text_encoder = DoubleTextEncoder(
|
||||
text_encoder_l=text_encoder_l, text_encoder_g=text_encoder_g_with_projection, projection=projection
|
||||
)
|
||||
|
||||
state_dict = double_text_encoder.state_dict()
|
||||
if args.half:
|
||||
state_dict = {key: value.half() for key, value in state_dict.items()}
|
||||
save_to_safetensors(path=args.output_path, tensors=state_dict)
|
||||
else:
|
||||
converter.save_to_safetensors(path=args.output_path, half=args.half)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in a new issue