make LoRA generic

This commit is contained in:
Pierre Chapuis 2024-02-06 10:21:13 +01:00
parent 471ef91d1c
commit 37425fb609
3 changed files with 31 additions and 32 deletions

View file

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Generic, TypeVar, cast
from torch import Tensor, device as Device, dtype as DType
from torch.nn import Parameter as TorchParameter
@ -7,8 +8,10 @@ from torch.nn.init import normal_, zeros_
import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter
T = TypeVar("T", bound=fl.WeightedModule)
class Lora(fl.Chain, ABC):
class Lora(Generic[T], fl.Chain, ABC):
"""Low-Rank Adaptation (LoRA) layer.
This layer is composed of two [`WeightedModule`][refiners.fluxion.layers.WeightedModule]:
@ -55,9 +58,7 @@ class Lora(fl.Chain, ABC):
zeros_(tensor=self.up.weight)
@abstractmethod
def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.WeightedModule, fl.WeightedModule]:
def lora_layers(self, device: Device | str | None = None, dtype: DType | None = None) -> tuple[T, T]:
"""Create the down and up layers of the LoRA.
Args:
@ -67,18 +68,18 @@ class Lora(fl.Chain, ABC):
...
@property
def down(self) -> fl.WeightedModule:
def down(self) -> T:
"""The down layer."""
down_layer = self[0]
assert isinstance(down_layer, fl.WeightedModule)
return down_layer
return cast(T, down_layer)
@property
def up(self) -> fl.WeightedModule:
def up(self) -> T:
"""The up layer."""
up_layer = self[1]
assert isinstance(up_layer, fl.WeightedModule)
return up_layer
return cast(T, up_layer)
@property
def rank(self) -> int:
@ -102,7 +103,7 @@ class Lora(fl.Chain, ABC):
/,
down: Tensor,
up: Tensor,
) -> "Lora":
) -> "Lora[Any]":
match (up.ndim, down.ndim):
case (2, 2):
return LinearLora.from_weights(name, up=up, down=down)
@ -112,14 +113,14 @@ class Lora(fl.Chain, ABC):
raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}")
@classmethod
def from_dict(cls, name: str, /, state_dict: dict[str, Tensor]) -> dict[str, "Lora"]:
def from_dict(cls, name: str, /, state_dict: dict[str, Tensor]) -> dict[str, "Lora[Any]"]:
"""
Create a dictionary of LoRA layers from a state dict.
Expects the state dict to be a succession of down and up weights.
"""
state_dict = {k: v for k, v in state_dict.items() if ".weight" in k}
loras: dict[str, Lora] = {}
loras: dict[str, Lora[Any]] = {}
for down_key, down_tensor, up_tensor in zip(
list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2]
):
@ -168,7 +169,7 @@ class Lora(fl.Chain, ABC):
self.up.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
class LinearLora(Lora):
class LinearLora(Lora[fl.Linear]):
"""Low-Rank Adaptation (LoRA) layer for linear layers.
This layer uses two [`Linear`][refiners.fluxion.layers.Linear] layers as its down and up layers.
@ -254,7 +255,7 @@ class LinearLora(Lora):
return False
class Conv2dLora(Lora):
class Conv2dLora(Lora[fl.Conv2d]):
"""Low-Rank Adaptation (LoRA) layer for 2D convolutional layers.
This layer uses two [`Conv2d`][refiners.fluxion.layers.Conv2d] layers as its down and up layers.
@ -374,7 +375,7 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
This adapter simply sums the target layer with the given LoRA layers.
"""
def __init__(self, target: fl.WeightedModule, /, *loras: Lora) -> None:
def __init__(self, target: fl.WeightedModule, /, *loras: Lora[Any]) -> None:
"""Initialize the adapter.
Args:
@ -387,24 +388,24 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
@property
def names(self) -> list[str]:
"""The names of the LoRA layers."""
return [lora.name for lora in self.layers(Lora)]
return [lora.name for lora in self.layers(Lora[Any])]
@property
def loras(self) -> dict[str, Lora]:
def loras(self) -> dict[str, Lora[Any]]:
"""The LoRA layers indexed by name."""
return {lora.name: lora for lora in self.layers(Lora)}
return {lora.name: lora for lora in self.layers(Lora[Any])}
@property
def scales(self) -> dict[str, float]:
"""The scales of the LoRA layers indexed by names."""
return {lora.name: lora.scale for lora in self.layers(Lora)}
return {lora.name: lora.scale for lora in self.layers(Lora[Any])}
@scales.setter
def scale(self, values: dict[str, float]) -> None:
for name, value in values.items():
self.loras[name].scale = value
def add_lora(self, lora: Lora, /) -> None:
def add_lora(self, lora: Lora[Any], /) -> None:
"""Add a LoRA layer to the adapter.
Raises:
@ -416,7 +417,7 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
assert lora.name not in self.names, f"LoRA layer with name {lora.name} already exists"
self.append(lora)
def remove_lora(self, name: str, /) -> Lora | None:
def remove_lora(self, name: str, /) -> Lora[Any] | None:
"""Remove a LoRA layer from the adapter.
Note:

View file

@ -1,3 +1,4 @@
from typing import Any
from warnings import warn
from torch import Tensor
@ -106,7 +107,7 @@ class SDLoraManager:
for name, lora_tensors in tensors.items():
self.add_loras(name, tensors=lora_tensors, scale=scale[name] if scale else 1.0)
def add_loras_to_text_encoder(self, loras: dict[str, Lora], /) -> None:
def add_loras_to_text_encoder(self, loras: dict[str, Lora[Any]], /) -> None:
"""Add multiple LoRAs to the text encoder.
Args:
@ -116,7 +117,7 @@ class SDLoraManager:
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder)
def add_loras_to_unet(self, loras: dict[str, Lora], /) -> None:
def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None:
"""Add multiple LoRAs to the U-Net.
Args:
@ -147,7 +148,7 @@ class SDLoraManager:
for lora_adapter in self.lora_adapters:
lora_adapter.eject()
def get_loras_by_name(self, name: str, /) -> list[Lora]:
def get_loras_by_name(self, name: str, /) -> list[Lora[Any]]:
"""Get the LoRA layers with the given name.
Args:
@ -190,9 +191,9 @@ class SDLoraManager:
lora.scale = scale
@property
def loras(self) -> list[Lora]:
def loras(self) -> list[Lora[Any]]:
"""List of all the LoRA layers managed by the SDLoraManager."""
return list(self.unet.layers(Lora)) + list(self.clip_text_encoder.layers(Lora))
return list(self.unet.layers(Lora[Any])) + list(self.clip_text_encoder.layers(Lora[Any]))
@property
def names(self) -> list[str]:
@ -239,12 +240,12 @@ class SDLoraManager:
@staticmethod
def auto_attach(
loras: dict[str, Lora],
loras: dict[str, Lora[Any]],
target: fl.Chain,
/,
exclude: list[str] | None = None,
) -> None:
failed_loras: dict[str, Lora] = {}
failed_loras: dict[str, Lora[Any]] = {}
for key, lora in loras.items():
if attach := lora.auto_attach(target, exclude=exclude):
adapter, parent = attach

View file

@ -11,11 +11,11 @@ def lora() -> LinearLora:
@pytest.fixture
def conv_lora() -> Lora:
def conv_lora() -> Conv2dLora:
return Conv2dLora("conv_test", in_channels=16, out_channels=8, kernel_size=(3, 1), rank=4)
def test_properties(lora: LinearLora, conv_lora: Lora) -> None:
def test_properties(lora: LinearLora, conv_lora: Conv2dLora) -> None:
assert lora.name == "test"
assert lora.rank == lora.down.out_features == lora.up.in_features == 16
assert lora.scale == 1.0
@ -27,7 +27,6 @@ def test_properties(lora: LinearLora, conv_lora: Lora) -> None:
assert conv_lora.scale == 1.0
assert conv_lora.in_channels == conv_lora.down.in_channels == 16
assert conv_lora.out_channels == conv_lora.up.out_channels == 8
assert isinstance(conv_lora.down, fl.Conv2d) and isinstance(conv_lora.up, fl.Conv2d)
assert conv_lora.kernel_size == (conv_lora.down.kernel_size[0], conv_lora.up.kernel_size[0]) == (3, 1)
# padding is set so the spatial dimensions are preserved
assert conv_lora.padding == (conv_lora.down.padding[0], conv_lora.up.padding[0]) == (0, 1)
@ -40,12 +39,10 @@ def test_scale_setter(lora: LinearLora) -> None:
def test_from_weights(lora: LinearLora, conv_lora: Conv2dLora) -> None:
assert isinstance(lora.down, fl.Linear) and isinstance(lora.up, fl.Linear)
new_lora = LinearLora.from_weights("test", down=lora.down.weight, up=lora.up.weight)
x = torch.randn(1, 320)
assert torch.allclose(lora(x), new_lora(x))
assert isinstance(conv_lora.down, fl.Conv2d) and isinstance(conv_lora.up, fl.Conv2d)
new_conv_lora = Conv2dLora.from_weights("conv_test", down=conv_lora.down.weight, up=conv_lora.up.weight)
x = torch.randn(1, 16, 64, 64)
assert torch.allclose(conv_lora(x), new_conv_lora(x))