remove add_multiple_loras

This commit is contained in:
Pierre Chapuis 2024-03-05 16:46:28 +01:00
parent c383ff6cf4
commit 052a20b897
4 changed files with 20 additions and 47 deletions

View file

@ -251,11 +251,8 @@ This is dead simple as [`SDLoraManager`][refiners.foundationals.latent_diffusion
```py
# Load LoRAs weights from disk and inject them into target
manager = SDLoraManager(sdxl)
scifi_lora_weights = load_from_safetensors("Sci-fi_Environments_sdxl.safetensors")
pixel_art_lora_weights = load_from_safetensors("pixel-art-xl-v1.1.safetensors")
manager.add_multiple_loras(
{"scifi-lora": scifi_lora_weights, "pixel-art-lora": pixel_art_lora_weights}
)
manager.add_loras("scifi-lora", load_from_safetensors("Sci-fi_Environments_sdxl.safetensors"))
manager.add_loras("pixel-art-lora", load_from_safetensors("pixel-art-xl-v1.1.safetensors"))
```
Adapters such as LoRAs also have a [scale][refiners.fluxion.adapters.Lora.scale] (roughly) quantifying the effect of this Adapter.
@ -264,12 +261,8 @@ Refiners allows setting different scales for each Adapter, allowing the user to
```py
# Load LoRAs weights from disk and inject them into target
manager = SDLoraManager(sdxl)
scifi_lora_weights = load_from_safetensors("Sci-fi_Environments_sdxl.safetensors")
pixel_art_lora_weights = load_from_safetensors("pixel-art-xl-v1.1.safetensors")
manager.add_multiple_loras(
tensors={"scifi-lora": scifi_lora_weights, "pixel-art-lora": pixel_art_lora_weights},
scale={"scifi-lora": 1.0, "pixel-art-lora": 1.4},
)
manager.add_loras("scifi-lora", load_from_safetensors("Sci-fi_Environments_sdxl.safetensors"), scale=1.0)
manager.add_loras("pixel-art-lora", load_from_safetensors("pixel-art-xl-v1.1.safetensors"), scale=1.4)
```
??? example "Expand to see the entire end-to-end code"
@ -291,10 +284,8 @@ manager.add_multiple_loras(
manager = SDLoraManager(sdxl)
scifi_lora_weights = load_from_safetensors("Sci-fi_Environments_sdxl.safetensors")
pixel_art_lora_weights = load_from_safetensors("pixel-art-xl-v1.1.safetensors")
manager.add_multiple_loras(
tensors={"scifi-lora": scifi_lora_weights, "pixel-art-lora": pixel_art_lora_weights},
scale={"scifi-lora": 1.0, "pixel-art-lora": 1.4},
)
manager.add_loras("scifi-lora", scifi_lora_weights, scale=1.0)
manager.add_loras("pixel-art-lora", pixel_art_lora_weights, scale=1.4)
# Hyperparameters
prompt = "a futuristic castle surrounded by a forest, mountains in the background"
@ -416,10 +407,8 @@ with torch.no_grad():
manager = SDLoraManager(sdxl)
scifi_lora_weights = load_from_safetensors("Sci-fi_Environments_sdxl.safetensors")
pixel_art_lora_weights = load_from_safetensors("pixel-art-xl-v1.1.safetensors")
manager.add_multiple_loras(
tensors={"scifi-lora": scifi_lora_weights, "pixel-art-lora": pixel_art_lora_weights},
scale={"scifi-lora": 1.5, "pixel-art-lora": 1.55},
)
manager.add_loras("scifi-lora", scifi_lora_weights, scale=1.5)
manager.add_loras("pixel-art-lora", pixel_art_lora_weights, scale=1.55)
# Load IP-Adapter
ip_adapter = SDXLIPAdapter(
@ -543,10 +532,8 @@ with torch.no_grad():
manager = SDLoraManager(sdxl)
scifi_lora_weights = load_from_safetensors("Sci-fi_Environments_sdxl.safetensors")
pixel_art_lora_weights = load_from_safetensors("pixel-art-xl-v1.1.safetensors")
manager.add_multiple_loras(
tensors={"scifi-lora": scifi_lora_weights, "pixel-art-lora": pixel_art_lora_weights},
scale={"scifi-lora": 1.5, "pixel-art-lora": 1.55},
)
manager.add_loras("scifi-lora", scifi_lora_weights, scale=1.5)
manager.add_loras("pixel-art-lora", pixel_art_lora_weights, scale=1.55)
# Load IP-Adapter
ip_adapter = SDXLIPAdapter(

View file

@ -95,25 +95,6 @@ class SDLoraManager:
# set the scale of the LoRA
self.set_scale(name, scale)
def add_multiple_loras(
self,
/,
tensors: dict[str, dict[str, Tensor]],
scale: dict[str, float] | None = None,
) -> None:
"""Load multiple LoRAs from a dictionary of `state_dict`.
Args:
tensors: The dictionary of `state_dict` of the LoRAs to load
(keys are the names of the LoRAs, values are the `state_dict` of the LoRAs).
scale: The scales to use for the LoRAs.
Raises:
AssertionError: If the manager already has a LoRA with the same name.
"""
for name, lora_tensors in tensors.items():
self.add_loras(name, tensors=lora_tensors, scale=scale[name] if scale else 1.0)
def add_loras_to_text_encoder(self, loras: dict[str, Lora[Any]], /) -> None:
"""Add multiple LoRAs to the text encoder.

View file

@ -335,7 +335,6 @@ def lora_sliders(test_weights_path: Path) -> tuple[dict[str, dict[str, torch.Ten
}, {
"age": 0.3,
"cartoon_style": -0.2,
"dpo": 1.4,
"eyesize": -0.2,
}
@ -1436,7 +1435,10 @@ def test_diffusion_sdxl_multiple_loras(
loras, scales = lora_sliders
loras["dpo"] = dpo
SDLoraManager(sdxl).add_multiple_loras(loras, scales)
manager = SDLoraManager(sdxl)
for lora_name, lora_weights in loras.items():
manager.add_loras(lora_name, lora_weights, scales[lora_name])
manager.add_loras("dpo", dpo, 1.4)
# parameters are the same as https://huggingface.co/radames/sdxl-DPO-LoRA
# except that we are using DDIM instead of sde-dpmsolver++

View file

@ -36,13 +36,15 @@ def test_add_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) ->
def test_add_multiple_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
manager.add_multiple_loras({"pokemon-lora": weights, "pokemon-lora2": weights})
manager.add_loras("pokemon-lora", weights)
manager.add_loras("pokemon-lora2", weights)
assert "pokemon-lora" in manager.names
assert "pokemon-lora2" in manager.names
def test_remove_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
manager.add_multiple_loras({"pokemon-lora": weights, "pokemon-lora2": weights})
manager.add_loras("pokemon-lora", weights)
manager.add_loras("pokemon-lora2", weights)
manager.remove_loras("pokemon-lora")
assert "pokemon-lora" not in manager.names
assert "pokemon-lora2" in manager.names
@ -53,7 +55,8 @@ def test_remove_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor])
def test_remove_all(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
manager.add_multiple_loras({"pokemon-lora": weights, "pokemon-lora2": weights})
manager.add_loras("pokemon-lora", weights)
manager.add_loras("pokemon-lora2", weights)
manager.remove_all()
assert len(manager.names) == 0