mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 00:58:13 +00:00
51 lines
1.4 KiB
Python
51 lines
1.4 KiB
Python
|
import torch
|
||
|
|
||
|
from safetensors.torch import save_file
|
||
|
from refiners.fluxion.utils import (
|
||
|
create_state_dict_mapping,
|
||
|
convert_state_dict,
|
||
|
)
|
||
|
|
||
|
from diffusers import DiffusionPipeline
|
||
|
from transformers.models.clip.modeling_clip import CLIPTextModel
|
||
|
|
||
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]:
|
||
|
dst_model = CLIPTextEncoderL()
|
||
|
x = dst_model.tokenizer("Nice cat", sequence_length=77)
|
||
|
mapping = create_state_dict_mapping(src_model, dst_model, [x])
|
||
|
state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping)
|
||
|
return {k: v.half() for k, v in state_dict.items()}
|
||
|
|
||
|
|
||
|
def main():
|
||
|
import argparse
|
||
|
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument(
|
||
|
"--from",
|
||
|
type=str,
|
||
|
dest="source",
|
||
|
required=False,
|
||
|
default="runwayml/stable-diffusion-v1-5",
|
||
|
help="Source model",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--output-file",
|
||
|
type=str,
|
||
|
required=False,
|
||
|
default="CLIPTextEncoderL.safetensors",
|
||
|
help="Path for the output file",
|
||
|
)
|
||
|
args = parser.parse_args()
|
||
|
src_model = DiffusionPipeline.from_pretrained(args.source).text_encoder
|
||
|
tensors = convert(src_model)
|
||
|
save_file(tensors, args.output_file)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|