mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 14:18:46 +00:00
make clip text converter support SDXL
i.e. convert the 2nd text encoder and save the final double text encoder
This commit is contained in:
parent
be54cfc016
commit
cc3b20320d
|
@ -1,10 +1,13 @@
|
||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import cast
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from refiners.fluxion.model_converter import ModelConverter
|
from refiners.fluxion.model_converter import ModelConverter
|
||||||
from transformers import CLIPTextModelWithProjection # type: ignore
|
from transformers import CLIPTextModelWithProjection # type: ignore
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, CLIPTextEncoderL, CLIPTextEncoderG
|
||||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||||
|
from refiners.fluxion.utils import save_to_safetensors
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
|
||||||
|
@ -76,6 +79,13 @@ def main() -> None:
|
||||||
" CLIPTextModel)"
|
" CLIPTextModel)"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--subfolder2",
|
||||||
|
type=str,
|
||||||
|
dest="subfolder2",
|
||||||
|
default=None,
|
||||||
|
help="Additional subfolder for the 2nd text encoder (useful for SDXL). Default: None",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--to",
|
"--to",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -86,7 +96,7 @@ def main() -> None:
|
||||||
" source path."
|
" source path."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--half", action="store_true", default=True, help="Convert to half precision. Default: True")
|
parser.add_argument("--half", action="store_true", help="Convert to half precision. Default: True")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--verbose",
|
"--verbose",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
@ -97,7 +107,32 @@ def main() -> None:
|
||||||
if args.output_path is None:
|
if args.output_path is None:
|
||||||
args.output_path = f"{Path(args.source_path).stem}-{args.subfolder}.safetensors"
|
args.output_path = f"{Path(args.source_path).stem}-{args.subfolder}.safetensors"
|
||||||
converter = setup_converter(args=args)
|
converter = setup_converter(args=args)
|
||||||
converter.save_to_safetensors(path=args.output_path, half=args.half)
|
if args.subfolder2 is not None:
|
||||||
|
# Assume this is the second text encoder of Stable Diffusion XL
|
||||||
|
args.subfolder = args.subfolder2
|
||||||
|
converter2 = setup_converter(args=args)
|
||||||
|
|
||||||
|
text_encoder_l = CLIPTextEncoderL()
|
||||||
|
text_encoder_l.load_state_dict(state_dict=converter.get_state_dict())
|
||||||
|
|
||||||
|
projection = cast(CLIPTextEncoder, converter2.target_model)[-1]
|
||||||
|
assert isinstance(projection, fl.Linear)
|
||||||
|
text_encoder_g_with_projection = CLIPTextEncoderG()
|
||||||
|
text_encoder_g_with_projection.append(module=projection)
|
||||||
|
text_encoder_g_with_projection.load_state_dict(state_dict=converter2.get_state_dict())
|
||||||
|
|
||||||
|
projection = text_encoder_g_with_projection.pop(index=-1)
|
||||||
|
assert isinstance(projection, fl.Linear)
|
||||||
|
double_text_encoder = DoubleTextEncoder(
|
||||||
|
text_encoder_l=text_encoder_l, text_encoder_g=text_encoder_g_with_projection, projection=projection
|
||||||
|
)
|
||||||
|
|
||||||
|
state_dict = double_text_encoder.state_dict()
|
||||||
|
if args.half:
|
||||||
|
state_dict = {key: value.half() for key, value in state_dict.items()}
|
||||||
|
save_to_safetensors(path=args.output_path, tensors=state_dict)
|
||||||
|
else:
|
||||||
|
converter.save_to_safetensors(path=args.output_path, half=args.half)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in a new issue