make LoRA generic

This commit is contained in:
Pierre Chapuis 2024-02-06 10:21:13 +01:00
parent 471ef91d1c
commit 37425fb609
3 changed files with 31 additions and 32 deletions

View file

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Generic, TypeVar, cast
from torch import Tensor, device as Device, dtype as DType from torch import Tensor, device as Device, dtype as DType
from torch.nn import Parameter as TorchParameter from torch.nn import Parameter as TorchParameter
@ -7,8 +8,10 @@ from torch.nn.init import normal_, zeros_
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
T = TypeVar("T", bound=fl.WeightedModule)
class Lora(fl.Chain, ABC):
class Lora(Generic[T], fl.Chain, ABC):
"""Low-Rank Adaptation (LoRA) layer. """Low-Rank Adaptation (LoRA) layer.
This layer is composed of two [`WeightedModule`][refiners.fluxion.layers.WeightedModule]: This layer is composed of two [`WeightedModule`][refiners.fluxion.layers.WeightedModule]:
@ -55,9 +58,7 @@ class Lora(fl.Chain, ABC):
zeros_(tensor=self.up.weight) zeros_(tensor=self.up.weight)
@abstractmethod @abstractmethod
def lora_layers( def lora_layers(self, device: Device | str | None = None, dtype: DType | None = None) -> tuple[T, T]:
self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.WeightedModule, fl.WeightedModule]:
"""Create the down and up layers of the LoRA. """Create the down and up layers of the LoRA.
Args: Args:
@ -67,18 +68,18 @@ class Lora(fl.Chain, ABC):
... ...
@property @property
def down(self) -> fl.WeightedModule: def down(self) -> T:
"""The down layer.""" """The down layer."""
down_layer = self[0] down_layer = self[0]
assert isinstance(down_layer, fl.WeightedModule) assert isinstance(down_layer, fl.WeightedModule)
return down_layer return cast(T, down_layer)
@property @property
def up(self) -> fl.WeightedModule: def up(self) -> T:
"""The up layer.""" """The up layer."""
up_layer = self[1] up_layer = self[1]
assert isinstance(up_layer, fl.WeightedModule) assert isinstance(up_layer, fl.WeightedModule)
return up_layer return cast(T, up_layer)
@property @property
def rank(self) -> int: def rank(self) -> int:
@ -102,7 +103,7 @@ class Lora(fl.Chain, ABC):
/, /,
down: Tensor, down: Tensor,
up: Tensor, up: Tensor,
) -> "Lora": ) -> "Lora[Any]":
match (up.ndim, down.ndim): match (up.ndim, down.ndim):
case (2, 2): case (2, 2):
return LinearLora.from_weights(name, up=up, down=down) return LinearLora.from_weights(name, up=up, down=down)
@ -112,14 +113,14 @@ class Lora(fl.Chain, ABC):
raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}") raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}")
@classmethod @classmethod
def from_dict(cls, name: str, /, state_dict: dict[str, Tensor]) -> dict[str, "Lora"]: def from_dict(cls, name: str, /, state_dict: dict[str, Tensor]) -> dict[str, "Lora[Any]"]:
""" """
Create a dictionary of LoRA layers from a state dict. Create a dictionary of LoRA layers from a state dict.
Expects the state dict to be a succession of down and up weights. Expects the state dict to be a succession of down and up weights.
""" """
state_dict = {k: v for k, v in state_dict.items() if ".weight" in k} state_dict = {k: v for k, v in state_dict.items() if ".weight" in k}
loras: dict[str, Lora] = {} loras: dict[str, Lora[Any]] = {}
for down_key, down_tensor, up_tensor in zip( for down_key, down_tensor, up_tensor in zip(
list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2] list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2]
): ):
@ -168,7 +169,7 @@ class Lora(fl.Chain, ABC):
self.up.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype)) self.up.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
class LinearLora(Lora): class LinearLora(Lora[fl.Linear]):
"""Low-Rank Adaptation (LoRA) layer for linear layers. """Low-Rank Adaptation (LoRA) layer for linear layers.
This layer uses two [`Linear`][refiners.fluxion.layers.Linear] layers as its down and up layers. This layer uses two [`Linear`][refiners.fluxion.layers.Linear] layers as its down and up layers.
@ -254,7 +255,7 @@ class LinearLora(Lora):
return False return False
class Conv2dLora(Lora): class Conv2dLora(Lora[fl.Conv2d]):
"""Low-Rank Adaptation (LoRA) layer for 2D convolutional layers. """Low-Rank Adaptation (LoRA) layer for 2D convolutional layers.
This layer uses two [`Conv2d`][refiners.fluxion.layers.Conv2d] layers as its down and up layers. This layer uses two [`Conv2d`][refiners.fluxion.layers.Conv2d] layers as its down and up layers.
@ -374,7 +375,7 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
This adapter simply sums the target layer with the given LoRA layers. This adapter simply sums the target layer with the given LoRA layers.
""" """
def __init__(self, target: fl.WeightedModule, /, *loras: Lora) -> None: def __init__(self, target: fl.WeightedModule, /, *loras: Lora[Any]) -> None:
"""Initialize the adapter. """Initialize the adapter.
Args: Args:
@ -387,24 +388,24 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
@property @property
def names(self) -> list[str]: def names(self) -> list[str]:
"""The names of the LoRA layers.""" """The names of the LoRA layers."""
return [lora.name for lora in self.layers(Lora)] return [lora.name for lora in self.layers(Lora[Any])]
@property @property
def loras(self) -> dict[str, Lora]: def loras(self) -> dict[str, Lora[Any]]:
"""The LoRA layers indexed by name.""" """The LoRA layers indexed by name."""
return {lora.name: lora for lora in self.layers(Lora)} return {lora.name: lora for lora in self.layers(Lora[Any])}
@property @property
def scales(self) -> dict[str, float]: def scales(self) -> dict[str, float]:
"""The scales of the LoRA layers indexed by names.""" """The scales of the LoRA layers indexed by names."""
return {lora.name: lora.scale for lora in self.layers(Lora)} return {lora.name: lora.scale for lora in self.layers(Lora[Any])}
@scales.setter @scales.setter
def scale(self, values: dict[str, float]) -> None: def scale(self, values: dict[str, float]) -> None:
for name, value in values.items(): for name, value in values.items():
self.loras[name].scale = value self.loras[name].scale = value
def add_lora(self, lora: Lora, /) -> None: def add_lora(self, lora: Lora[Any], /) -> None:
"""Add a LoRA layer to the adapter. """Add a LoRA layer to the adapter.
Raises: Raises:
@ -416,7 +417,7 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
assert lora.name not in self.names, f"LoRA layer with name {lora.name} already exists" assert lora.name not in self.names, f"LoRA layer with name {lora.name} already exists"
self.append(lora) self.append(lora)
def remove_lora(self, name: str, /) -> Lora | None: def remove_lora(self, name: str, /) -> Lora[Any] | None:
"""Remove a LoRA layer from the adapter. """Remove a LoRA layer from the adapter.
Note: Note:

View file

@ -1,3 +1,4 @@
from typing import Any
from warnings import warn from warnings import warn
from torch import Tensor from torch import Tensor
@ -106,7 +107,7 @@ class SDLoraManager:
for name, lora_tensors in tensors.items(): for name, lora_tensors in tensors.items():
self.add_loras(name, tensors=lora_tensors, scale=scale[name] if scale else 1.0) self.add_loras(name, tensors=lora_tensors, scale=scale[name] if scale else 1.0)
def add_loras_to_text_encoder(self, loras: dict[str, Lora], /) -> None: def add_loras_to_text_encoder(self, loras: dict[str, Lora[Any]], /) -> None:
"""Add multiple LoRAs to the text encoder. """Add multiple LoRAs to the text encoder.
Args: Args:
@ -116,7 +117,7 @@ class SDLoraManager:
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key} text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder) SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder)
def add_loras_to_unet(self, loras: dict[str, Lora], /) -> None: def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None:
"""Add multiple LoRAs to the U-Net. """Add multiple LoRAs to the U-Net.
Args: Args:
@ -147,7 +148,7 @@ class SDLoraManager:
for lora_adapter in self.lora_adapters: for lora_adapter in self.lora_adapters:
lora_adapter.eject() lora_adapter.eject()
def get_loras_by_name(self, name: str, /) -> list[Lora]: def get_loras_by_name(self, name: str, /) -> list[Lora[Any]]:
"""Get the LoRA layers with the given name. """Get the LoRA layers with the given name.
Args: Args:
@ -190,9 +191,9 @@ class SDLoraManager:
lora.scale = scale lora.scale = scale
@property @property
def loras(self) -> list[Lora]: def loras(self) -> list[Lora[Any]]:
"""List of all the LoRA layers managed by the SDLoraManager.""" """List of all the LoRA layers managed by the SDLoraManager."""
return list(self.unet.layers(Lora)) + list(self.clip_text_encoder.layers(Lora)) return list(self.unet.layers(Lora[Any])) + list(self.clip_text_encoder.layers(Lora[Any]))
@property @property
def names(self) -> list[str]: def names(self) -> list[str]:
@ -239,12 +240,12 @@ class SDLoraManager:
@staticmethod @staticmethod
def auto_attach( def auto_attach(
loras: dict[str, Lora], loras: dict[str, Lora[Any]],
target: fl.Chain, target: fl.Chain,
/, /,
exclude: list[str] | None = None, exclude: list[str] | None = None,
) -> None: ) -> None:
failed_loras: dict[str, Lora] = {} failed_loras: dict[str, Lora[Any]] = {}
for key, lora in loras.items(): for key, lora in loras.items():
if attach := lora.auto_attach(target, exclude=exclude): if attach := lora.auto_attach(target, exclude=exclude):
adapter, parent = attach adapter, parent = attach

View file

@ -11,11 +11,11 @@ def lora() -> LinearLora:
@pytest.fixture @pytest.fixture
def conv_lora() -> Lora: def conv_lora() -> Conv2dLora:
return Conv2dLora("conv_test", in_channels=16, out_channels=8, kernel_size=(3, 1), rank=4) return Conv2dLora("conv_test", in_channels=16, out_channels=8, kernel_size=(3, 1), rank=4)
def test_properties(lora: LinearLora, conv_lora: Lora) -> None: def test_properties(lora: LinearLora, conv_lora: Conv2dLora) -> None:
assert lora.name == "test" assert lora.name == "test"
assert lora.rank == lora.down.out_features == lora.up.in_features == 16 assert lora.rank == lora.down.out_features == lora.up.in_features == 16
assert lora.scale == 1.0 assert lora.scale == 1.0
@ -27,7 +27,6 @@ def test_properties(lora: LinearLora, conv_lora: Lora) -> None:
assert conv_lora.scale == 1.0 assert conv_lora.scale == 1.0
assert conv_lora.in_channels == conv_lora.down.in_channels == 16 assert conv_lora.in_channels == conv_lora.down.in_channels == 16
assert conv_lora.out_channels == conv_lora.up.out_channels == 8 assert conv_lora.out_channels == conv_lora.up.out_channels == 8
assert isinstance(conv_lora.down, fl.Conv2d) and isinstance(conv_lora.up, fl.Conv2d)
assert conv_lora.kernel_size == (conv_lora.down.kernel_size[0], conv_lora.up.kernel_size[0]) == (3, 1) assert conv_lora.kernel_size == (conv_lora.down.kernel_size[0], conv_lora.up.kernel_size[0]) == (3, 1)
# padding is set so the spatial dimensions are preserved # padding is set so the spatial dimensions are preserved
assert conv_lora.padding == (conv_lora.down.padding[0], conv_lora.up.padding[0]) == (0, 1) assert conv_lora.padding == (conv_lora.down.padding[0], conv_lora.up.padding[0]) == (0, 1)
@ -40,12 +39,10 @@ def test_scale_setter(lora: LinearLora) -> None:
def test_from_weights(lora: LinearLora, conv_lora: Conv2dLora) -> None: def test_from_weights(lora: LinearLora, conv_lora: Conv2dLora) -> None:
assert isinstance(lora.down, fl.Linear) and isinstance(lora.up, fl.Linear)
new_lora = LinearLora.from_weights("test", down=lora.down.weight, up=lora.up.weight) new_lora = LinearLora.from_weights("test", down=lora.down.weight, up=lora.up.weight)
x = torch.randn(1, 320) x = torch.randn(1, 320)
assert torch.allclose(lora(x), new_lora(x)) assert torch.allclose(lora(x), new_lora(x))
assert isinstance(conv_lora.down, fl.Conv2d) and isinstance(conv_lora.up, fl.Conv2d)
new_conv_lora = Conv2dLora.from_weights("conv_test", down=conv_lora.down.weight, up=conv_lora.up.weight) new_conv_lora = Conv2dLora.from_weights("conv_test", down=conv_lora.down.weight, up=conv_lora.up.weight)
x = torch.randn(1, 16, 64, 64) x = torch.randn(1, 16, 64, 64)
assert torch.allclose(conv_lora(x), new_conv_lora(x)) assert torch.allclose(conv_lora(x), new_conv_lora(x))