mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
add SDXL-Lightning weights to conversion script + support safetensors
This commit is contained in:
parent
7f51d18045
commit
7d8e3fc1db
|
@ -45,7 +45,7 @@ First, install test dependencies with:
|
||||||
rye sync --all-features
|
rye sync --all-features
|
||||||
```
|
```
|
||||||
|
|
||||||
Then, download and convert all the necessary weights. Be aware that this will use around 80 GB of disk space:
|
Then, download and convert all the necessary weights. Be aware that this will use around 100 GB of disk space:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python scripts/prepare_test_weights.py
|
python scripts/prepare_test_weights.py
|
||||||
|
|
|
@ -7,6 +7,7 @@ from diffusers import UNet2DConditionModel # type: ignore
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from refiners.fluxion.model_converter import ModelConverter
|
from refiners.fluxion.model_converter import ModelConverter
|
||||||
|
from refiners.fluxion.utils import load_from_safetensors, load_tensors
|
||||||
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
|
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import SDXLLcmAdapter
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import SDXLLcmAdapter
|
||||||
|
|
||||||
|
@ -28,7 +29,12 @@ def setup_converter(args: Args) -> ModelConverter:
|
||||||
low_cpu_mem_usage=False,
|
low_cpu_mem_usage=False,
|
||||||
)
|
)
|
||||||
if args.override_weights is not None:
|
if args.override_weights is not None:
|
||||||
sd = torch.load(args.override_weights) # type: ignore
|
if args.override_weights.endswith(".pth"):
|
||||||
|
sd = load_tensors(args.override_weights)
|
||||||
|
elif args.override_weights.endswith(".safetensors"):
|
||||||
|
sd = load_from_safetensors(args.override_weights)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported file format: {args.override_weights}")
|
||||||
source.load_state_dict(sd)
|
source.load_state_dict(sd)
|
||||||
source_in_channels: int = source.config.in_channels # type: ignore
|
source_in_channels: int = source.config.in_channels # type: ignore
|
||||||
source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore
|
source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore
|
||||||
|
|
|
@ -393,6 +393,28 @@ def download_lcm_lora():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def download_sdxl_lightning_base():
|
||||||
|
base_folder = os.path.join(test_weights_dir, "ByteDance/SDXL-Lightning")
|
||||||
|
download_file(
|
||||||
|
f"https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_4step_unet.safetensors",
|
||||||
|
base_folder,
|
||||||
|
expected_hash="1b76cca3",
|
||||||
|
)
|
||||||
|
download_file(
|
||||||
|
f"https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_1step_unet_x0.safetensors",
|
||||||
|
base_folder,
|
||||||
|
expected_hash="38e605bd",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def download_sdxl_lightning_lora():
|
||||||
|
download_file(
|
||||||
|
"https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_4step_lora.safetensors",
|
||||||
|
dest_folder=test_weights_dir,
|
||||||
|
expected_hash="9783edac",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def printg(msg: str):
|
def printg(msg: str):
|
||||||
"""print in green color"""
|
"""print in green color"""
|
||||||
print("\033[92m" + msg + "\033[0m")
|
print("\033[92m" + msg + "\033[0m")
|
||||||
|
@ -692,6 +714,32 @@ def convert_lcm_base():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_sdxl_lightning_base():
|
||||||
|
run_conversion_script(
|
||||||
|
"convert_diffusers_unet.py",
|
||||||
|
"tests/weights/stabilityai/stable-diffusion-xl-base-1.0",
|
||||||
|
"tests/weights/sdxl_lightning_4step_unet.safetensors",
|
||||||
|
additional_args=[
|
||||||
|
"--override-weights",
|
||||||
|
"tests/weights/ByteDance/SDXL-Lightning/sdxl_lightning_4step_unet.safetensors",
|
||||||
|
],
|
||||||
|
half=True,
|
||||||
|
expected_hash="cfdc46da",
|
||||||
|
)
|
||||||
|
|
||||||
|
run_conversion_script(
|
||||||
|
"convert_diffusers_unet.py",
|
||||||
|
"tests/weights/stabilityai/stable-diffusion-xl-base-1.0",
|
||||||
|
"tests/weights/sdxl_lightning_1step_unet_x0.safetensors",
|
||||||
|
additional_args=[
|
||||||
|
"--override-weights",
|
||||||
|
"tests/weights/ByteDance/SDXL-Lightning/sdxl_lightning_1step_unet_x0.safetensors",
|
||||||
|
],
|
||||||
|
half=True,
|
||||||
|
expected_hash="21166a64",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def download_all():
|
def download_all():
|
||||||
print(f"\nAll weights will be downloaded to {test_weights_dir}\n")
|
print(f"\nAll weights will be downloaded to {test_weights_dir}\n")
|
||||||
download_sd15("runwayml/stable-diffusion-v1-5")
|
download_sd15("runwayml/stable-diffusion-v1-5")
|
||||||
|
@ -710,6 +758,8 @@ def download_all():
|
||||||
download_control_lora_fooocus()
|
download_control_lora_fooocus()
|
||||||
download_lcm_base()
|
download_lcm_base()
|
||||||
download_lcm_lora()
|
download_lcm_lora()
|
||||||
|
download_sdxl_lightning_base()
|
||||||
|
download_sdxl_lightning_lora()
|
||||||
|
|
||||||
|
|
||||||
def convert_all():
|
def convert_all():
|
||||||
|
@ -727,6 +777,7 @@ def convert_all():
|
||||||
convert_dinov2()
|
convert_dinov2()
|
||||||
convert_control_lora_fooocus()
|
convert_control_lora_fooocus()
|
||||||
convert_lcm_base()
|
convert_lcm_base()
|
||||||
|
convert_sdxl_lightning_base()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
Loading…
Reference in a new issue