add SDXL-Lightning weights to conversion script + support safetensors

This commit is contained in:
Pierre Chapuis 2024-02-23 17:34:50 +01:00
parent 7f51d18045
commit 7d8e3fc1db
3 changed files with 59 additions and 2 deletions

View file

@ -45,7 +45,7 @@ First, install test dependencies with:
rye sync --all-features rye sync --all-features
``` ```
Then, download and convert all the necessary weights. Be aware that this will use around 80 GB of disk space: Then, download and convert all the necessary weights. Be aware that this will use around 100 GB of disk space:
```bash ```bash
python scripts/prepare_test_weights.py python scripts/prepare_test_weights.py

View file

@ -7,6 +7,7 @@ from diffusers import UNet2DConditionModel # type: ignore
from torch import nn from torch import nn
from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import load_from_safetensors, load_tensors
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import SDXLLcmAdapter from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import SDXLLcmAdapter
@ -28,7 +29,12 @@ def setup_converter(args: Args) -> ModelConverter:
low_cpu_mem_usage=False, low_cpu_mem_usage=False,
) )
if args.override_weights is not None: if args.override_weights is not None:
sd = torch.load(args.override_weights) # type: ignore if args.override_weights.endswith(".pth"):
sd = load_tensors(args.override_weights)
elif args.override_weights.endswith(".safetensors"):
sd = load_from_safetensors(args.override_weights)
else:
raise ValueError(f"Unsupported file format: {args.override_weights}")
source.load_state_dict(sd) source.load_state_dict(sd)
source_in_channels: int = source.config.in_channels # type: ignore source_in_channels: int = source.config.in_channels # type: ignore
source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore

View file

@ -393,6 +393,28 @@ def download_lcm_lora():
) )
def download_sdxl_lightning_base():
base_folder = os.path.join(test_weights_dir, "ByteDance/SDXL-Lightning")
download_file(
f"https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_4step_unet.safetensors",
base_folder,
expected_hash="1b76cca3",
)
download_file(
f"https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_1step_unet_x0.safetensors",
base_folder,
expected_hash="38e605bd",
)
def download_sdxl_lightning_lora():
download_file(
"https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_4step_lora.safetensors",
dest_folder=test_weights_dir,
expected_hash="9783edac",
)
def printg(msg: str): def printg(msg: str):
"""print in green color""" """print in green color"""
print("\033[92m" + msg + "\033[0m") print("\033[92m" + msg + "\033[0m")
@ -692,6 +714,32 @@ def convert_lcm_base():
) )
def convert_sdxl_lightning_base():
run_conversion_script(
"convert_diffusers_unet.py",
"tests/weights/stabilityai/stable-diffusion-xl-base-1.0",
"tests/weights/sdxl_lightning_4step_unet.safetensors",
additional_args=[
"--override-weights",
"tests/weights/ByteDance/SDXL-Lightning/sdxl_lightning_4step_unet.safetensors",
],
half=True,
expected_hash="cfdc46da",
)
run_conversion_script(
"convert_diffusers_unet.py",
"tests/weights/stabilityai/stable-diffusion-xl-base-1.0",
"tests/weights/sdxl_lightning_1step_unet_x0.safetensors",
additional_args=[
"--override-weights",
"tests/weights/ByteDance/SDXL-Lightning/sdxl_lightning_1step_unet_x0.safetensors",
],
half=True,
expected_hash="21166a64",
)
def download_all(): def download_all():
print(f"\nAll weights will be downloaded to {test_weights_dir}\n") print(f"\nAll weights will be downloaded to {test_weights_dir}\n")
download_sd15("runwayml/stable-diffusion-v1-5") download_sd15("runwayml/stable-diffusion-v1-5")
@ -710,6 +758,8 @@ def download_all():
download_control_lora_fooocus() download_control_lora_fooocus()
download_lcm_base() download_lcm_base()
download_lcm_lora() download_lcm_lora()
download_sdxl_lightning_base()
download_sdxl_lightning_lora()
def convert_all(): def convert_all():
@ -727,6 +777,7 @@ def convert_all():
convert_dinov2() convert_dinov2()
convert_control_lora_fooocus() convert_control_lora_fooocus()
convert_lcm_base() convert_lcm_base()
convert_sdxl_lightning_base()
def main(): def main():