mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
add SDXL-Lightning weights to conversion script + support safetensors
This commit is contained in:
parent
7f51d18045
commit
7d8e3fc1db
|
@ -45,7 +45,7 @@ First, install test dependencies with:
|
|||
rye sync --all-features
|
||||
```
|
||||
|
||||
Then, download and convert all the necessary weights. Be aware that this will use around 80 GB of disk space:
|
||||
Then, download and convert all the necessary weights. Be aware that this will use around 100 GB of disk space:
|
||||
|
||||
```bash
|
||||
python scripts/prepare_test_weights.py
|
||||
|
|
|
@ -7,6 +7,7 @@ from diffusers import UNet2DConditionModel # type: ignore
|
|||
from torch import nn
|
||||
|
||||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from refiners.fluxion.utils import load_from_safetensors, load_tensors
|
||||
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import SDXLLcmAdapter
|
||||
|
||||
|
@ -28,7 +29,12 @@ def setup_converter(args: Args) -> ModelConverter:
|
|||
low_cpu_mem_usage=False,
|
||||
)
|
||||
if args.override_weights is not None:
|
||||
sd = torch.load(args.override_weights) # type: ignore
|
||||
if args.override_weights.endswith(".pth"):
|
||||
sd = load_tensors(args.override_weights)
|
||||
elif args.override_weights.endswith(".safetensors"):
|
||||
sd = load_from_safetensors(args.override_weights)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file format: {args.override_weights}")
|
||||
source.load_state_dict(sd)
|
||||
source_in_channels: int = source.config.in_channels # type: ignore
|
||||
source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore
|
||||
|
|
|
@ -393,6 +393,28 @@ def download_lcm_lora():
|
|||
)
|
||||
|
||||
|
||||
def download_sdxl_lightning_base():
|
||||
base_folder = os.path.join(test_weights_dir, "ByteDance/SDXL-Lightning")
|
||||
download_file(
|
||||
f"https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_4step_unet.safetensors",
|
||||
base_folder,
|
||||
expected_hash="1b76cca3",
|
||||
)
|
||||
download_file(
|
||||
f"https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_1step_unet_x0.safetensors",
|
||||
base_folder,
|
||||
expected_hash="38e605bd",
|
||||
)
|
||||
|
||||
|
||||
def download_sdxl_lightning_lora():
|
||||
download_file(
|
||||
"https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_4step_lora.safetensors",
|
||||
dest_folder=test_weights_dir,
|
||||
expected_hash="9783edac",
|
||||
)
|
||||
|
||||
|
||||
def printg(msg: str):
|
||||
"""print in green color"""
|
||||
print("\033[92m" + msg + "\033[0m")
|
||||
|
@ -692,6 +714,32 @@ def convert_lcm_base():
|
|||
)
|
||||
|
||||
|
||||
def convert_sdxl_lightning_base():
|
||||
run_conversion_script(
|
||||
"convert_diffusers_unet.py",
|
||||
"tests/weights/stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"tests/weights/sdxl_lightning_4step_unet.safetensors",
|
||||
additional_args=[
|
||||
"--override-weights",
|
||||
"tests/weights/ByteDance/SDXL-Lightning/sdxl_lightning_4step_unet.safetensors",
|
||||
],
|
||||
half=True,
|
||||
expected_hash="cfdc46da",
|
||||
)
|
||||
|
||||
run_conversion_script(
|
||||
"convert_diffusers_unet.py",
|
||||
"tests/weights/stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"tests/weights/sdxl_lightning_1step_unet_x0.safetensors",
|
||||
additional_args=[
|
||||
"--override-weights",
|
||||
"tests/weights/ByteDance/SDXL-Lightning/sdxl_lightning_1step_unet_x0.safetensors",
|
||||
],
|
||||
half=True,
|
||||
expected_hash="21166a64",
|
||||
)
|
||||
|
||||
|
||||
def download_all():
|
||||
print(f"\nAll weights will be downloaded to {test_weights_dir}\n")
|
||||
download_sd15("runwayml/stable-diffusion-v1-5")
|
||||
|
@ -710,6 +758,8 @@ def download_all():
|
|||
download_control_lora_fooocus()
|
||||
download_lcm_base()
|
||||
download_lcm_lora()
|
||||
download_sdxl_lightning_base()
|
||||
download_sdxl_lightning_lora()
|
||||
|
||||
|
||||
def convert_all():
|
||||
|
@ -727,6 +777,7 @@ def convert_all():
|
|||
convert_dinov2()
|
||||
convert_control_lora_fooocus()
|
||||
convert_lcm_base()
|
||||
convert_sdxl_lightning_base()
|
||||
|
||||
|
||||
def main():
|
||||
|
|
Loading…
Reference in a new issue