Load Multiple LoRAs with SDLoraManager

This commit is contained in:
limiteinductive 2024-01-22 14:45:34 +01:00 committed by Benjamin Trom
parent fb2f0e28d4
commit 421da6a3b6
8 changed files with 439 additions and 74 deletions

View file

@ -238,6 +238,11 @@ def download_loras():
"https://huggingface.co/radames/sdxl-DPO-LoRA/resolve/main/pytorch_lora_weights.safetensors", dest_folder
)
dest_folder = os.path.join(test_weights_dir, "loras", "sliders")
download_file("https://sliders.baulab.info/weights/xl_sliders/age.pt", dest_folder)
download_file("https://sliders.baulab.info/weights/xl_sliders/cartoon_style.pt", dest_folder)
download_file("https://sliders.baulab.info/weights/xl_sliders/eyesize.pt", dest_folder)
def download_preprocessors():
dest_folder = os.path.join(test_weights_dir, "carolineec", "informativedrawings")

View file

@ -1,5 +1,4 @@
from abc import ABC, abstractmethod
from typing import Any
from torch import Tensor, device as Device, dtype as DType
from torch.nn import Parameter as TorchParameter
@ -7,18 +6,20 @@ from torch.nn.init import normal_, zeros_
import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.layers.chain import Chain
class Lora(fl.Chain, ABC):
def __init__(
self,
name: str,
/,
rank: int = 16,
scale: float = 1.0,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.rank = rank
self.name = name
self._rank = rank
self._scale = scale
super().__init__(*self.lora_layers(device=device, dtype=dtype), fl.Multiply(scale))
@ -44,6 +45,10 @@ class Lora(fl.Chain, ABC):
assert isinstance(up_layer, fl.WeightedModule)
return up_layer
@property
def rank(self) -> int:
return self._rank
@property
def scale(self) -> float:
return self._scale
@ -56,19 +61,21 @@ class Lora(fl.Chain, ABC):
@classmethod
def from_weights(
cls,
name: str,
/,
down: Tensor,
up: Tensor,
) -> "Lora":
match (up.ndim, down.ndim):
case (2, 2):
return LinearLora.from_weights(up=up, down=down)
return LinearLora.from_weights(name, up=up, down=down)
case (4, 4):
return Conv2dLora.from_weights(up=up, down=down)
return Conv2dLora.from_weights(name, up=up, down=down)
case _:
raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}")
@classmethod
def from_dict(cls, state_dict: dict[str, Tensor], /) -> dict[str, "Lora"]:
def from_dict(cls, name: str, /, state_dict: dict[str, Tensor]) -> dict[str, "Lora"]:
"""
Create a dictionary of LoRA layers from a state dict.
@ -80,13 +87,37 @@ class Lora(fl.Chain, ABC):
list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2]
):
key = ".".join(down_key.split(".")[:-2])
loras[key] = cls.from_weights(down=down_tensor, up=up_tensor)
loras[key] = cls.from_weights(name, down=down_tensor, up=up_tensor)
return loras
@abstractmethod
def auto_attach(self, target: fl.Chain, exclude: list[str] | None = None) -> Any:
def is_compatible(self, layer: fl.WeightedModule, /) -> bool:
...
def auto_attach(
self, target: fl.Chain, exclude: list[str] | None = None
) -> "tuple[LoraAdapter, fl.Chain | None] | None":
for layer, parent in target.walk(self.up.__class__):
if isinstance(parent, Lora):
continue
if exclude is not None and any(
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
):
continue
if not self.is_compatible(layer):
continue
if isinstance(parent, LoraAdapter):
if self.name in parent.names:
continue
else:
parent.add_lora(self)
return parent, None
return LoraAdapter(layer, self), parent
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
assert down_weight.shape == self.down.weight.shape
assert up_weight.shape == self.up.weight.shape
@ -97,6 +128,8 @@ class Lora(fl.Chain, ABC):
class LinearLora(Lora):
def __init__(
self,
name: str,
/,
in_features: int,
out_features: int,
rank: int = 16,
@ -107,35 +140,29 @@ class LinearLora(Lora):
self.in_features = in_features
self.out_features = out_features
super().__init__(rank=rank, scale=scale, device=device, dtype=dtype)
super().__init__(name, rank=rank, scale=scale, device=device, dtype=dtype)
@classmethod
def from_weights(
cls,
name: str,
/,
down: Tensor,
up: Tensor,
) -> "LinearLora":
assert up.ndim == 2 and down.ndim == 2
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
lora = cls(
in_features=down.shape[1], out_features=up.shape[0], rank=down.shape[0], device=up.device, dtype=up.dtype
name,
in_features=down.shape[1],
out_features=up.shape[0],
rank=down.shape[0],
device=up.device,
dtype=up.dtype,
)
lora.load_weights(down_weight=down, up_weight=up)
return lora
def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None":
for layer, parent in target.walk(fl.Linear):
if isinstance(parent, Lora) or isinstance(parent, LoraAdapter):
continue
if exclude is not None and any(
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
):
continue
if layer.in_features == self.in_features and layer.out_features == self.out_features:
return LoraAdapter(target=layer, lora=self), parent
def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.Linear, fl.Linear]:
@ -156,10 +183,17 @@ class LinearLora(Lora):
),
)
def is_compatible(self, layer: fl.WeightedModule, /) -> bool:
if isinstance(layer, fl.Linear):
return layer.in_features == self.in_features and layer.out_features == self.out_features
return False
class Conv2dLora(Lora):
def __init__(
self,
name: str,
/,
in_channels: int,
out_channels: int,
rank: int = 16,
@ -176,20 +210,24 @@ class Conv2dLora(Lora):
self.stride = stride
self.padding = padding
super().__init__(rank=rank, scale=scale, device=device, dtype=dtype)
super().__init__(name, rank=rank, scale=scale, device=device, dtype=dtype)
@classmethod
def from_weights(
cls,
name: str,
/,
down: Tensor,
up: Tensor,
) -> "Conv2dLora":
assert up.ndim == 4 and down.ndim == 4
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
down_kernel_size, up_kernel_size = down.shape[2], up.shape[2]
# padding is set so the spatial dimensions are preserved (assuming stride=1 and kernel_size either 1 or 3)
down_padding = 1 if down_kernel_size == 3 else 0
up_padding = 1 if up_kernel_size == 3 else 0
lora = cls(
name,
in_channels=down.shape[1],
out_channels=up.shape[0],
rank=down.shape[0],
@ -201,25 +239,6 @@ class Conv2dLora(Lora):
lora.load_weights(down_weight=down, up_weight=up)
return lora
def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None":
for layer, parent in target.walk(fl.Conv2d):
if isinstance(parent, Lora) or isinstance(parent, LoraAdapter):
continue
if exclude is not None and any(
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
):
continue
if layer.in_channels == self.in_channels and layer.out_channels == self.out_channels:
if layer.stride != (self.stride[0], self.stride[0]):
self.down.stride = layer.stride
return LoraAdapter(
target=layer,
lora=self,
), parent
def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.Conv2d, fl.Conv2d]:
@ -246,20 +265,47 @@ class Conv2dLora(Lora):
),
)
def is_compatible(self, layer: fl.WeightedModule, /) -> bool:
if (
isinstance(layer, fl.Conv2d)
and layer.in_channels == self.in_channels
and layer.out_channels == self.out_channels
):
# stride cannot be inferred from the weights, so we assume it's the same as the layer
self.down.stride = layer.stride
return True
return False
class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
def __init__(self, target: fl.WeightedModule, lora: Lora) -> None:
def __init__(self, target: fl.WeightedModule, /, *loras: Lora) -> None:
with self.setup_adapter(target):
super().__init__(target, lora)
super().__init__(target, *loras)
@property
def lora(self) -> Lora:
return self.ensure_find(Lora)
def names(self) -> list[str]:
return [lora.name for lora in self.layers(Lora)]
@property
def scale(self) -> float:
return self.lora.scale
def loras(self) -> dict[str, Lora]:
return {lora.name: lora for lora in self.layers(Lora)}
@scale.setter
def scale(self, value: float) -> None:
self.lora.scale = value
@property
def scales(self) -> dict[str, float]:
return {lora.name: lora.scale for lora in self.layers(Lora)}
@scales.setter
def scale(self, values: dict[str, float]) -> None:
for name, value in values.items():
self.loras[name].scale = value
def add_lora(self, lora: Lora, /) -> None:
assert lora.name not in self.names, f"LoRA layer with name {lora.name} already exists"
self.append(lora)
def remove_lora(self, name: str, /) -> Lora | None:
if name in self.names:
lora = self.loras[name]
self.remove(lora)
return lora

View file

@ -26,19 +26,20 @@ class SDLoraManager:
assert isinstance(clip_text_encoder, fl.Chain)
return clip_text_encoder
def load(
def add_loras(
self,
tensors: dict[str, Tensor],
name: str,
/,
tensors: dict[str, Tensor],
scale: float = 1.0,
) -> None:
"""Load the LoRA weights from a dictionary of tensors.
Expects the keys to be in the commonly found formats on CivitAI's hub.
"""
assert len(self.lora_adapters) == 0, "Loras already loaded"
assert name not in self.names, f"LoRA {name} already exists"
loras = Lora.from_dict(
{key: value.to(device=self.target.device, dtype=self.target.dtype) for key, value in tensors.items()}
name, {key: value.to(device=self.target.device, dtype=self.target.dtype) for key, value in tensors.items()}
)
loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)}
@ -46,16 +47,25 @@ class SDLoraManager:
if not "unet" in loras and not "text" in loras:
loras = {f"unet_{key}": loras[key] for key in loras.keys()}
self.load_unet(loras)
self.load_text_encoder(loras)
self.add_loras_to_unet(loras)
self.add_loras_to_text_encoder(loras)
self.scale = scale
self.set_scale(name, scale)
def load_text_encoder(self, loras: dict[str, Lora], /) -> None:
def add_multiple_loras(
self,
/,
tensors: dict[str, dict[str, Tensor]],
scale: dict[str, float] | None = None,
) -> None:
for name, lora_tensors in tensors.items():
self.add_loras(name, tensors=lora_tensors, scale=scale[name] if scale else 1.0)
def add_loras_to_text_encoder(self, loras: dict[str, Lora], /) -> None:
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder)
def load_unet(self, loras: dict[str, Lora], /) -> None:
def add_loras_to_unet(self, loras: dict[str, Lora], /) -> None:
unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key}
exclude: list[str] = []
exclude = [
@ -65,14 +75,43 @@ class SDLoraManager:
]
SDLoraManager.auto_attach(unet_loras, self.unet, exclude=exclude)
def unload(self) -> None:
def remove_loras(self, *names: str) -> None:
for lora_adapter in self.lora_adapters:
for name in names:
lora_adapter.remove_lora(name)
if len(lora_adapter.loras) == 0:
lora_adapter.eject()
def remove_all(self) -> None:
for lora_adapter in self.lora_adapters:
lora_adapter.eject()
def get_loras_by_name(self, name: str, /) -> list[Lora]:
return [lora for lora in self.loras if lora.name == name]
def get_scale(self, name: str, /) -> float:
loras = self.get_loras_by_name(name)
assert all([lora.scale == loras[0].scale for lora in loras]), "lora scales are not all the same"
return loras[0].scale
def set_scale(self, name: str, scale: float, /) -> None:
self.update_scales({name: scale})
def update_scales(self, scales: dict[str, float], /) -> None:
assert all([name in self.names for name in scales]), f"Scales keys must be a subset of {self.names}"
for name, scale in scales.items():
for lora in self.get_loras_by_name(name):
lora.scale = scale
@property
def loras(self) -> list[Lora]:
return list(self.unet.layers(Lora)) + list(self.clip_text_encoder.layers(Lora))
@property
def names(self) -> list[str]:
return list(set(lora.name for lora in self.loras))
@property
def lora_adapters(self) -> list[LoraAdapter]:
return list(self.unet.layers(LoraAdapter)) + list(self.clip_text_encoder.layers(LoraAdapter))
@ -87,15 +126,8 @@ class SDLoraManager:
}
@property
def scale(self) -> float:
assert len(self.loras) > 0, "No loras found"
assert all([lora.scale == self.loras[0].scale for lora in self.loras])
return self.loras[0].scale
@scale.setter
def scale(self, value: float) -> None:
for lora in self.loras:
lora.scale = value
def scales(self) -> dict[str, float]:
return {name: self.get_scale(name) for name in self.names}
@staticmethod
def pad(input: str, /, padding_length: int = 2) -> str:
@ -130,7 +162,9 @@ class SDLoraManager:
for key, lora in loras.items():
if attach := lora.auto_attach(target, exclude=exclude):
adapter, parent = attach
adapter.inject(parent)
# if parent is None, `adapter` is already attached and `lora` has been added to it
if parent is not None:
adapter.inject(parent)
else:
failed_loras[key] = lora

118
tests/adapters/test_lora.py Normal file
View file

@ -0,0 +1,118 @@
import pytest
import torch
from refiners.fluxion import layers as fl
from refiners.fluxion.adapters.lora import Conv2dLora, LinearLora, Lora, LoraAdapter
@pytest.fixture
def lora() -> LinearLora:
return LinearLora("test", in_features=320, out_features=128, rank=16)
@pytest.fixture
def conv_lora() -> Lora:
return Conv2dLora("conv_test", in_channels=16, out_channels=8, kernel_size=(3, 1), rank=4)
def test_properties(lora: LinearLora, conv_lora: Lora) -> None:
assert lora.name == "test"
assert lora.rank == lora.down.out_features == lora.up.in_features == 16
assert lora.scale == 1.0
assert lora.in_features == lora.down.in_features == 320
assert lora.out_features == lora.up.out_features == 128
assert conv_lora.name == "conv_test"
assert conv_lora.rank == conv_lora.down.out_channels == conv_lora.up.in_channels == 4
assert conv_lora.scale == 1.0
assert conv_lora.in_channels == conv_lora.down.in_channels == 16
assert conv_lora.out_channels == conv_lora.up.out_channels == 8
assert conv_lora.kernel_size == (conv_lora.down.kernel_size[0], conv_lora.up.kernel_size[0]) == (3, 1)
# padding is set so the spatial dimensions are preserved
assert conv_lora.padding == (conv_lora.down.padding[0], conv_lora.up.padding[0]) == (0, 1)
def test_scale_setter(lora: LinearLora) -> None:
lora.scale = 2.0
assert lora.scale == 2.0
assert lora.ensure_find(fl.Multiply).scale == 2.0
def test_from_weights(lora: LinearLora, conv_lora: Conv2dLora) -> None:
new_lora = LinearLora.from_weights("test", down=lora.down.weight, up=lora.up.weight)
x = torch.randn(1, 320)
assert torch.allclose(lora(x), new_lora(x))
new_conv_lora = Conv2dLora.from_weights("conv_test", down=conv_lora.down.weight, up=conv_lora.up.weight)
x = torch.randn(1, 16, 64, 64)
assert torch.allclose(conv_lora(x), new_conv_lora(x))
def test_from_dict() -> None:
state_dict = {
"down.weight": torch.randn(320, 128),
"up.weight": torch.randn(128, 320),
"this.is_not_used.alpha": torch.randn(1, 320),
"probably.a.conv.down.weight": torch.randn(4, 16, 3, 3),
"probably.a.conv.up.weight": torch.randn(8, 4, 1, 1),
}
loras = Lora.from_dict("test", state_dict=state_dict)
assert len(loras) == 2
linear_lora, conv_lora = loras.values()
assert isinstance(linear_lora, LinearLora)
assert isinstance(conv_lora, Conv2dLora)
assert linear_lora.name == "test"
assert conv_lora.name == "test"
@pytest.fixture
def lora_adapter() -> LoraAdapter:
target = fl.Linear(320, 128)
lora1 = LinearLora("test1", in_features=320, out_features=128, rank=16, scale=2.0)
lora2 = LinearLora("test2", in_features=320, out_features=128, rank=16, scale=-1.0)
return LoraAdapter(target, lora1, lora2)
def test_names(lora_adapter: LoraAdapter) -> None:
assert set(lora_adapter.names) == {"test1", "test2"}
def test_loras(lora_adapter: LoraAdapter) -> None:
assert set(lora_adapter.loras.keys()) == {"test1", "test2"}
def test_scales(lora_adapter: LoraAdapter) -> None:
assert set(lora_adapter.scales.keys()) == {"test1", "test2"}
assert lora_adapter.scales["test1"] == 2.0
assert lora_adapter.scales["test2"] == -1.0
def test_scale_setter_lora_adapter(lora_adapter: LoraAdapter) -> None:
lora_adapter.scale = {"test1": 0.0, "test2": 3.0}
assert lora_adapter.scales == {"test1": 0.0, "test2": 3.0}
def test_add_lora(lora_adapter: LoraAdapter) -> None:
lora3 = LinearLora("test3", in_features=320, out_features=128, rank=16)
lora_adapter.add_lora(lora3)
assert "test3" in lora_adapter.names
def test_remove_lora(lora_adapter: LoraAdapter) -> None:
lora_adapter.remove_lora("test1")
assert "test1" not in lora_adapter.names
def test_add_existing_lora(lora_adapter: LoraAdapter) -> None:
lora3 = LinearLora("test1", in_features=320, out_features=128, rank=16)
with pytest.raises(AssertionError):
lora_adapter.add_lora(lora3)
def test_remove_nonexistent_lora(lora_adapter: LoraAdapter) -> None:
assert lora_adapter.remove_lora("test3") is None
def test_set_scale_for_nonexistent_lora(lora_adapter: LoraAdapter) -> None:
with pytest.raises(KeyError):
lora_adapter.scale = {"test3": 2.0}

View file

@ -216,6 +216,26 @@ def lora_data_dpo(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image,
return expected_image, tensors
@pytest.fixture(scope="module")
def lora_sliders(test_weights_path: Path) -> tuple[dict[str, dict[str, torch.Tensor]], dict[str, float]]:
weights_path = test_weights_path / "loras" / "sliders"
if not weights_path.is_dir():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
return {
"age": load_tensors(weights_path / "age.pt"), # type: ignore
"cartoon_style": load_tensors(weights_path / "cartoon_style.pt"), # type: ignore
"eyesize": load_tensors(weights_path / "eyesize.pt"), # type: ignore
}, {
"age": 0.3,
"cartoon_style": -0.2,
"dpo": 1.4,
"eyesize": -0.2,
}
@pytest.fixture
def scene_image_inpainting_refonly(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "inpainting-scene.png").convert("RGB")
@ -266,6 +286,11 @@ def expected_freeu(ref_path: Path) -> Image.Image:
return Image.open(fp=ref_path / "expected_freeu.png").convert(mode="RGB")
@pytest.fixture
def expected_sdxl_multi_loras(ref_path: Path) -> Image.Image:
return Image.open(fp=ref_path / "expected_sdxl_multi_loras.png").convert(mode="RGB")
@pytest.fixture
def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image.Image, Image.Image]:
assets = Path(__file__).parent.parent.parent / "assets"
@ -1034,7 +1059,7 @@ def test_diffusion_lora(
sd15.set_inference_steps(30)
SDLoraManager(sd15).load(lora_weights, scale=1)
SDLoraManager(sd15).add_loras("pokemon", lora_weights, scale=1)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
@ -1067,7 +1092,7 @@ def test_diffusion_sdxl_lora(
prompt = "professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography"
negative_prompt = "3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white"
SDLoraManager(sdxl).load(lora_weights, scale=lora_scale)
SDLoraManager(sdxl).add_loras("dpo", lora_weights, scale=lora_scale)
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt
@ -1094,6 +1119,54 @@ def test_diffusion_sdxl_lora(
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_diffusion_sdxl_multiple_loras(
sdxl_ddim: StableDiffusion_XL,
lora_data_dpo: tuple[Image.Image, dict[str, torch.Tensor]],
lora_sliders: tuple[dict[str, dict[str, torch.Tensor]], dict[str, float]],
expected_sdxl_multi_loras: Image.Image,
) -> None:
sdxl = sdxl_ddim
expected_image = expected_sdxl_multi_loras
_, dpo = lora_data_dpo
loras, scales = lora_sliders
loras["dpo"] = dpo
SDLoraManager(sdxl).add_multiple_loras(loras, scales)
# parameters are the same as https://huggingface.co/radames/sdxl-DPO-LoRA
# except that we are using DDIM instead of sde-dpmsolver++
n_steps = 40
seed = 12341234123
guidance_scale = 4
prompt = "professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography"
negative_prompt = "3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white"
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt
)
time_ids = sdxl.default_time_ids
sdxl.set_inference_steps(n_steps)
manual_seed(seed=seed)
x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=guidance_scale,
)
predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_diffusion_refonly(
sd15_ddim: StableDiffusion_1,

View file

@ -49,6 +49,7 @@ Special cases:
- `expected_freeu.png`
- `expected_dropy_slime_9752.png`
- `expected_sdxl_dpo_lora.png`
- `expected_sdxl_multi_loras.png`
## Other images

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 MiB

View file

@ -0,0 +1,88 @@
from pathlib import Path
from warnings import warn
import pytest
import torch
from refiners.fluxion.utils import load_tensors
from refiners.foundationals.latent_diffusion import StableDiffusion_1
from refiners.foundationals.latent_diffusion.lora import Lora, SDLoraManager
@pytest.fixture
def manager() -> SDLoraManager:
target = StableDiffusion_1()
return SDLoraManager(target)
@pytest.fixture
def weights(test_weights_path: Path) -> dict[str, torch.Tensor]:
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin"
if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
return load_tensors(weights_path)
def test_add_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
manager.add_loras("pokemon-lora", tensors=weights)
assert "pokemon-lora" in manager.names
with pytest.raises(AssertionError) as exc:
manager.add_loras("pokemon-lora", tensors=weights)
assert "already exists" in str(exc.value)
def test_add_multiple_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
manager.add_multiple_loras({"pokemon-lora": weights, "pokemon-lora2": weights})
assert "pokemon-lora" in manager.names
assert "pokemon-lora2" in manager.names
def test_remove_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
manager.add_multiple_loras({"pokemon-lora": weights, "pokemon-lora2": weights})
manager.remove_loras("pokemon-lora")
assert "pokemon-lora" not in manager.names
assert "pokemon-lora2" in manager.names
manager.remove_loras("pokemon-lora2")
assert "pokemon-lora2" not in manager.names
assert len(manager.names) == 0
def test_remove_all(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
manager.add_multiple_loras({"pokemon-lora": weights, "pokemon-lora2": weights})
manager.remove_all()
assert len(manager.names) == 0
def test_get_lora(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
manager.add_loras("pokemon-lora", tensors=weights)
assert all(isinstance(lora, Lora) for lora in manager.get_loras_by_name("pokemon-lora"))
def test_get_scale(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
manager.add_loras("pokemon-lora", tensors=weights, scale=0.4)
assert manager.get_scale("pokemon-lora") == 0.4
def test_names(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
assert manager.names == []
manager.add_loras("pokemon-lora", tensors=weights)
assert manager.names == ["pokemon-lora"]
manager.add_loras("pokemon-lora2", tensors=weights)
assert manager.names == ["pokemon-lora", "pokemon-lora2"]
def test_scales(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
assert manager.scales == {}
manager.add_loras("pokemon-lora", tensors=weights, scale=0.4)
assert manager.scales == {"pokemon-lora": 0.4}
manager.add_loras("pokemon-lora2", tensors=weights, scale=0.5)
assert manager.scales == {"pokemon-lora": 0.4, "pokemon-lora2": 0.5}