add LCM and LCM-LoRA to tests weights conversion script

This commit is contained in:
Pierre Chapuis 2024-02-16 13:17:39 +01:00
parent 12b6829a26
commit b55e9332fe

View file

@ -52,6 +52,7 @@ def download_file(
dry_run: bool | None = None,
skip_existing: bool = True,
expected_hash: str | None = None,
filename: str | None = None,
):
"""
Downloads a file
@ -65,7 +66,7 @@ def download_file(
"""
global download_count, bytes_count
filename = os.path.basename(urlparse(url).path)
filename = os.path.basename(urlparse(url).path) if filename is None else filename
dest_filename = os.path.join(dest_folder, filename)
temp_filename = dest_filename + ".part"
dry_run = bool(os.environ.get("DRY_RUN") == "1") if dry_run is None else dry_run
@ -364,6 +365,24 @@ def download_dinov2():
download_files(urls, base_folder)
def download_lcm_base():
base_folder = os.path.join(test_weights_dir, "latent-consistency/lcm-sdxl")
download_file(f"https://huggingface.co/latent-consistency/lcm-sdxl/raw/main/config.json", base_folder)
download_file(
f"https://huggingface.co/latent-consistency/lcm-sdxl/resolve/main/diffusion_pytorch_model.safetensors",
base_folder,
)
def download_lcm_lora():
download_file(
"https://huggingface.co/latent-consistency/lcm-lora-sdxl/resolve/main/pytorch_lora_weights.safetensors",
dest_folder=test_weights_dir,
filename="sdxl-lcm-lora.safetensors",
expected_hash="6312a30a",
)
def printg(msg: str):
"""print in green color"""
print("\033[92m" + msg + "\033[0m")
@ -653,6 +672,16 @@ def convert_control_lora_fooocus():
)
def convert_lcm_base():
run_conversion_script(
"convert_diffusers_unet.py",
"tests/weights/latent-consistency/lcm-sdxl",
"tests/weights/sdxl-lcm-unet.safetensors",
half=True,
expected_hash="242cf440",
)
def download_all():
print(f"\nAll weights will be downloaded to {test_weights_dir}\n")
download_sd15("runwayml/stable-diffusion-v1-5")
@ -669,6 +698,8 @@ def download_all():
download_sam()
download_dinov2()
download_control_lora_fooocus()
download_lcm_base()
download_lcm_lora()
def convert_all():
@ -685,6 +716,7 @@ def convert_all():
convert_sam()
convert_dinov2()
convert_control_lora_fooocus()
convert_lcm_base()
def main():