From 2106c237d99e410307646664f438abb3ac9e40ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Sun, 24 Sep 2023 16:21:35 +0200 Subject: [PATCH] add T2I-Adapter conversion script --- .../convert_diffusers_t2i_adapter.py | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 scripts/conversion/convert_diffusers_t2i_adapter.py diff --git a/scripts/conversion/convert_diffusers_t2i_adapter.py b/scripts/conversion/convert_diffusers_t2i_adapter.py new file mode 100644 index 0000000..814fd22 --- /dev/null +++ b/scripts/conversion/convert_diffusers_t2i_adapter.py @@ -0,0 +1,57 @@ +import argparse +from pathlib import Path +import torch +from torch import nn +from diffusers import T2IAdapter # type: ignore +from refiners.foundationals.latent_diffusion.t2i_adapter import ConditionEncoder, ConditionEncoderXL +from refiners.fluxion.model_converter import ModelConverter + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert a pretrained diffusers T2I-Adapter model to refiners") + parser.add_argument( + "--from", + type=str, + dest="source_path", + required=True, + help="Path or repository name of the source model. (e.g.: 'ip-adapter_sd15.bin').", + ) + parser.add_argument( + "--to", + type=str, + dest="output_path", + default=None, + help=( + "Path to save the converted model (extension will be .safetensors). If not specified, the output path will" + " be the source path with the extension changed to .safetensors." + ), + ) + parser.add_argument( + "--half", + action="store_true", + dest="use_half", + default=False, + help="Use this flag to save the output file as half precision (default: full precision).", + ) + parser.add_argument( + "--verbose", + action="store_true", + dest="verbose", + default=False, + help="Use this flag to print verbose output during conversion.", + ) + args = parser.parse_args() + if args.output_path is None: + args.output_path = f"{Path(args.source_path).name}.safetensors" + assert args.output_path is not None + + sdxl = "xl" in args.source_path + target = ConditionEncoderXL() if sdxl else ConditionEncoder() + source: nn.Module = T2IAdapter.from_pretrained(pretrained_model_name_or_path=args.source_path) # type: ignore + assert isinstance(source, nn.Module), "Source model is not a nn.Module" + + x = torch.randn(1, 3, 1024, 1024) if sdxl else torch.randn(1, 3, 512, 512) + converter = ModelConverter(source_model=source, target_model=target, verbose=args.verbose) + if not converter.run(source_args=(x,)): + raise RuntimeError("Model conversion failed") + + converter.save_to_safetensors(path=args.output_path, half=args.use_half)