diff --git a/scripts/prepare_test_weights.py b/scripts/prepare_test_weights.py index ce42c77..d57acac 100644 --- a/scripts/prepare_test_weights.py +++ b/scripts/prepare_test_weights.py @@ -52,6 +52,7 @@ def download_file( dry_run: bool | None = None, skip_existing: bool = True, expected_hash: str | None = None, + filename: str | None = None, ): """ Downloads a file @@ -65,7 +66,7 @@ def download_file( """ global download_count, bytes_count - filename = os.path.basename(urlparse(url).path) + filename = os.path.basename(urlparse(url).path) if filename is None else filename dest_filename = os.path.join(dest_folder, filename) temp_filename = dest_filename + ".part" dry_run = bool(os.environ.get("DRY_RUN") == "1") if dry_run is None else dry_run @@ -364,6 +365,24 @@ def download_dinov2(): download_files(urls, base_folder) +def download_lcm_base(): + base_folder = os.path.join(test_weights_dir, "latent-consistency/lcm-sdxl") + download_file(f"https://huggingface.co/latent-consistency/lcm-sdxl/raw/main/config.json", base_folder) + download_file( + f"https://huggingface.co/latent-consistency/lcm-sdxl/resolve/main/diffusion_pytorch_model.safetensors", + base_folder, + ) + + +def download_lcm_lora(): + download_file( + "https://huggingface.co/latent-consistency/lcm-lora-sdxl/resolve/main/pytorch_lora_weights.safetensors", + dest_folder=test_weights_dir, + filename="sdxl-lcm-lora.safetensors", + expected_hash="6312a30a", + ) + + def printg(msg: str): """print in green color""" print("\033[92m" + msg + "\033[0m") @@ -653,6 +672,16 @@ def convert_control_lora_fooocus(): ) +def convert_lcm_base(): + run_conversion_script( + "convert_diffusers_unet.py", + "tests/weights/latent-consistency/lcm-sdxl", + "tests/weights/sdxl-lcm-unet.safetensors", + half=True, + expected_hash="242cf440", + ) + + def download_all(): print(f"\nAll weights will be downloaded to {test_weights_dir}\n") download_sd15("runwayml/stable-diffusion-v1-5") @@ -669,6 +698,8 @@ def download_all(): download_sam() download_dinov2() download_control_lora_fooocus() + download_lcm_base() + download_lcm_lora() def convert_all(): @@ -685,6 +716,7 @@ def convert_all(): convert_sam() convert_dinov2() convert_control_lora_fooocus() + convert_lcm_base() def main():