2023-08-24 00:26:37 +00:00
|
|
|
import argparse
|
|
|
|
from pathlib import Path
|
2023-12-11 10:46:38 +00:00
|
|
|
|
2023-08-24 00:26:37 +00:00
|
|
|
import torch
|
|
|
|
from diffusers import AutoencoderKL # type: ignore
|
2023-12-11 10:46:38 +00:00
|
|
|
from torch import nn
|
|
|
|
|
2023-08-24 00:26:37 +00:00
|
|
|
from refiners.fluxion.model_converter import ModelConverter
|
2023-12-11 10:46:38 +00:00
|
|
|
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
2023-08-24 00:26:37 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Args(argparse.Namespace):
|
|
|
|
source_path: str
|
|
|
|
output_path: str | None
|
|
|
|
use_half: bool
|
|
|
|
verbose: bool
|
|
|
|
|
|
|
|
|
|
|
|
def setup_converter(args: Args) -> ModelConverter:
|
|
|
|
target = LatentDiffusionAutoencoder()
|
2023-12-14 03:53:06 +00:00
|
|
|
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
2023-12-08 11:26:50 +00:00
|
|
|
source: nn.Module = AutoencoderKL.from_pretrained( # type: ignore
|
2023-12-14 03:53:06 +00:00
|
|
|
pretrained_model_name_or_path=args.source_path,
|
|
|
|
subfolder=args.subfolder,
|
|
|
|
low_cpu_mem_usage=False,
|
2023-12-08 11:26:50 +00:00
|
|
|
) # type: ignore
|
2023-08-24 00:26:37 +00:00
|
|
|
x = torch.randn(1, 3, 512, 512)
|
|
|
|
converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose)
|
|
|
|
if not converter.run(source_args=(x,)):
|
|
|
|
raise RuntimeError("Model conversion failed")
|
|
|
|
return converter
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
description="Convert a pretrained diffusers AutoencoderKL model to a refiners Latent Diffusion Autoencoder"
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--from",
|
|
|
|
type=str,
|
|
|
|
dest="source_path",
|
|
|
|
default="runwayml/stable-diffusion-v1-5",
|
|
|
|
help="Path to the source pretrained model (default: 'runwayml/stable-diffusion-v1-5').",
|
|
|
|
)
|
2023-09-29 16:50:22 +00:00
|
|
|
parser.add_argument(
|
|
|
|
"--subfolder",
|
|
|
|
type=str,
|
|
|
|
dest="subfolder",
|
|
|
|
default="vae",
|
|
|
|
help="Subfolder in the source path where the model is located inside the Hub (default: 'vae')",
|
|
|
|
)
|
2023-08-24 00:26:37 +00:00
|
|
|
parser.add_argument(
|
|
|
|
"--to",
|
|
|
|
type=str,
|
|
|
|
dest="output_path",
|
|
|
|
default=None,
|
|
|
|
help=(
|
|
|
|
"Path to save the converted model (extension will be .safetensors). If not specified, the output path will"
|
|
|
|
" be the source path with the extension changed to .safetensors."
|
|
|
|
),
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--half",
|
|
|
|
action="store_true",
|
|
|
|
dest="use_half",
|
|
|
|
default=False,
|
|
|
|
help="Use this flag to save the output file as half precision (default: full precision).",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--verbose",
|
|
|
|
action="store_true",
|
|
|
|
dest="verbose",
|
|
|
|
default=False,
|
|
|
|
help="Use this flag to print verbose output during conversion.",
|
|
|
|
)
|
|
|
|
args = parser.parse_args(namespace=Args())
|
|
|
|
if args.output_path is None:
|
|
|
|
args.output_path = f"{Path(args.source_path).stem}-autoencoder.safetensors"
|
|
|
|
assert args.output_path is not None
|
|
|
|
converter = setup_converter(args=args)
|
|
|
|
converter.save_to_safetensors(path=args.output_path, half=args.use_half)
|