mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 00:58:13 +00:00
Load Multiple LoRAs with SDLoraManager
This commit is contained in:
parent
fb2f0e28d4
commit
421da6a3b6
|
@ -238,6 +238,11 @@ def download_loras():
|
||||||
"https://huggingface.co/radames/sdxl-DPO-LoRA/resolve/main/pytorch_lora_weights.safetensors", dest_folder
|
"https://huggingface.co/radames/sdxl-DPO-LoRA/resolve/main/pytorch_lora_weights.safetensors", dest_folder
|
||||||
)
|
)
|
||||||
|
|
||||||
|
dest_folder = os.path.join(test_weights_dir, "loras", "sliders")
|
||||||
|
download_file("https://sliders.baulab.info/weights/xl_sliders/age.pt", dest_folder)
|
||||||
|
download_file("https://sliders.baulab.info/weights/xl_sliders/cartoon_style.pt", dest_folder)
|
||||||
|
download_file("https://sliders.baulab.info/weights/xl_sliders/eyesize.pt", dest_folder)
|
||||||
|
|
||||||
|
|
||||||
def download_preprocessors():
|
def download_preprocessors():
|
||||||
dest_folder = os.path.join(test_weights_dir, "carolineec", "informativedrawings")
|
dest_folder = os.path.join(test_weights_dir, "carolineec", "informativedrawings")
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from torch import Tensor, device as Device, dtype as DType
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
from torch.nn import Parameter as TorchParameter
|
from torch.nn import Parameter as TorchParameter
|
||||||
|
@ -7,18 +6,20 @@ from torch.nn.init import normal_, zeros_
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.adapters.adapter import Adapter
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
from refiners.fluxion.layers.chain import Chain
|
|
||||||
|
|
||||||
|
|
||||||
class Lora(fl.Chain, ABC):
|
class Lora(fl.Chain, ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
name: str,
|
||||||
|
/,
|
||||||
rank: int = 16,
|
rank: int = 16,
|
||||||
scale: float = 1.0,
|
scale: float = 1.0,
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.rank = rank
|
self.name = name
|
||||||
|
self._rank = rank
|
||||||
self._scale = scale
|
self._scale = scale
|
||||||
|
|
||||||
super().__init__(*self.lora_layers(device=device, dtype=dtype), fl.Multiply(scale))
|
super().__init__(*self.lora_layers(device=device, dtype=dtype), fl.Multiply(scale))
|
||||||
|
@ -44,6 +45,10 @@ class Lora(fl.Chain, ABC):
|
||||||
assert isinstance(up_layer, fl.WeightedModule)
|
assert isinstance(up_layer, fl.WeightedModule)
|
||||||
return up_layer
|
return up_layer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rank(self) -> int:
|
||||||
|
return self._rank
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scale(self) -> float:
|
def scale(self) -> float:
|
||||||
return self._scale
|
return self._scale
|
||||||
|
@ -56,19 +61,21 @@ class Lora(fl.Chain, ABC):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_weights(
|
def from_weights(
|
||||||
cls,
|
cls,
|
||||||
|
name: str,
|
||||||
|
/,
|
||||||
down: Tensor,
|
down: Tensor,
|
||||||
up: Tensor,
|
up: Tensor,
|
||||||
) -> "Lora":
|
) -> "Lora":
|
||||||
match (up.ndim, down.ndim):
|
match (up.ndim, down.ndim):
|
||||||
case (2, 2):
|
case (2, 2):
|
||||||
return LinearLora.from_weights(up=up, down=down)
|
return LinearLora.from_weights(name, up=up, down=down)
|
||||||
case (4, 4):
|
case (4, 4):
|
||||||
return Conv2dLora.from_weights(up=up, down=down)
|
return Conv2dLora.from_weights(name, up=up, down=down)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}")
|
raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, state_dict: dict[str, Tensor], /) -> dict[str, "Lora"]:
|
def from_dict(cls, name: str, /, state_dict: dict[str, Tensor]) -> dict[str, "Lora"]:
|
||||||
"""
|
"""
|
||||||
Create a dictionary of LoRA layers from a state dict.
|
Create a dictionary of LoRA layers from a state dict.
|
||||||
|
|
||||||
|
@ -80,13 +87,37 @@ class Lora(fl.Chain, ABC):
|
||||||
list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2]
|
list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2]
|
||||||
):
|
):
|
||||||
key = ".".join(down_key.split(".")[:-2])
|
key = ".".join(down_key.split(".")[:-2])
|
||||||
loras[key] = cls.from_weights(down=down_tensor, up=up_tensor)
|
loras[key] = cls.from_weights(name, down=down_tensor, up=up_tensor)
|
||||||
return loras
|
return loras
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def auto_attach(self, target: fl.Chain, exclude: list[str] | None = None) -> Any:
|
def is_compatible(self, layer: fl.WeightedModule, /) -> bool:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def auto_attach(
|
||||||
|
self, target: fl.Chain, exclude: list[str] | None = None
|
||||||
|
) -> "tuple[LoraAdapter, fl.Chain | None] | None":
|
||||||
|
for layer, parent in target.walk(self.up.__class__):
|
||||||
|
if isinstance(parent, Lora):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if exclude is not None and any(
|
||||||
|
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not self.is_compatible(layer):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(parent, LoraAdapter):
|
||||||
|
if self.name in parent.names:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
parent.add_lora(self)
|
||||||
|
return parent, None
|
||||||
|
|
||||||
|
return LoraAdapter(layer, self), parent
|
||||||
|
|
||||||
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
|
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
|
||||||
assert down_weight.shape == self.down.weight.shape
|
assert down_weight.shape == self.down.weight.shape
|
||||||
assert up_weight.shape == self.up.weight.shape
|
assert up_weight.shape == self.up.weight.shape
|
||||||
|
@ -97,6 +128,8 @@ class Lora(fl.Chain, ABC):
|
||||||
class LinearLora(Lora):
|
class LinearLora(Lora):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
name: str,
|
||||||
|
/,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
out_features: int,
|
out_features: int,
|
||||||
rank: int = 16,
|
rank: int = 16,
|
||||||
|
@ -107,35 +140,29 @@ class LinearLora(Lora):
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
|
|
||||||
super().__init__(rank=rank, scale=scale, device=device, dtype=dtype)
|
super().__init__(name, rank=rank, scale=scale, device=device, dtype=dtype)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_weights(
|
def from_weights(
|
||||||
cls,
|
cls,
|
||||||
|
name: str,
|
||||||
|
/,
|
||||||
down: Tensor,
|
down: Tensor,
|
||||||
up: Tensor,
|
up: Tensor,
|
||||||
) -> "LinearLora":
|
) -> "LinearLora":
|
||||||
assert up.ndim == 2 and down.ndim == 2
|
assert up.ndim == 2 and down.ndim == 2
|
||||||
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
|
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
|
||||||
lora = cls(
|
lora = cls(
|
||||||
in_features=down.shape[1], out_features=up.shape[0], rank=down.shape[0], device=up.device, dtype=up.dtype
|
name,
|
||||||
|
in_features=down.shape[1],
|
||||||
|
out_features=up.shape[0],
|
||||||
|
rank=down.shape[0],
|
||||||
|
device=up.device,
|
||||||
|
dtype=up.dtype,
|
||||||
)
|
)
|
||||||
lora.load_weights(down_weight=down, up_weight=up)
|
lora.load_weights(down_weight=down, up_weight=up)
|
||||||
return lora
|
return lora
|
||||||
|
|
||||||
def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None":
|
|
||||||
for layer, parent in target.walk(fl.Linear):
|
|
||||||
if isinstance(parent, Lora) or isinstance(parent, LoraAdapter):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if exclude is not None and any(
|
|
||||||
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if layer.in_features == self.in_features and layer.out_features == self.out_features:
|
|
||||||
return LoraAdapter(target=layer, lora=self), parent
|
|
||||||
|
|
||||||
def lora_layers(
|
def lora_layers(
|
||||||
self, device: Device | str | None = None, dtype: DType | None = None
|
self, device: Device | str | None = None, dtype: DType | None = None
|
||||||
) -> tuple[fl.Linear, fl.Linear]:
|
) -> tuple[fl.Linear, fl.Linear]:
|
||||||
|
@ -156,10 +183,17 @@ class LinearLora(Lora):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def is_compatible(self, layer: fl.WeightedModule, /) -> bool:
|
||||||
|
if isinstance(layer, fl.Linear):
|
||||||
|
return layer.in_features == self.in_features and layer.out_features == self.out_features
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class Conv2dLora(Lora):
|
class Conv2dLora(Lora):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
name: str,
|
||||||
|
/,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
out_channels: int,
|
out_channels: int,
|
||||||
rank: int = 16,
|
rank: int = 16,
|
||||||
|
@ -176,20 +210,24 @@ class Conv2dLora(Lora):
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
|
|
||||||
super().__init__(rank=rank, scale=scale, device=device, dtype=dtype)
|
super().__init__(name, rank=rank, scale=scale, device=device, dtype=dtype)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_weights(
|
def from_weights(
|
||||||
cls,
|
cls,
|
||||||
|
name: str,
|
||||||
|
/,
|
||||||
down: Tensor,
|
down: Tensor,
|
||||||
up: Tensor,
|
up: Tensor,
|
||||||
) -> "Conv2dLora":
|
) -> "Conv2dLora":
|
||||||
assert up.ndim == 4 and down.ndim == 4
|
assert up.ndim == 4 and down.ndim == 4
|
||||||
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
|
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
|
||||||
down_kernel_size, up_kernel_size = down.shape[2], up.shape[2]
|
down_kernel_size, up_kernel_size = down.shape[2], up.shape[2]
|
||||||
|
# padding is set so the spatial dimensions are preserved (assuming stride=1 and kernel_size either 1 or 3)
|
||||||
down_padding = 1 if down_kernel_size == 3 else 0
|
down_padding = 1 if down_kernel_size == 3 else 0
|
||||||
up_padding = 1 if up_kernel_size == 3 else 0
|
up_padding = 1 if up_kernel_size == 3 else 0
|
||||||
lora = cls(
|
lora = cls(
|
||||||
|
name,
|
||||||
in_channels=down.shape[1],
|
in_channels=down.shape[1],
|
||||||
out_channels=up.shape[0],
|
out_channels=up.shape[0],
|
||||||
rank=down.shape[0],
|
rank=down.shape[0],
|
||||||
|
@ -201,25 +239,6 @@ class Conv2dLora(Lora):
|
||||||
lora.load_weights(down_weight=down, up_weight=up)
|
lora.load_weights(down_weight=down, up_weight=up)
|
||||||
return lora
|
return lora
|
||||||
|
|
||||||
def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None":
|
|
||||||
for layer, parent in target.walk(fl.Conv2d):
|
|
||||||
if isinstance(parent, Lora) or isinstance(parent, LoraAdapter):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if exclude is not None and any(
|
|
||||||
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if layer.in_channels == self.in_channels and layer.out_channels == self.out_channels:
|
|
||||||
if layer.stride != (self.stride[0], self.stride[0]):
|
|
||||||
self.down.stride = layer.stride
|
|
||||||
|
|
||||||
return LoraAdapter(
|
|
||||||
target=layer,
|
|
||||||
lora=self,
|
|
||||||
), parent
|
|
||||||
|
|
||||||
def lora_layers(
|
def lora_layers(
|
||||||
self, device: Device | str | None = None, dtype: DType | None = None
|
self, device: Device | str | None = None, dtype: DType | None = None
|
||||||
) -> tuple[fl.Conv2d, fl.Conv2d]:
|
) -> tuple[fl.Conv2d, fl.Conv2d]:
|
||||||
|
@ -246,20 +265,47 @@ class Conv2dLora(Lora):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def is_compatible(self, layer: fl.WeightedModule, /) -> bool:
|
||||||
|
if (
|
||||||
|
isinstance(layer, fl.Conv2d)
|
||||||
|
and layer.in_channels == self.in_channels
|
||||||
|
and layer.out_channels == self.out_channels
|
||||||
|
):
|
||||||
|
# stride cannot be inferred from the weights, so we assume it's the same as the layer
|
||||||
|
self.down.stride = layer.stride
|
||||||
|
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
|
class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
|
||||||
def __init__(self, target: fl.WeightedModule, lora: Lora) -> None:
|
def __init__(self, target: fl.WeightedModule, /, *loras: Lora) -> None:
|
||||||
with self.setup_adapter(target):
|
with self.setup_adapter(target):
|
||||||
super().__init__(target, lora)
|
super().__init__(target, *loras)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lora(self) -> Lora:
|
def names(self) -> list[str]:
|
||||||
return self.ensure_find(Lora)
|
return [lora.name for lora in self.layers(Lora)]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scale(self) -> float:
|
def loras(self) -> dict[str, Lora]:
|
||||||
return self.lora.scale
|
return {lora.name: lora for lora in self.layers(Lora)}
|
||||||
|
|
||||||
@scale.setter
|
@property
|
||||||
def scale(self, value: float) -> None:
|
def scales(self) -> dict[str, float]:
|
||||||
self.lora.scale = value
|
return {lora.name: lora.scale for lora in self.layers(Lora)}
|
||||||
|
|
||||||
|
@scales.setter
|
||||||
|
def scale(self, values: dict[str, float]) -> None:
|
||||||
|
for name, value in values.items():
|
||||||
|
self.loras[name].scale = value
|
||||||
|
|
||||||
|
def add_lora(self, lora: Lora, /) -> None:
|
||||||
|
assert lora.name not in self.names, f"LoRA layer with name {lora.name} already exists"
|
||||||
|
self.append(lora)
|
||||||
|
|
||||||
|
def remove_lora(self, name: str, /) -> Lora | None:
|
||||||
|
if name in self.names:
|
||||||
|
lora = self.loras[name]
|
||||||
|
self.remove(lora)
|
||||||
|
return lora
|
||||||
|
|
|
@ -26,19 +26,20 @@ class SDLoraManager:
|
||||||
assert isinstance(clip_text_encoder, fl.Chain)
|
assert isinstance(clip_text_encoder, fl.Chain)
|
||||||
return clip_text_encoder
|
return clip_text_encoder
|
||||||
|
|
||||||
def load(
|
def add_loras(
|
||||||
self,
|
self,
|
||||||
tensors: dict[str, Tensor],
|
name: str,
|
||||||
/,
|
/,
|
||||||
|
tensors: dict[str, Tensor],
|
||||||
scale: float = 1.0,
|
scale: float = 1.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Load the LoRA weights from a dictionary of tensors.
|
"""Load the LoRA weights from a dictionary of tensors.
|
||||||
|
|
||||||
Expects the keys to be in the commonly found formats on CivitAI's hub.
|
Expects the keys to be in the commonly found formats on CivitAI's hub.
|
||||||
"""
|
"""
|
||||||
assert len(self.lora_adapters) == 0, "Loras already loaded"
|
assert name not in self.names, f"LoRA {name} already exists"
|
||||||
loras = Lora.from_dict(
|
loras = Lora.from_dict(
|
||||||
{key: value.to(device=self.target.device, dtype=self.target.dtype) for key, value in tensors.items()}
|
name, {key: value.to(device=self.target.device, dtype=self.target.dtype) for key, value in tensors.items()}
|
||||||
)
|
)
|
||||||
loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)}
|
loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)}
|
||||||
|
|
||||||
|
@ -46,16 +47,25 @@ class SDLoraManager:
|
||||||
if not "unet" in loras and not "text" in loras:
|
if not "unet" in loras and not "text" in loras:
|
||||||
loras = {f"unet_{key}": loras[key] for key in loras.keys()}
|
loras = {f"unet_{key}": loras[key] for key in loras.keys()}
|
||||||
|
|
||||||
self.load_unet(loras)
|
self.add_loras_to_unet(loras)
|
||||||
self.load_text_encoder(loras)
|
self.add_loras_to_text_encoder(loras)
|
||||||
|
|
||||||
self.scale = scale
|
self.set_scale(name, scale)
|
||||||
|
|
||||||
def load_text_encoder(self, loras: dict[str, Lora], /) -> None:
|
def add_multiple_loras(
|
||||||
|
self,
|
||||||
|
/,
|
||||||
|
tensors: dict[str, dict[str, Tensor]],
|
||||||
|
scale: dict[str, float] | None = None,
|
||||||
|
) -> None:
|
||||||
|
for name, lora_tensors in tensors.items():
|
||||||
|
self.add_loras(name, tensors=lora_tensors, scale=scale[name] if scale else 1.0)
|
||||||
|
|
||||||
|
def add_loras_to_text_encoder(self, loras: dict[str, Lora], /) -> None:
|
||||||
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
|
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
|
||||||
SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder)
|
SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder)
|
||||||
|
|
||||||
def load_unet(self, loras: dict[str, Lora], /) -> None:
|
def add_loras_to_unet(self, loras: dict[str, Lora], /) -> None:
|
||||||
unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key}
|
unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key}
|
||||||
exclude: list[str] = []
|
exclude: list[str] = []
|
||||||
exclude = [
|
exclude = [
|
||||||
|
@ -65,14 +75,43 @@ class SDLoraManager:
|
||||||
]
|
]
|
||||||
SDLoraManager.auto_attach(unet_loras, self.unet, exclude=exclude)
|
SDLoraManager.auto_attach(unet_loras, self.unet, exclude=exclude)
|
||||||
|
|
||||||
def unload(self) -> None:
|
def remove_loras(self, *names: str) -> None:
|
||||||
|
for lora_adapter in self.lora_adapters:
|
||||||
|
for name in names:
|
||||||
|
lora_adapter.remove_lora(name)
|
||||||
|
|
||||||
|
if len(lora_adapter.loras) == 0:
|
||||||
|
lora_adapter.eject()
|
||||||
|
|
||||||
|
def remove_all(self) -> None:
|
||||||
for lora_adapter in self.lora_adapters:
|
for lora_adapter in self.lora_adapters:
|
||||||
lora_adapter.eject()
|
lora_adapter.eject()
|
||||||
|
|
||||||
|
def get_loras_by_name(self, name: str, /) -> list[Lora]:
|
||||||
|
return [lora for lora in self.loras if lora.name == name]
|
||||||
|
|
||||||
|
def get_scale(self, name: str, /) -> float:
|
||||||
|
loras = self.get_loras_by_name(name)
|
||||||
|
assert all([lora.scale == loras[0].scale for lora in loras]), "lora scales are not all the same"
|
||||||
|
return loras[0].scale
|
||||||
|
|
||||||
|
def set_scale(self, name: str, scale: float, /) -> None:
|
||||||
|
self.update_scales({name: scale})
|
||||||
|
|
||||||
|
def update_scales(self, scales: dict[str, float], /) -> None:
|
||||||
|
assert all([name in self.names for name in scales]), f"Scales keys must be a subset of {self.names}"
|
||||||
|
for name, scale in scales.items():
|
||||||
|
for lora in self.get_loras_by_name(name):
|
||||||
|
lora.scale = scale
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loras(self) -> list[Lora]:
|
def loras(self) -> list[Lora]:
|
||||||
return list(self.unet.layers(Lora)) + list(self.clip_text_encoder.layers(Lora))
|
return list(self.unet.layers(Lora)) + list(self.clip_text_encoder.layers(Lora))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def names(self) -> list[str]:
|
||||||
|
return list(set(lora.name for lora in self.loras))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lora_adapters(self) -> list[LoraAdapter]:
|
def lora_adapters(self) -> list[LoraAdapter]:
|
||||||
return list(self.unet.layers(LoraAdapter)) + list(self.clip_text_encoder.layers(LoraAdapter))
|
return list(self.unet.layers(LoraAdapter)) + list(self.clip_text_encoder.layers(LoraAdapter))
|
||||||
|
@ -87,15 +126,8 @@ class SDLoraManager:
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scale(self) -> float:
|
def scales(self) -> dict[str, float]:
|
||||||
assert len(self.loras) > 0, "No loras found"
|
return {name: self.get_scale(name) for name in self.names}
|
||||||
assert all([lora.scale == self.loras[0].scale for lora in self.loras])
|
|
||||||
return self.loras[0].scale
|
|
||||||
|
|
||||||
@scale.setter
|
|
||||||
def scale(self, value: float) -> None:
|
|
||||||
for lora in self.loras:
|
|
||||||
lora.scale = value
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def pad(input: str, /, padding_length: int = 2) -> str:
|
def pad(input: str, /, padding_length: int = 2) -> str:
|
||||||
|
@ -130,6 +162,8 @@ class SDLoraManager:
|
||||||
for key, lora in loras.items():
|
for key, lora in loras.items():
|
||||||
if attach := lora.auto_attach(target, exclude=exclude):
|
if attach := lora.auto_attach(target, exclude=exclude):
|
||||||
adapter, parent = attach
|
adapter, parent = attach
|
||||||
|
# if parent is None, `adapter` is already attached and `lora` has been added to it
|
||||||
|
if parent is not None:
|
||||||
adapter.inject(parent)
|
adapter.inject(parent)
|
||||||
else:
|
else:
|
||||||
failed_loras[key] = lora
|
failed_loras[key] = lora
|
||||||
|
|
118
tests/adapters/test_lora.py
Normal file
118
tests/adapters/test_lora.py
Normal file
|
@ -0,0 +1,118 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from refiners.fluxion import layers as fl
|
||||||
|
from refiners.fluxion.adapters.lora import Conv2dLora, LinearLora, Lora, LoraAdapter
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def lora() -> LinearLora:
|
||||||
|
return LinearLora("test", in_features=320, out_features=128, rank=16)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def conv_lora() -> Lora:
|
||||||
|
return Conv2dLora("conv_test", in_channels=16, out_channels=8, kernel_size=(3, 1), rank=4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_properties(lora: LinearLora, conv_lora: Lora) -> None:
|
||||||
|
assert lora.name == "test"
|
||||||
|
assert lora.rank == lora.down.out_features == lora.up.in_features == 16
|
||||||
|
assert lora.scale == 1.0
|
||||||
|
assert lora.in_features == lora.down.in_features == 320
|
||||||
|
assert lora.out_features == lora.up.out_features == 128
|
||||||
|
|
||||||
|
assert conv_lora.name == "conv_test"
|
||||||
|
assert conv_lora.rank == conv_lora.down.out_channels == conv_lora.up.in_channels == 4
|
||||||
|
assert conv_lora.scale == 1.0
|
||||||
|
assert conv_lora.in_channels == conv_lora.down.in_channels == 16
|
||||||
|
assert conv_lora.out_channels == conv_lora.up.out_channels == 8
|
||||||
|
assert conv_lora.kernel_size == (conv_lora.down.kernel_size[0], conv_lora.up.kernel_size[0]) == (3, 1)
|
||||||
|
# padding is set so the spatial dimensions are preserved
|
||||||
|
assert conv_lora.padding == (conv_lora.down.padding[0], conv_lora.up.padding[0]) == (0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scale_setter(lora: LinearLora) -> None:
|
||||||
|
lora.scale = 2.0
|
||||||
|
assert lora.scale == 2.0
|
||||||
|
assert lora.ensure_find(fl.Multiply).scale == 2.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_weights(lora: LinearLora, conv_lora: Conv2dLora) -> None:
|
||||||
|
new_lora = LinearLora.from_weights("test", down=lora.down.weight, up=lora.up.weight)
|
||||||
|
x = torch.randn(1, 320)
|
||||||
|
assert torch.allclose(lora(x), new_lora(x))
|
||||||
|
|
||||||
|
new_conv_lora = Conv2dLora.from_weights("conv_test", down=conv_lora.down.weight, up=conv_lora.up.weight)
|
||||||
|
x = torch.randn(1, 16, 64, 64)
|
||||||
|
assert torch.allclose(conv_lora(x), new_conv_lora(x))
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_dict() -> None:
|
||||||
|
state_dict = {
|
||||||
|
"down.weight": torch.randn(320, 128),
|
||||||
|
"up.weight": torch.randn(128, 320),
|
||||||
|
"this.is_not_used.alpha": torch.randn(1, 320),
|
||||||
|
"probably.a.conv.down.weight": torch.randn(4, 16, 3, 3),
|
||||||
|
"probably.a.conv.up.weight": torch.randn(8, 4, 1, 1),
|
||||||
|
}
|
||||||
|
loras = Lora.from_dict("test", state_dict=state_dict)
|
||||||
|
assert len(loras) == 2
|
||||||
|
linear_lora, conv_lora = loras.values()
|
||||||
|
assert isinstance(linear_lora, LinearLora)
|
||||||
|
assert isinstance(conv_lora, Conv2dLora)
|
||||||
|
assert linear_lora.name == "test"
|
||||||
|
assert conv_lora.name == "test"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def lora_adapter() -> LoraAdapter:
|
||||||
|
target = fl.Linear(320, 128)
|
||||||
|
lora1 = LinearLora("test1", in_features=320, out_features=128, rank=16, scale=2.0)
|
||||||
|
lora2 = LinearLora("test2", in_features=320, out_features=128, rank=16, scale=-1.0)
|
||||||
|
return LoraAdapter(target, lora1, lora2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_names(lora_adapter: LoraAdapter) -> None:
|
||||||
|
assert set(lora_adapter.names) == {"test1", "test2"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_loras(lora_adapter: LoraAdapter) -> None:
|
||||||
|
assert set(lora_adapter.loras.keys()) == {"test1", "test2"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_scales(lora_adapter: LoraAdapter) -> None:
|
||||||
|
assert set(lora_adapter.scales.keys()) == {"test1", "test2"}
|
||||||
|
assert lora_adapter.scales["test1"] == 2.0
|
||||||
|
assert lora_adapter.scales["test2"] == -1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_scale_setter_lora_adapter(lora_adapter: LoraAdapter) -> None:
|
||||||
|
lora_adapter.scale = {"test1": 0.0, "test2": 3.0}
|
||||||
|
assert lora_adapter.scales == {"test1": 0.0, "test2": 3.0}
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_lora(lora_adapter: LoraAdapter) -> None:
|
||||||
|
lora3 = LinearLora("test3", in_features=320, out_features=128, rank=16)
|
||||||
|
lora_adapter.add_lora(lora3)
|
||||||
|
assert "test3" in lora_adapter.names
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_lora(lora_adapter: LoraAdapter) -> None:
|
||||||
|
lora_adapter.remove_lora("test1")
|
||||||
|
assert "test1" not in lora_adapter.names
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_existing_lora(lora_adapter: LoraAdapter) -> None:
|
||||||
|
lora3 = LinearLora("test1", in_features=320, out_features=128, rank=16)
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
lora_adapter.add_lora(lora3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_nonexistent_lora(lora_adapter: LoraAdapter) -> None:
|
||||||
|
assert lora_adapter.remove_lora("test3") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_scale_for_nonexistent_lora(lora_adapter: LoraAdapter) -> None:
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
lora_adapter.scale = {"test3": 2.0}
|
|
@ -216,6 +216,26 @@ def lora_data_dpo(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image,
|
||||||
return expected_image, tensors
|
return expected_image, tensors
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def lora_sliders(test_weights_path: Path) -> tuple[dict[str, dict[str, torch.Tensor]], dict[str, float]]:
|
||||||
|
weights_path = test_weights_path / "loras" / "sliders"
|
||||||
|
|
||||||
|
if not weights_path.is_dir():
|
||||||
|
warn(f"could not find weights at {weights_path}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"age": load_tensors(weights_path / "age.pt"), # type: ignore
|
||||||
|
"cartoon_style": load_tensors(weights_path / "cartoon_style.pt"), # type: ignore
|
||||||
|
"eyesize": load_tensors(weights_path / "eyesize.pt"), # type: ignore
|
||||||
|
}, {
|
||||||
|
"age": 0.3,
|
||||||
|
"cartoon_style": -0.2,
|
||||||
|
"dpo": 1.4,
|
||||||
|
"eyesize": -0.2,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def scene_image_inpainting_refonly(ref_path: Path) -> Image.Image:
|
def scene_image_inpainting_refonly(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "inpainting-scene.png").convert("RGB")
|
return Image.open(ref_path / "inpainting-scene.png").convert("RGB")
|
||||||
|
@ -266,6 +286,11 @@ def expected_freeu(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(fp=ref_path / "expected_freeu.png").convert(mode="RGB")
|
return Image.open(fp=ref_path / "expected_freeu.png").convert(mode="RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_sdxl_multi_loras(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(fp=ref_path / "expected_sdxl_multi_loras.png").convert(mode="RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image.Image, Image.Image]:
|
def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image.Image, Image.Image]:
|
||||||
assets = Path(__file__).parent.parent.parent / "assets"
|
assets = Path(__file__).parent.parent.parent / "assets"
|
||||||
|
@ -1034,7 +1059,7 @@ def test_diffusion_lora(
|
||||||
|
|
||||||
sd15.set_inference_steps(30)
|
sd15.set_inference_steps(30)
|
||||||
|
|
||||||
SDLoraManager(sd15).load(lora_weights, scale=1)
|
SDLoraManager(sd15).add_loras("pokemon", lora_weights, scale=1)
|
||||||
|
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
@ -1067,7 +1092,7 @@ def test_diffusion_sdxl_lora(
|
||||||
prompt = "professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography"
|
prompt = "professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography"
|
||||||
negative_prompt = "3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white"
|
negative_prompt = "3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white"
|
||||||
|
|
||||||
SDLoraManager(sdxl).load(lora_weights, scale=lora_scale)
|
SDLoraManager(sdxl).add_loras("dpo", lora_weights, scale=lora_scale)
|
||||||
|
|
||||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||||
text=prompt, negative_text=negative_prompt
|
text=prompt, negative_text=negative_prompt
|
||||||
|
@ -1094,6 +1119,54 @@ def test_diffusion_sdxl_lora(
|
||||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_diffusion_sdxl_multiple_loras(
|
||||||
|
sdxl_ddim: StableDiffusion_XL,
|
||||||
|
lora_data_dpo: tuple[Image.Image, dict[str, torch.Tensor]],
|
||||||
|
lora_sliders: tuple[dict[str, dict[str, torch.Tensor]], dict[str, float]],
|
||||||
|
expected_sdxl_multi_loras: Image.Image,
|
||||||
|
) -> None:
|
||||||
|
sdxl = sdxl_ddim
|
||||||
|
expected_image = expected_sdxl_multi_loras
|
||||||
|
_, dpo = lora_data_dpo
|
||||||
|
loras, scales = lora_sliders
|
||||||
|
loras["dpo"] = dpo
|
||||||
|
|
||||||
|
SDLoraManager(sdxl).add_multiple_loras(loras, scales)
|
||||||
|
|
||||||
|
# parameters are the same as https://huggingface.co/radames/sdxl-DPO-LoRA
|
||||||
|
# except that we are using DDIM instead of sde-dpmsolver++
|
||||||
|
n_steps = 40
|
||||||
|
seed = 12341234123
|
||||||
|
guidance_scale = 4
|
||||||
|
prompt = "professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography"
|
||||||
|
negative_prompt = "3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white"
|
||||||
|
|
||||||
|
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||||
|
text=prompt, negative_text=negative_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
time_ids = sdxl.default_time_ids
|
||||||
|
sdxl.set_inference_steps(n_steps)
|
||||||
|
|
||||||
|
manual_seed(seed=seed)
|
||||||
|
x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype)
|
||||||
|
|
||||||
|
for step in sdxl.steps:
|
||||||
|
x = sdxl(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
pooled_text_embedding=pooled_text_embedding,
|
||||||
|
time_ids=time_ids,
|
||||||
|
condition_scale=guidance_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
predicted_image = sdxl.lda.decode_latents(x)
|
||||||
|
|
||||||
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_diffusion_refonly(
|
def test_diffusion_refonly(
|
||||||
sd15_ddim: StableDiffusion_1,
|
sd15_ddim: StableDiffusion_1,
|
||||||
|
|
|
@ -49,6 +49,7 @@ Special cases:
|
||||||
- `expected_freeu.png`
|
- `expected_freeu.png`
|
||||||
- `expected_dropy_slime_9752.png`
|
- `expected_dropy_slime_9752.png`
|
||||||
- `expected_sdxl_dpo_lora.png`
|
- `expected_sdxl_dpo_lora.png`
|
||||||
|
- `expected_sdxl_multi_loras.png`
|
||||||
|
|
||||||
## Other images
|
## Other images
|
||||||
|
|
||||||
|
|
BIN
tests/e2e/test_diffusion_ref/expected_sdxl_multi_loras.png
Normal file
BIN
tests/e2e/test_diffusion_ref/expected_sdxl_multi_loras.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.7 MiB |
88
tests/foundationals/latent_diffusion/test_lora_manager.py
Normal file
88
tests/foundationals/latent_diffusion/test_lora_manager.py
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from warnings import warn
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import load_tensors
|
||||||
|
from refiners.foundationals.latent_diffusion import StableDiffusion_1
|
||||||
|
from refiners.foundationals.latent_diffusion.lora import Lora, SDLoraManager
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def manager() -> SDLoraManager:
|
||||||
|
target = StableDiffusion_1()
|
||||||
|
return SDLoraManager(target)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def weights(test_weights_path: Path) -> dict[str, torch.Tensor]:
|
||||||
|
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin"
|
||||||
|
|
||||||
|
if not weights_path.is_file():
|
||||||
|
warn(f"could not find weights at {weights_path}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
|
||||||
|
return load_tensors(weights_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
||||||
|
manager.add_loras("pokemon-lora", tensors=weights)
|
||||||
|
assert "pokemon-lora" in manager.names
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError) as exc:
|
||||||
|
manager.add_loras("pokemon-lora", tensors=weights)
|
||||||
|
assert "already exists" in str(exc.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_multiple_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
||||||
|
manager.add_multiple_loras({"pokemon-lora": weights, "pokemon-lora2": weights})
|
||||||
|
assert "pokemon-lora" in manager.names
|
||||||
|
assert "pokemon-lora2" in manager.names
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
||||||
|
manager.add_multiple_loras({"pokemon-lora": weights, "pokemon-lora2": weights})
|
||||||
|
manager.remove_loras("pokemon-lora")
|
||||||
|
assert "pokemon-lora" not in manager.names
|
||||||
|
assert "pokemon-lora2" in manager.names
|
||||||
|
|
||||||
|
manager.remove_loras("pokemon-lora2")
|
||||||
|
assert "pokemon-lora2" not in manager.names
|
||||||
|
assert len(manager.names) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_all(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
||||||
|
manager.add_multiple_loras({"pokemon-lora": weights, "pokemon-lora2": weights})
|
||||||
|
manager.remove_all()
|
||||||
|
assert len(manager.names) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_lora(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
||||||
|
manager.add_loras("pokemon-lora", tensors=weights)
|
||||||
|
assert all(isinstance(lora, Lora) for lora in manager.get_loras_by_name("pokemon-lora"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_scale(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
||||||
|
manager.add_loras("pokemon-lora", tensors=weights, scale=0.4)
|
||||||
|
assert manager.get_scale("pokemon-lora") == 0.4
|
||||||
|
|
||||||
|
|
||||||
|
def test_names(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
||||||
|
assert manager.names == []
|
||||||
|
|
||||||
|
manager.add_loras("pokemon-lora", tensors=weights)
|
||||||
|
assert manager.names == ["pokemon-lora"]
|
||||||
|
|
||||||
|
manager.add_loras("pokemon-lora2", tensors=weights)
|
||||||
|
assert manager.names == ["pokemon-lora", "pokemon-lora2"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_scales(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
||||||
|
assert manager.scales == {}
|
||||||
|
|
||||||
|
manager.add_loras("pokemon-lora", tensors=weights, scale=0.4)
|
||||||
|
assert manager.scales == {"pokemon-lora": 0.4}
|
||||||
|
|
||||||
|
manager.add_loras("pokemon-lora2", tensors=weights, scale=0.5)
|
||||||
|
assert manager.scales == {"pokemon-lora": 0.4, "pokemon-lora2": 0.5}
|
Loading…
Reference in a new issue