mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 15:02:01 +00:00
add LCM and LCM-LoRA to tests weights conversion script
This commit is contained in:
parent
12b6829a26
commit
b55e9332fe
|
@ -52,6 +52,7 @@ def download_file(
|
|||
dry_run: bool | None = None,
|
||||
skip_existing: bool = True,
|
||||
expected_hash: str | None = None,
|
||||
filename: str | None = None,
|
||||
):
|
||||
"""
|
||||
Downloads a file
|
||||
|
@ -65,7 +66,7 @@ def download_file(
|
|||
|
||||
"""
|
||||
global download_count, bytes_count
|
||||
filename = os.path.basename(urlparse(url).path)
|
||||
filename = os.path.basename(urlparse(url).path) if filename is None else filename
|
||||
dest_filename = os.path.join(dest_folder, filename)
|
||||
temp_filename = dest_filename + ".part"
|
||||
dry_run = bool(os.environ.get("DRY_RUN") == "1") if dry_run is None else dry_run
|
||||
|
@ -364,6 +365,24 @@ def download_dinov2():
|
|||
download_files(urls, base_folder)
|
||||
|
||||
|
||||
def download_lcm_base():
|
||||
base_folder = os.path.join(test_weights_dir, "latent-consistency/lcm-sdxl")
|
||||
download_file(f"https://huggingface.co/latent-consistency/lcm-sdxl/raw/main/config.json", base_folder)
|
||||
download_file(
|
||||
f"https://huggingface.co/latent-consistency/lcm-sdxl/resolve/main/diffusion_pytorch_model.safetensors",
|
||||
base_folder,
|
||||
)
|
||||
|
||||
|
||||
def download_lcm_lora():
|
||||
download_file(
|
||||
"https://huggingface.co/latent-consistency/lcm-lora-sdxl/resolve/main/pytorch_lora_weights.safetensors",
|
||||
dest_folder=test_weights_dir,
|
||||
filename="sdxl-lcm-lora.safetensors",
|
||||
expected_hash="6312a30a",
|
||||
)
|
||||
|
||||
|
||||
def printg(msg: str):
|
||||
"""print in green color"""
|
||||
print("\033[92m" + msg + "\033[0m")
|
||||
|
@ -653,6 +672,16 @@ def convert_control_lora_fooocus():
|
|||
)
|
||||
|
||||
|
||||
def convert_lcm_base():
|
||||
run_conversion_script(
|
||||
"convert_diffusers_unet.py",
|
||||
"tests/weights/latent-consistency/lcm-sdxl",
|
||||
"tests/weights/sdxl-lcm-unet.safetensors",
|
||||
half=True,
|
||||
expected_hash="242cf440",
|
||||
)
|
||||
|
||||
|
||||
def download_all():
|
||||
print(f"\nAll weights will be downloaded to {test_weights_dir}\n")
|
||||
download_sd15("runwayml/stable-diffusion-v1-5")
|
||||
|
@ -669,6 +698,8 @@ def download_all():
|
|||
download_sam()
|
||||
download_dinov2()
|
||||
download_control_lora_fooocus()
|
||||
download_lcm_base()
|
||||
download_lcm_lora()
|
||||
|
||||
|
||||
def convert_all():
|
||||
|
@ -685,6 +716,7 @@ def convert_all():
|
|||
convert_sam()
|
||||
convert_dinov2()
|
||||
convert_control_lora_fooocus()
|
||||
convert_lcm_base()
|
||||
|
||||
|
||||
def main():
|
||||
|
|
Loading…
Reference in a new issue