From 7d8e3fc1db4c38f8a8947f6119d478233f983c42 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Fri, 23 Feb 2024 17:34:50 +0100 Subject: [PATCH] add SDXL-Lightning weights to conversion script + support safetensors --- CONTRIBUTING.md | 2 +- scripts/conversion/convert_diffusers_unet.py | 8 ++- scripts/prepare_test_weights.py | 51 ++++++++++++++++++++ 3 files changed, 59 insertions(+), 2 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 200e10d..ffcb7b8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -45,7 +45,7 @@ First, install test dependencies with: rye sync --all-features ``` -Then, download and convert all the necessary weights. Be aware that this will use around 80 GB of disk space: +Then, download and convert all the necessary weights. Be aware that this will use around 100 GB of disk space: ```bash python scripts/prepare_test_weights.py diff --git a/scripts/conversion/convert_diffusers_unet.py b/scripts/conversion/convert_diffusers_unet.py index 014db37..610ec88 100644 --- a/scripts/conversion/convert_diffusers_unet.py +++ b/scripts/conversion/convert_diffusers_unet.py @@ -7,6 +7,7 @@ from diffusers import UNet2DConditionModel # type: ignore from torch import nn from refiners.fluxion.model_converter import ModelConverter +from refiners.fluxion.utils import load_from_safetensors, load_tensors from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import SDXLLcmAdapter @@ -28,7 +29,12 @@ def setup_converter(args: Args) -> ModelConverter: low_cpu_mem_usage=False, ) if args.override_weights is not None: - sd = torch.load(args.override_weights) # type: ignore + if args.override_weights.endswith(".pth"): + sd = load_tensors(args.override_weights) + elif args.override_weights.endswith(".safetensors"): + sd = load_from_safetensors(args.override_weights) + else: + raise ValueError(f"Unsupported file format: {args.override_weights}") source.load_state_dict(sd) source_in_channels: int = source.config.in_channels # type: ignore source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore diff --git a/scripts/prepare_test_weights.py b/scripts/prepare_test_weights.py index 6d41982..0d50ae3 100644 --- a/scripts/prepare_test_weights.py +++ b/scripts/prepare_test_weights.py @@ -393,6 +393,28 @@ def download_lcm_lora(): ) +def download_sdxl_lightning_base(): + base_folder = os.path.join(test_weights_dir, "ByteDance/SDXL-Lightning") + download_file( + f"https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_4step_unet.safetensors", + base_folder, + expected_hash="1b76cca3", + ) + download_file( + f"https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_1step_unet_x0.safetensors", + base_folder, + expected_hash="38e605bd", + ) + + +def download_sdxl_lightning_lora(): + download_file( + "https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_4step_lora.safetensors", + dest_folder=test_weights_dir, + expected_hash="9783edac", + ) + + def printg(msg: str): """print in green color""" print("\033[92m" + msg + "\033[0m") @@ -692,6 +714,32 @@ def convert_lcm_base(): ) +def convert_sdxl_lightning_base(): + run_conversion_script( + "convert_diffusers_unet.py", + "tests/weights/stabilityai/stable-diffusion-xl-base-1.0", + "tests/weights/sdxl_lightning_4step_unet.safetensors", + additional_args=[ + "--override-weights", + "tests/weights/ByteDance/SDXL-Lightning/sdxl_lightning_4step_unet.safetensors", + ], + half=True, + expected_hash="cfdc46da", + ) + + run_conversion_script( + "convert_diffusers_unet.py", + "tests/weights/stabilityai/stable-diffusion-xl-base-1.0", + "tests/weights/sdxl_lightning_1step_unet_x0.safetensors", + additional_args=[ + "--override-weights", + "tests/weights/ByteDance/SDXL-Lightning/sdxl_lightning_1step_unet_x0.safetensors", + ], + half=True, + expected_hash="21166a64", + ) + + def download_all(): print(f"\nAll weights will be downloaded to {test_weights_dir}\n") download_sd15("runwayml/stable-diffusion-v1-5") @@ -710,6 +758,8 @@ def download_all(): download_control_lora_fooocus() download_lcm_base() download_lcm_lora() + download_sdxl_lightning_base() + download_sdxl_lightning_lora() def convert_all(): @@ -727,6 +777,7 @@ def convert_all(): convert_dinov2() convert_control_lora_fooocus() convert_lcm_base() + convert_sdxl_lightning_base() def main():