Load Multiple LoRAs with SDLoraManager

This commit is contained in:
limiteinductive 2024-01-22 14:45:34 +01:00 committed by Benjamin Trom
parent fb2f0e28d4
commit 421da6a3b6
8 changed files with 439 additions and 74 deletions

View file

@ -238,6 +238,11 @@ def download_loras():
"https://huggingface.co/radames/sdxl-DPO-LoRA/resolve/main/pytorch_lora_weights.safetensors", dest_folder "https://huggingface.co/radames/sdxl-DPO-LoRA/resolve/main/pytorch_lora_weights.safetensors", dest_folder
) )
dest_folder = os.path.join(test_weights_dir, "loras", "sliders")
download_file("https://sliders.baulab.info/weights/xl_sliders/age.pt", dest_folder)
download_file("https://sliders.baulab.info/weights/xl_sliders/cartoon_style.pt", dest_folder)
download_file("https://sliders.baulab.info/weights/xl_sliders/eyesize.pt", dest_folder)
def download_preprocessors(): def download_preprocessors():
dest_folder = os.path.join(test_weights_dir, "carolineec", "informativedrawings") dest_folder = os.path.join(test_weights_dir, "carolineec", "informativedrawings")

View file

@ -1,5 +1,4 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any
from torch import Tensor, device as Device, dtype as DType from torch import Tensor, device as Device, dtype as DType
from torch.nn import Parameter as TorchParameter from torch.nn import Parameter as TorchParameter
@ -7,18 +6,20 @@ from torch.nn.init import normal_, zeros_
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.layers.chain import Chain
class Lora(fl.Chain, ABC): class Lora(fl.Chain, ABC):
def __init__( def __init__(
self, self,
name: str,
/,
rank: int = 16, rank: int = 16,
scale: float = 1.0, scale: float = 1.0,
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
) -> None: ) -> None:
self.rank = rank self.name = name
self._rank = rank
self._scale = scale self._scale = scale
super().__init__(*self.lora_layers(device=device, dtype=dtype), fl.Multiply(scale)) super().__init__(*self.lora_layers(device=device, dtype=dtype), fl.Multiply(scale))
@ -44,6 +45,10 @@ class Lora(fl.Chain, ABC):
assert isinstance(up_layer, fl.WeightedModule) assert isinstance(up_layer, fl.WeightedModule)
return up_layer return up_layer
@property
def rank(self) -> int:
return self._rank
@property @property
def scale(self) -> float: def scale(self) -> float:
return self._scale return self._scale
@ -56,19 +61,21 @@ class Lora(fl.Chain, ABC):
@classmethod @classmethod
def from_weights( def from_weights(
cls, cls,
name: str,
/,
down: Tensor, down: Tensor,
up: Tensor, up: Tensor,
) -> "Lora": ) -> "Lora":
match (up.ndim, down.ndim): match (up.ndim, down.ndim):
case (2, 2): case (2, 2):
return LinearLora.from_weights(up=up, down=down) return LinearLora.from_weights(name, up=up, down=down)
case (4, 4): case (4, 4):
return Conv2dLora.from_weights(up=up, down=down) return Conv2dLora.from_weights(name, up=up, down=down)
case _: case _:
raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}") raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}")
@classmethod @classmethod
def from_dict(cls, state_dict: dict[str, Tensor], /) -> dict[str, "Lora"]: def from_dict(cls, name: str, /, state_dict: dict[str, Tensor]) -> dict[str, "Lora"]:
""" """
Create a dictionary of LoRA layers from a state dict. Create a dictionary of LoRA layers from a state dict.
@ -80,13 +87,37 @@ class Lora(fl.Chain, ABC):
list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2] list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2]
): ):
key = ".".join(down_key.split(".")[:-2]) key = ".".join(down_key.split(".")[:-2])
loras[key] = cls.from_weights(down=down_tensor, up=up_tensor) loras[key] = cls.from_weights(name, down=down_tensor, up=up_tensor)
return loras return loras
@abstractmethod @abstractmethod
def auto_attach(self, target: fl.Chain, exclude: list[str] | None = None) -> Any: def is_compatible(self, layer: fl.WeightedModule, /) -> bool:
... ...
def auto_attach(
self, target: fl.Chain, exclude: list[str] | None = None
) -> "tuple[LoraAdapter, fl.Chain | None] | None":
for layer, parent in target.walk(self.up.__class__):
if isinstance(parent, Lora):
continue
if exclude is not None and any(
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
):
continue
if not self.is_compatible(layer):
continue
if isinstance(parent, LoraAdapter):
if self.name in parent.names:
continue
else:
parent.add_lora(self)
return parent, None
return LoraAdapter(layer, self), parent
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None: def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
assert down_weight.shape == self.down.weight.shape assert down_weight.shape == self.down.weight.shape
assert up_weight.shape == self.up.weight.shape assert up_weight.shape == self.up.weight.shape
@ -97,6 +128,8 @@ class Lora(fl.Chain, ABC):
class LinearLora(Lora): class LinearLora(Lora):
def __init__( def __init__(
self, self,
name: str,
/,
in_features: int, in_features: int,
out_features: int, out_features: int,
rank: int = 16, rank: int = 16,
@ -107,35 +140,29 @@ class LinearLora(Lora):
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
super().__init__(rank=rank, scale=scale, device=device, dtype=dtype) super().__init__(name, rank=rank, scale=scale, device=device, dtype=dtype)
@classmethod @classmethod
def from_weights( def from_weights(
cls, cls,
name: str,
/,
down: Tensor, down: Tensor,
up: Tensor, up: Tensor,
) -> "LinearLora": ) -> "LinearLora":
assert up.ndim == 2 and down.ndim == 2 assert up.ndim == 2 and down.ndim == 2
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}" assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
lora = cls( lora = cls(
in_features=down.shape[1], out_features=up.shape[0], rank=down.shape[0], device=up.device, dtype=up.dtype name,
in_features=down.shape[1],
out_features=up.shape[0],
rank=down.shape[0],
device=up.device,
dtype=up.dtype,
) )
lora.load_weights(down_weight=down, up_weight=up) lora.load_weights(down_weight=down, up_weight=up)
return lora return lora
def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None":
for layer, parent in target.walk(fl.Linear):
if isinstance(parent, Lora) or isinstance(parent, LoraAdapter):
continue
if exclude is not None and any(
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
):
continue
if layer.in_features == self.in_features and layer.out_features == self.out_features:
return LoraAdapter(target=layer, lora=self), parent
def lora_layers( def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.Linear, fl.Linear]: ) -> tuple[fl.Linear, fl.Linear]:
@ -156,10 +183,17 @@ class LinearLora(Lora):
), ),
) )
def is_compatible(self, layer: fl.WeightedModule, /) -> bool:
if isinstance(layer, fl.Linear):
return layer.in_features == self.in_features and layer.out_features == self.out_features
return False
class Conv2dLora(Lora): class Conv2dLora(Lora):
def __init__( def __init__(
self, self,
name: str,
/,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
rank: int = 16, rank: int = 16,
@ -176,20 +210,24 @@ class Conv2dLora(Lora):
self.stride = stride self.stride = stride
self.padding = padding self.padding = padding
super().__init__(rank=rank, scale=scale, device=device, dtype=dtype) super().__init__(name, rank=rank, scale=scale, device=device, dtype=dtype)
@classmethod @classmethod
def from_weights( def from_weights(
cls, cls,
name: str,
/,
down: Tensor, down: Tensor,
up: Tensor, up: Tensor,
) -> "Conv2dLora": ) -> "Conv2dLora":
assert up.ndim == 4 and down.ndim == 4 assert up.ndim == 4 and down.ndim == 4
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}" assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
down_kernel_size, up_kernel_size = down.shape[2], up.shape[2] down_kernel_size, up_kernel_size = down.shape[2], up.shape[2]
# padding is set so the spatial dimensions are preserved (assuming stride=1 and kernel_size either 1 or 3)
down_padding = 1 if down_kernel_size == 3 else 0 down_padding = 1 if down_kernel_size == 3 else 0
up_padding = 1 if up_kernel_size == 3 else 0 up_padding = 1 if up_kernel_size == 3 else 0
lora = cls( lora = cls(
name,
in_channels=down.shape[1], in_channels=down.shape[1],
out_channels=up.shape[0], out_channels=up.shape[0],
rank=down.shape[0], rank=down.shape[0],
@ -201,25 +239,6 @@ class Conv2dLora(Lora):
lora.load_weights(down_weight=down, up_weight=up) lora.load_weights(down_weight=down, up_weight=up)
return lora return lora
def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None":
for layer, parent in target.walk(fl.Conv2d):
if isinstance(parent, Lora) or isinstance(parent, LoraAdapter):
continue
if exclude is not None and any(
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
):
continue
if layer.in_channels == self.in_channels and layer.out_channels == self.out_channels:
if layer.stride != (self.stride[0], self.stride[0]):
self.down.stride = layer.stride
return LoraAdapter(
target=layer,
lora=self,
), parent
def lora_layers( def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.Conv2d, fl.Conv2d]: ) -> tuple[fl.Conv2d, fl.Conv2d]:
@ -246,20 +265,47 @@ class Conv2dLora(Lora):
), ),
) )
def is_compatible(self, layer: fl.WeightedModule, /) -> bool:
if (
isinstance(layer, fl.Conv2d)
and layer.in_channels == self.in_channels
and layer.out_channels == self.out_channels
):
# stride cannot be inferred from the weights, so we assume it's the same as the layer
self.down.stride = layer.stride
return True
return False
class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]): class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
def __init__(self, target: fl.WeightedModule, lora: Lora) -> None: def __init__(self, target: fl.WeightedModule, /, *loras: Lora) -> None:
with self.setup_adapter(target): with self.setup_adapter(target):
super().__init__(target, lora) super().__init__(target, *loras)
@property @property
def lora(self) -> Lora: def names(self) -> list[str]:
return self.ensure_find(Lora) return [lora.name for lora in self.layers(Lora)]
@property @property
def scale(self) -> float: def loras(self) -> dict[str, Lora]:
return self.lora.scale return {lora.name: lora for lora in self.layers(Lora)}
@scale.setter @property
def scale(self, value: float) -> None: def scales(self) -> dict[str, float]:
self.lora.scale = value return {lora.name: lora.scale for lora in self.layers(Lora)}
@scales.setter
def scale(self, values: dict[str, float]) -> None:
for name, value in values.items():
self.loras[name].scale = value
def add_lora(self, lora: Lora, /) -> None:
assert lora.name not in self.names, f"LoRA layer with name {lora.name} already exists"
self.append(lora)
def remove_lora(self, name: str, /) -> Lora | None:
if name in self.names:
lora = self.loras[name]
self.remove(lora)
return lora

View file

@ -26,19 +26,20 @@ class SDLoraManager:
assert isinstance(clip_text_encoder, fl.Chain) assert isinstance(clip_text_encoder, fl.Chain)
return clip_text_encoder return clip_text_encoder
def load( def add_loras(
self, self,
tensors: dict[str, Tensor], name: str,
/, /,
tensors: dict[str, Tensor],
scale: float = 1.0, scale: float = 1.0,
) -> None: ) -> None:
"""Load the LoRA weights from a dictionary of tensors. """Load the LoRA weights from a dictionary of tensors.
Expects the keys to be in the commonly found formats on CivitAI's hub. Expects the keys to be in the commonly found formats on CivitAI's hub.
""" """
assert len(self.lora_adapters) == 0, "Loras already loaded" assert name not in self.names, f"LoRA {name} already exists"
loras = Lora.from_dict( loras = Lora.from_dict(
{key: value.to(device=self.target.device, dtype=self.target.dtype) for key, value in tensors.items()} name, {key: value.to(device=self.target.device, dtype=self.target.dtype) for key, value in tensors.items()}
) )
loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)} loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)}
@ -46,16 +47,25 @@ class SDLoraManager:
if not "unet" in loras and not "text" in loras: if not "unet" in loras and not "text" in loras:
loras = {f"unet_{key}": loras[key] for key in loras.keys()} loras = {f"unet_{key}": loras[key] for key in loras.keys()}
self.load_unet(loras) self.add_loras_to_unet(loras)
self.load_text_encoder(loras) self.add_loras_to_text_encoder(loras)
self.scale = scale self.set_scale(name, scale)
def load_text_encoder(self, loras: dict[str, Lora], /) -> None: def add_multiple_loras(
self,
/,
tensors: dict[str, dict[str, Tensor]],
scale: dict[str, float] | None = None,
) -> None:
for name, lora_tensors in tensors.items():
self.add_loras(name, tensors=lora_tensors, scale=scale[name] if scale else 1.0)
def add_loras_to_text_encoder(self, loras: dict[str, Lora], /) -> None:
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key} text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder) SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder)
def load_unet(self, loras: dict[str, Lora], /) -> None: def add_loras_to_unet(self, loras: dict[str, Lora], /) -> None:
unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key} unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key}
exclude: list[str] = [] exclude: list[str] = []
exclude = [ exclude = [
@ -65,14 +75,43 @@ class SDLoraManager:
] ]
SDLoraManager.auto_attach(unet_loras, self.unet, exclude=exclude) SDLoraManager.auto_attach(unet_loras, self.unet, exclude=exclude)
def unload(self) -> None: def remove_loras(self, *names: str) -> None:
for lora_adapter in self.lora_adapters:
for name in names:
lora_adapter.remove_lora(name)
if len(lora_adapter.loras) == 0:
lora_adapter.eject()
def remove_all(self) -> None:
for lora_adapter in self.lora_adapters: for lora_adapter in self.lora_adapters:
lora_adapter.eject() lora_adapter.eject()
def get_loras_by_name(self, name: str, /) -> list[Lora]:
return [lora for lora in self.loras if lora.name == name]
def get_scale(self, name: str, /) -> float:
loras = self.get_loras_by_name(name)
assert all([lora.scale == loras[0].scale for lora in loras]), "lora scales are not all the same"
return loras[0].scale
def set_scale(self, name: str, scale: float, /) -> None:
self.update_scales({name: scale})
def update_scales(self, scales: dict[str, float], /) -> None:
assert all([name in self.names for name in scales]), f"Scales keys must be a subset of {self.names}"
for name, scale in scales.items():
for lora in self.get_loras_by_name(name):
lora.scale = scale
@property @property
def loras(self) -> list[Lora]: def loras(self) -> list[Lora]:
return list(self.unet.layers(Lora)) + list(self.clip_text_encoder.layers(Lora)) return list(self.unet.layers(Lora)) + list(self.clip_text_encoder.layers(Lora))
@property
def names(self) -> list[str]:
return list(set(lora.name for lora in self.loras))
@property @property
def lora_adapters(self) -> list[LoraAdapter]: def lora_adapters(self) -> list[LoraAdapter]:
return list(self.unet.layers(LoraAdapter)) + list(self.clip_text_encoder.layers(LoraAdapter)) return list(self.unet.layers(LoraAdapter)) + list(self.clip_text_encoder.layers(LoraAdapter))
@ -87,15 +126,8 @@ class SDLoraManager:
} }
@property @property
def scale(self) -> float: def scales(self) -> dict[str, float]:
assert len(self.loras) > 0, "No loras found" return {name: self.get_scale(name) for name in self.names}
assert all([lora.scale == self.loras[0].scale for lora in self.loras])
return self.loras[0].scale
@scale.setter
def scale(self, value: float) -> None:
for lora in self.loras:
lora.scale = value
@staticmethod @staticmethod
def pad(input: str, /, padding_length: int = 2) -> str: def pad(input: str, /, padding_length: int = 2) -> str:
@ -130,6 +162,8 @@ class SDLoraManager:
for key, lora in loras.items(): for key, lora in loras.items():
if attach := lora.auto_attach(target, exclude=exclude): if attach := lora.auto_attach(target, exclude=exclude):
adapter, parent = attach adapter, parent = attach
# if parent is None, `adapter` is already attached and `lora` has been added to it
if parent is not None:
adapter.inject(parent) adapter.inject(parent)
else: else:
failed_loras[key] = lora failed_loras[key] = lora

118
tests/adapters/test_lora.py Normal file
View file

@ -0,0 +1,118 @@
import pytest
import torch
from refiners.fluxion import layers as fl
from refiners.fluxion.adapters.lora import Conv2dLora, LinearLora, Lora, LoraAdapter
@pytest.fixture
def lora() -> LinearLora:
return LinearLora("test", in_features=320, out_features=128, rank=16)
@pytest.fixture
def conv_lora() -> Lora:
return Conv2dLora("conv_test", in_channels=16, out_channels=8, kernel_size=(3, 1), rank=4)
def test_properties(lora: LinearLora, conv_lora: Lora) -> None:
assert lora.name == "test"
assert lora.rank == lora.down.out_features == lora.up.in_features == 16
assert lora.scale == 1.0
assert lora.in_features == lora.down.in_features == 320
assert lora.out_features == lora.up.out_features == 128
assert conv_lora.name == "conv_test"
assert conv_lora.rank == conv_lora.down.out_channels == conv_lora.up.in_channels == 4
assert conv_lora.scale == 1.0
assert conv_lora.in_channels == conv_lora.down.in_channels == 16
assert conv_lora.out_channels == conv_lora.up.out_channels == 8
assert conv_lora.kernel_size == (conv_lora.down.kernel_size[0], conv_lora.up.kernel_size[0]) == (3, 1)
# padding is set so the spatial dimensions are preserved
assert conv_lora.padding == (conv_lora.down.padding[0], conv_lora.up.padding[0]) == (0, 1)
def test_scale_setter(lora: LinearLora) -> None:
lora.scale = 2.0
assert lora.scale == 2.0
assert lora.ensure_find(fl.Multiply).scale == 2.0
def test_from_weights(lora: LinearLora, conv_lora: Conv2dLora) -> None:
new_lora = LinearLora.from_weights("test", down=lora.down.weight, up=lora.up.weight)
x = torch.randn(1, 320)
assert torch.allclose(lora(x), new_lora(x))
new_conv_lora = Conv2dLora.from_weights("conv_test", down=conv_lora.down.weight, up=conv_lora.up.weight)
x = torch.randn(1, 16, 64, 64)
assert torch.allclose(conv_lora(x), new_conv_lora(x))
def test_from_dict() -> None:
state_dict = {
"down.weight": torch.randn(320, 128),
"up.weight": torch.randn(128, 320),
"this.is_not_used.alpha": torch.randn(1, 320),
"probably.a.conv.down.weight": torch.randn(4, 16, 3, 3),
"probably.a.conv.up.weight": torch.randn(8, 4, 1, 1),
}
loras = Lora.from_dict("test", state_dict=state_dict)
assert len(loras) == 2
linear_lora, conv_lora = loras.values()
assert isinstance(linear_lora, LinearLora)
assert isinstance(conv_lora, Conv2dLora)
assert linear_lora.name == "test"
assert conv_lora.name == "test"
@pytest.fixture
def lora_adapter() -> LoraAdapter:
target = fl.Linear(320, 128)
lora1 = LinearLora("test1", in_features=320, out_features=128, rank=16, scale=2.0)
lora2 = LinearLora("test2", in_features=320, out_features=128, rank=16, scale=-1.0)
return LoraAdapter(target, lora1, lora2)
def test_names(lora_adapter: LoraAdapter) -> None:
assert set(lora_adapter.names) == {"test1", "test2"}
def test_loras(lora_adapter: LoraAdapter) -> None:
assert set(lora_adapter.loras.keys()) == {"test1", "test2"}
def test_scales(lora_adapter: LoraAdapter) -> None:
assert set(lora_adapter.scales.keys()) == {"test1", "test2"}
assert lora_adapter.scales["test1"] == 2.0
assert lora_adapter.scales["test2"] == -1.0
def test_scale_setter_lora_adapter(lora_adapter: LoraAdapter) -> None:
lora_adapter.scale = {"test1": 0.0, "test2": 3.0}
assert lora_adapter.scales == {"test1": 0.0, "test2": 3.0}
def test_add_lora(lora_adapter: LoraAdapter) -> None:
lora3 = LinearLora("test3", in_features=320, out_features=128, rank=16)
lora_adapter.add_lora(lora3)
assert "test3" in lora_adapter.names
def test_remove_lora(lora_adapter: LoraAdapter) -> None:
lora_adapter.remove_lora("test1")
assert "test1" not in lora_adapter.names
def test_add_existing_lora(lora_adapter: LoraAdapter) -> None:
lora3 = LinearLora("test1", in_features=320, out_features=128, rank=16)
with pytest.raises(AssertionError):
lora_adapter.add_lora(lora3)
def test_remove_nonexistent_lora(lora_adapter: LoraAdapter) -> None:
assert lora_adapter.remove_lora("test3") is None
def test_set_scale_for_nonexistent_lora(lora_adapter: LoraAdapter) -> None:
with pytest.raises(KeyError):
lora_adapter.scale = {"test3": 2.0}

View file

@ -216,6 +216,26 @@ def lora_data_dpo(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image,
return expected_image, tensors return expected_image, tensors
@pytest.fixture(scope="module")
def lora_sliders(test_weights_path: Path) -> tuple[dict[str, dict[str, torch.Tensor]], dict[str, float]]:
weights_path = test_weights_path / "loras" / "sliders"
if not weights_path.is_dir():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
return {
"age": load_tensors(weights_path / "age.pt"), # type: ignore
"cartoon_style": load_tensors(weights_path / "cartoon_style.pt"), # type: ignore
"eyesize": load_tensors(weights_path / "eyesize.pt"), # type: ignore
}, {
"age": 0.3,
"cartoon_style": -0.2,
"dpo": 1.4,
"eyesize": -0.2,
}
@pytest.fixture @pytest.fixture
def scene_image_inpainting_refonly(ref_path: Path) -> Image.Image: def scene_image_inpainting_refonly(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "inpainting-scene.png").convert("RGB") return Image.open(ref_path / "inpainting-scene.png").convert("RGB")
@ -266,6 +286,11 @@ def expected_freeu(ref_path: Path) -> Image.Image:
return Image.open(fp=ref_path / "expected_freeu.png").convert(mode="RGB") return Image.open(fp=ref_path / "expected_freeu.png").convert(mode="RGB")
@pytest.fixture
def expected_sdxl_multi_loras(ref_path: Path) -> Image.Image:
return Image.open(fp=ref_path / "expected_sdxl_multi_loras.png").convert(mode="RGB")
@pytest.fixture @pytest.fixture
def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image.Image, Image.Image]: def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image.Image, Image.Image]:
assets = Path(__file__).parent.parent.parent / "assets" assets = Path(__file__).parent.parent.parent / "assets"
@ -1034,7 +1059,7 @@ def test_diffusion_lora(
sd15.set_inference_steps(30) sd15.set_inference_steps(30)
SDLoraManager(sd15).load(lora_weights, scale=1) SDLoraManager(sd15).add_loras("pokemon", lora_weights, scale=1)
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device) x = torch.randn(1, 4, 64, 64, device=test_device)
@ -1067,7 +1092,7 @@ def test_diffusion_sdxl_lora(
prompt = "professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography" prompt = "professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography"
negative_prompt = "3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white" negative_prompt = "3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white"
SDLoraManager(sdxl).load(lora_weights, scale=lora_scale) SDLoraManager(sdxl).add_loras("dpo", lora_weights, scale=lora_scale)
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt text=prompt, negative_text=negative_prompt
@ -1094,6 +1119,54 @@ def test_diffusion_sdxl_lora(
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_diffusion_sdxl_multiple_loras(
sdxl_ddim: StableDiffusion_XL,
lora_data_dpo: tuple[Image.Image, dict[str, torch.Tensor]],
lora_sliders: tuple[dict[str, dict[str, torch.Tensor]], dict[str, float]],
expected_sdxl_multi_loras: Image.Image,
) -> None:
sdxl = sdxl_ddim
expected_image = expected_sdxl_multi_loras
_, dpo = lora_data_dpo
loras, scales = lora_sliders
loras["dpo"] = dpo
SDLoraManager(sdxl).add_multiple_loras(loras, scales)
# parameters are the same as https://huggingface.co/radames/sdxl-DPO-LoRA
# except that we are using DDIM instead of sde-dpmsolver++
n_steps = 40
seed = 12341234123
guidance_scale = 4
prompt = "professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography"
negative_prompt = "3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white"
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt
)
time_ids = sdxl.default_time_ids
sdxl.set_inference_steps(n_steps)
manual_seed(seed=seed)
x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=guidance_scale,
)
predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@no_grad() @no_grad()
def test_diffusion_refonly( def test_diffusion_refonly(
sd15_ddim: StableDiffusion_1, sd15_ddim: StableDiffusion_1,

View file

@ -49,6 +49,7 @@ Special cases:
- `expected_freeu.png` - `expected_freeu.png`
- `expected_dropy_slime_9752.png` - `expected_dropy_slime_9752.png`
- `expected_sdxl_dpo_lora.png` - `expected_sdxl_dpo_lora.png`
- `expected_sdxl_multi_loras.png`
## Other images ## Other images

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 MiB

View file

@ -0,0 +1,88 @@
from pathlib import Path
from warnings import warn
import pytest
import torch
from refiners.fluxion.utils import load_tensors
from refiners.foundationals.latent_diffusion import StableDiffusion_1
from refiners.foundationals.latent_diffusion.lora import Lora, SDLoraManager
@pytest.fixture
def manager() -> SDLoraManager:
target = StableDiffusion_1()
return SDLoraManager(target)
@pytest.fixture
def weights(test_weights_path: Path) -> dict[str, torch.Tensor]:
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin"
if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
return load_tensors(weights_path)
def test_add_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
manager.add_loras("pokemon-lora", tensors=weights)
assert "pokemon-lora" in manager.names
with pytest.raises(AssertionError) as exc:
manager.add_loras("pokemon-lora", tensors=weights)
assert "already exists" in str(exc.value)
def test_add_multiple_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
manager.add_multiple_loras({"pokemon-lora": weights, "pokemon-lora2": weights})
assert "pokemon-lora" in manager.names
assert "pokemon-lora2" in manager.names
def test_remove_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
manager.add_multiple_loras({"pokemon-lora": weights, "pokemon-lora2": weights})
manager.remove_loras("pokemon-lora")
assert "pokemon-lora" not in manager.names
assert "pokemon-lora2" in manager.names
manager.remove_loras("pokemon-lora2")
assert "pokemon-lora2" not in manager.names
assert len(manager.names) == 0
def test_remove_all(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
manager.add_multiple_loras({"pokemon-lora": weights, "pokemon-lora2": weights})
manager.remove_all()
assert len(manager.names) == 0
def test_get_lora(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
manager.add_loras("pokemon-lora", tensors=weights)
assert all(isinstance(lora, Lora) for lora in manager.get_loras_by_name("pokemon-lora"))
def test_get_scale(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
manager.add_loras("pokemon-lora", tensors=weights, scale=0.4)
assert manager.get_scale("pokemon-lora") == 0.4
def test_names(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
assert manager.names == []
manager.add_loras("pokemon-lora", tensors=weights)
assert manager.names == ["pokemon-lora"]
manager.add_loras("pokemon-lora2", tensors=weights)
assert manager.names == ["pokemon-lora", "pokemon-lora2"]
def test_scales(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
assert manager.scales == {}
manager.add_loras("pokemon-lora", tensors=weights, scale=0.4)
assert manager.scales == {"pokemon-lora": 0.4}
manager.add_loras("pokemon-lora2", tensors=weights, scale=0.5)
assert manager.scales == {"pokemon-lora": 0.4, "pokemon-lora2": 0.5}