remove add_multiple_loras

This commit is contained in:
Pierre Chapuis 2024-03-05 16:46:28 +01:00
parent c383ff6cf4
commit 052a20b897
4 changed files with 20 additions and 47 deletions

View file

@ -251,11 +251,8 @@ This is dead simple as [`SDLoraManager`][refiners.foundationals.latent_diffusion
```py ```py
# Load LoRAs weights from disk and inject them into target # Load LoRAs weights from disk and inject them into target
manager = SDLoraManager(sdxl) manager = SDLoraManager(sdxl)
scifi_lora_weights = load_from_safetensors("Sci-fi_Environments_sdxl.safetensors") manager.add_loras("scifi-lora", load_from_safetensors("Sci-fi_Environments_sdxl.safetensors"))
pixel_art_lora_weights = load_from_safetensors("pixel-art-xl-v1.1.safetensors") manager.add_loras("pixel-art-lora", load_from_safetensors("pixel-art-xl-v1.1.safetensors"))
manager.add_multiple_loras(
{"scifi-lora": scifi_lora_weights, "pixel-art-lora": pixel_art_lora_weights}
)
``` ```
Adapters such as LoRAs also have a [scale][refiners.fluxion.adapters.Lora.scale] (roughly) quantifying the effect of this Adapter. Adapters such as LoRAs also have a [scale][refiners.fluxion.adapters.Lora.scale] (roughly) quantifying the effect of this Adapter.
@ -264,12 +261,8 @@ Refiners allows setting different scales for each Adapter, allowing the user to
```py ```py
# Load LoRAs weights from disk and inject them into target # Load LoRAs weights from disk and inject them into target
manager = SDLoraManager(sdxl) manager = SDLoraManager(sdxl)
scifi_lora_weights = load_from_safetensors("Sci-fi_Environments_sdxl.safetensors") manager.add_loras("scifi-lora", load_from_safetensors("Sci-fi_Environments_sdxl.safetensors"), scale=1.0)
pixel_art_lora_weights = load_from_safetensors("pixel-art-xl-v1.1.safetensors") manager.add_loras("pixel-art-lora", load_from_safetensors("pixel-art-xl-v1.1.safetensors"), scale=1.4)
manager.add_multiple_loras(
tensors={"scifi-lora": scifi_lora_weights, "pixel-art-lora": pixel_art_lora_weights},
scale={"scifi-lora": 1.0, "pixel-art-lora": 1.4},
)
``` ```
??? example "Expand to see the entire end-to-end code" ??? example "Expand to see the entire end-to-end code"
@ -291,10 +284,8 @@ manager.add_multiple_loras(
manager = SDLoraManager(sdxl) manager = SDLoraManager(sdxl)
scifi_lora_weights = load_from_safetensors("Sci-fi_Environments_sdxl.safetensors") scifi_lora_weights = load_from_safetensors("Sci-fi_Environments_sdxl.safetensors")
pixel_art_lora_weights = load_from_safetensors("pixel-art-xl-v1.1.safetensors") pixel_art_lora_weights = load_from_safetensors("pixel-art-xl-v1.1.safetensors")
manager.add_multiple_loras( manager.add_loras("scifi-lora", scifi_lora_weights, scale=1.0)
tensors={"scifi-lora": scifi_lora_weights, "pixel-art-lora": pixel_art_lora_weights}, manager.add_loras("pixel-art-lora", pixel_art_lora_weights, scale=1.4)
scale={"scifi-lora": 1.0, "pixel-art-lora": 1.4},
)
# Hyperparameters # Hyperparameters
prompt = "a futuristic castle surrounded by a forest, mountains in the background" prompt = "a futuristic castle surrounded by a forest, mountains in the background"
@ -416,10 +407,8 @@ with torch.no_grad():
manager = SDLoraManager(sdxl) manager = SDLoraManager(sdxl)
scifi_lora_weights = load_from_safetensors("Sci-fi_Environments_sdxl.safetensors") scifi_lora_weights = load_from_safetensors("Sci-fi_Environments_sdxl.safetensors")
pixel_art_lora_weights = load_from_safetensors("pixel-art-xl-v1.1.safetensors") pixel_art_lora_weights = load_from_safetensors("pixel-art-xl-v1.1.safetensors")
manager.add_multiple_loras( manager.add_loras("scifi-lora", scifi_lora_weights, scale=1.5)
tensors={"scifi-lora": scifi_lora_weights, "pixel-art-lora": pixel_art_lora_weights}, manager.add_loras("pixel-art-lora", pixel_art_lora_weights, scale=1.55)
scale={"scifi-lora": 1.5, "pixel-art-lora": 1.55},
)
# Load IP-Adapter # Load IP-Adapter
ip_adapter = SDXLIPAdapter( ip_adapter = SDXLIPAdapter(
@ -543,10 +532,8 @@ with torch.no_grad():
manager = SDLoraManager(sdxl) manager = SDLoraManager(sdxl)
scifi_lora_weights = load_from_safetensors("Sci-fi_Environments_sdxl.safetensors") scifi_lora_weights = load_from_safetensors("Sci-fi_Environments_sdxl.safetensors")
pixel_art_lora_weights = load_from_safetensors("pixel-art-xl-v1.1.safetensors") pixel_art_lora_weights = load_from_safetensors("pixel-art-xl-v1.1.safetensors")
manager.add_multiple_loras( manager.add_loras("scifi-lora", scifi_lora_weights, scale=1.5)
tensors={"scifi-lora": scifi_lora_weights, "pixel-art-lora": pixel_art_lora_weights}, manager.add_loras("pixel-art-lora", pixel_art_lora_weights, scale=1.55)
scale={"scifi-lora": 1.5, "pixel-art-lora": 1.55},
)
# Load IP-Adapter # Load IP-Adapter
ip_adapter = SDXLIPAdapter( ip_adapter = SDXLIPAdapter(

View file

@ -95,25 +95,6 @@ class SDLoraManager:
# set the scale of the LoRA # set the scale of the LoRA
self.set_scale(name, scale) self.set_scale(name, scale)
def add_multiple_loras(
self,
/,
tensors: dict[str, dict[str, Tensor]],
scale: dict[str, float] | None = None,
) -> None:
"""Load multiple LoRAs from a dictionary of `state_dict`.
Args:
tensors: The dictionary of `state_dict` of the LoRAs to load
(keys are the names of the LoRAs, values are the `state_dict` of the LoRAs).
scale: The scales to use for the LoRAs.
Raises:
AssertionError: If the manager already has a LoRA with the same name.
"""
for name, lora_tensors in tensors.items():
self.add_loras(name, tensors=lora_tensors, scale=scale[name] if scale else 1.0)
def add_loras_to_text_encoder(self, loras: dict[str, Lora[Any]], /) -> None: def add_loras_to_text_encoder(self, loras: dict[str, Lora[Any]], /) -> None:
"""Add multiple LoRAs to the text encoder. """Add multiple LoRAs to the text encoder.

View file

@ -335,7 +335,6 @@ def lora_sliders(test_weights_path: Path) -> tuple[dict[str, dict[str, torch.Ten
}, { }, {
"age": 0.3, "age": 0.3,
"cartoon_style": -0.2, "cartoon_style": -0.2,
"dpo": 1.4,
"eyesize": -0.2, "eyesize": -0.2,
} }
@ -1436,7 +1435,10 @@ def test_diffusion_sdxl_multiple_loras(
loras, scales = lora_sliders loras, scales = lora_sliders
loras["dpo"] = dpo loras["dpo"] = dpo
SDLoraManager(sdxl).add_multiple_loras(loras, scales) manager = SDLoraManager(sdxl)
for lora_name, lora_weights in loras.items():
manager.add_loras(lora_name, lora_weights, scales[lora_name])
manager.add_loras("dpo", dpo, 1.4)
# parameters are the same as https://huggingface.co/radames/sdxl-DPO-LoRA # parameters are the same as https://huggingface.co/radames/sdxl-DPO-LoRA
# except that we are using DDIM instead of sde-dpmsolver++ # except that we are using DDIM instead of sde-dpmsolver++

View file

@ -36,13 +36,15 @@ def test_add_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) ->
def test_add_multiple_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None: def test_add_multiple_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
manager.add_multiple_loras({"pokemon-lora": weights, "pokemon-lora2": weights}) manager.add_loras("pokemon-lora", weights)
manager.add_loras("pokemon-lora2", weights)
assert "pokemon-lora" in manager.names assert "pokemon-lora" in manager.names
assert "pokemon-lora2" in manager.names assert "pokemon-lora2" in manager.names
def test_remove_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None: def test_remove_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
manager.add_multiple_loras({"pokemon-lora": weights, "pokemon-lora2": weights}) manager.add_loras("pokemon-lora", weights)
manager.add_loras("pokemon-lora2", weights)
manager.remove_loras("pokemon-lora") manager.remove_loras("pokemon-lora")
assert "pokemon-lora" not in manager.names assert "pokemon-lora" not in manager.names
assert "pokemon-lora2" in manager.names assert "pokemon-lora2" in manager.names
@ -53,7 +55,8 @@ def test_remove_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor])
def test_remove_all(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None: def test_remove_all(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
manager.add_multiple_loras({"pokemon-lora": weights, "pokemon-lora2": weights}) manager.add_loras("pokemon-lora", weights)
manager.add_loras("pokemon-lora2", weights)
manager.remove_all() manager.remove_all()
assert len(manager.names) == 0 assert len(manager.names) == 0