From 052a20b897b405c3b0df2149ff9ed9480020dc15 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Tue, 5 Mar 2024 16:46:28 +0100 Subject: [PATCH] remove add_multiple_loras --- docs/guides/adapting_sdxl/index.md | 33 ++++++------------- .../foundationals/latent_diffusion/lora.py | 19 ----------- tests/e2e/test_diffusion.py | 6 ++-- .../latent_diffusion/test_lora_manager.py | 9 +++-- 4 files changed, 20 insertions(+), 47 deletions(-) diff --git a/docs/guides/adapting_sdxl/index.md b/docs/guides/adapting_sdxl/index.md index 5e50080..5ac1a96 100644 --- a/docs/guides/adapting_sdxl/index.md +++ b/docs/guides/adapting_sdxl/index.md @@ -251,11 +251,8 @@ This is dead simple as [`SDLoraManager`][refiners.foundationals.latent_diffusion ```py # Load LoRAs weights from disk and inject them into target manager = SDLoraManager(sdxl) -scifi_lora_weights = load_from_safetensors("Sci-fi_Environments_sdxl.safetensors") -pixel_art_lora_weights = load_from_safetensors("pixel-art-xl-v1.1.safetensors") -manager.add_multiple_loras( - {"scifi-lora": scifi_lora_weights, "pixel-art-lora": pixel_art_lora_weights} -) +manager.add_loras("scifi-lora", load_from_safetensors("Sci-fi_Environments_sdxl.safetensors")) +manager.add_loras("pixel-art-lora", load_from_safetensors("pixel-art-xl-v1.1.safetensors")) ``` Adapters such as LoRAs also have a [scale][refiners.fluxion.adapters.Lora.scale] (roughly) quantifying the effect of this Adapter. @@ -264,12 +261,8 @@ Refiners allows setting different scales for each Adapter, allowing the user to ```py # Load LoRAs weights from disk and inject them into target manager = SDLoraManager(sdxl) -scifi_lora_weights = load_from_safetensors("Sci-fi_Environments_sdxl.safetensors") -pixel_art_lora_weights = load_from_safetensors("pixel-art-xl-v1.1.safetensors") -manager.add_multiple_loras( - tensors={"scifi-lora": scifi_lora_weights, "pixel-art-lora": pixel_art_lora_weights}, - scale={"scifi-lora": 1.0, "pixel-art-lora": 1.4}, -) +manager.add_loras("scifi-lora", load_from_safetensors("Sci-fi_Environments_sdxl.safetensors"), scale=1.0) +manager.add_loras("pixel-art-lora", load_from_safetensors("pixel-art-xl-v1.1.safetensors"), scale=1.4) ``` ??? example "Expand to see the entire end-to-end code" @@ -291,10 +284,8 @@ manager.add_multiple_loras( manager = SDLoraManager(sdxl) scifi_lora_weights = load_from_safetensors("Sci-fi_Environments_sdxl.safetensors") pixel_art_lora_weights = load_from_safetensors("pixel-art-xl-v1.1.safetensors") - manager.add_multiple_loras( - tensors={"scifi-lora": scifi_lora_weights, "pixel-art-lora": pixel_art_lora_weights}, - scale={"scifi-lora": 1.0, "pixel-art-lora": 1.4}, - ) + manager.add_loras("scifi-lora", scifi_lora_weights, scale=1.0) + manager.add_loras("pixel-art-lora", pixel_art_lora_weights, scale=1.4) # Hyperparameters prompt = "a futuristic castle surrounded by a forest, mountains in the background" @@ -416,10 +407,8 @@ with torch.no_grad(): manager = SDLoraManager(sdxl) scifi_lora_weights = load_from_safetensors("Sci-fi_Environments_sdxl.safetensors") pixel_art_lora_weights = load_from_safetensors("pixel-art-xl-v1.1.safetensors") - manager.add_multiple_loras( - tensors={"scifi-lora": scifi_lora_weights, "pixel-art-lora": pixel_art_lora_weights}, - scale={"scifi-lora": 1.5, "pixel-art-lora": 1.55}, - ) + manager.add_loras("scifi-lora", scifi_lora_weights, scale=1.5) + manager.add_loras("pixel-art-lora", pixel_art_lora_weights, scale=1.55) # Load IP-Adapter ip_adapter = SDXLIPAdapter( @@ -543,10 +532,8 @@ with torch.no_grad(): manager = SDLoraManager(sdxl) scifi_lora_weights = load_from_safetensors("Sci-fi_Environments_sdxl.safetensors") pixel_art_lora_weights = load_from_safetensors("pixel-art-xl-v1.1.safetensors") - manager.add_multiple_loras( - tensors={"scifi-lora": scifi_lora_weights, "pixel-art-lora": pixel_art_lora_weights}, - scale={"scifi-lora": 1.5, "pixel-art-lora": 1.55}, - ) + manager.add_loras("scifi-lora", scifi_lora_weights, scale=1.5) + manager.add_loras("pixel-art-lora", pixel_art_lora_weights, scale=1.55) # Load IP-Adapter ip_adapter = SDXLIPAdapter( diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index 862b1dc..94e231f 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -95,25 +95,6 @@ class SDLoraManager: # set the scale of the LoRA self.set_scale(name, scale) - def add_multiple_loras( - self, - /, - tensors: dict[str, dict[str, Tensor]], - scale: dict[str, float] | None = None, - ) -> None: - """Load multiple LoRAs from a dictionary of `state_dict`. - - Args: - tensors: The dictionary of `state_dict` of the LoRAs to load - (keys are the names of the LoRAs, values are the `state_dict` of the LoRAs). - scale: The scales to use for the LoRAs. - - Raises: - AssertionError: If the manager already has a LoRA with the same name. - """ - for name, lora_tensors in tensors.items(): - self.add_loras(name, tensors=lora_tensors, scale=scale[name] if scale else 1.0) - def add_loras_to_text_encoder(self, loras: dict[str, Lora[Any]], /) -> None: """Add multiple LoRAs to the text encoder. diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 23582e8..69465e0 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -335,7 +335,6 @@ def lora_sliders(test_weights_path: Path) -> tuple[dict[str, dict[str, torch.Ten }, { "age": 0.3, "cartoon_style": -0.2, - "dpo": 1.4, "eyesize": -0.2, } @@ -1436,7 +1435,10 @@ def test_diffusion_sdxl_multiple_loras( loras, scales = lora_sliders loras["dpo"] = dpo - SDLoraManager(sdxl).add_multiple_loras(loras, scales) + manager = SDLoraManager(sdxl) + for lora_name, lora_weights in loras.items(): + manager.add_loras(lora_name, lora_weights, scales[lora_name]) + manager.add_loras("dpo", dpo, 1.4) # parameters are the same as https://huggingface.co/radames/sdxl-DPO-LoRA # except that we are using DDIM instead of sde-dpmsolver++ diff --git a/tests/foundationals/latent_diffusion/test_lora_manager.py b/tests/foundationals/latent_diffusion/test_lora_manager.py index a126934..4e30ce8 100644 --- a/tests/foundationals/latent_diffusion/test_lora_manager.py +++ b/tests/foundationals/latent_diffusion/test_lora_manager.py @@ -36,13 +36,15 @@ def test_add_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> def test_add_multiple_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None: - manager.add_multiple_loras({"pokemon-lora": weights, "pokemon-lora2": weights}) + manager.add_loras("pokemon-lora", weights) + manager.add_loras("pokemon-lora2", weights) assert "pokemon-lora" in manager.names assert "pokemon-lora2" in manager.names def test_remove_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None: - manager.add_multiple_loras({"pokemon-lora": weights, "pokemon-lora2": weights}) + manager.add_loras("pokemon-lora", weights) + manager.add_loras("pokemon-lora2", weights) manager.remove_loras("pokemon-lora") assert "pokemon-lora" not in manager.names assert "pokemon-lora2" in manager.names @@ -53,7 +55,8 @@ def test_remove_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) def test_remove_all(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None: - manager.add_multiple_loras({"pokemon-lora": weights, "pokemon-lora2": weights}) + manager.add_loras("pokemon-lora", weights) + manager.add_loras("pokemon-lora2", weights) manager.remove_all() assert len(manager.names) == 0