mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
make LoRA generic
This commit is contained in:
parent
471ef91d1c
commit
37425fb609
|
@ -1,4 +1,5 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Generic, TypeVar, cast
|
||||||
|
|
||||||
from torch import Tensor, device as Device, dtype as DType
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
from torch.nn import Parameter as TorchParameter
|
from torch.nn import Parameter as TorchParameter
|
||||||
|
@ -7,8 +8,10 @@ from torch.nn.init import normal_, zeros_
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.adapters.adapter import Adapter
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=fl.WeightedModule)
|
||||||
|
|
||||||
class Lora(fl.Chain, ABC):
|
|
||||||
|
class Lora(Generic[T], fl.Chain, ABC):
|
||||||
"""Low-Rank Adaptation (LoRA) layer.
|
"""Low-Rank Adaptation (LoRA) layer.
|
||||||
|
|
||||||
This layer is composed of two [`WeightedModule`][refiners.fluxion.layers.WeightedModule]:
|
This layer is composed of two [`WeightedModule`][refiners.fluxion.layers.WeightedModule]:
|
||||||
|
@ -55,9 +58,7 @@ class Lora(fl.Chain, ABC):
|
||||||
zeros_(tensor=self.up.weight)
|
zeros_(tensor=self.up.weight)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def lora_layers(
|
def lora_layers(self, device: Device | str | None = None, dtype: DType | None = None) -> tuple[T, T]:
|
||||||
self, device: Device | str | None = None, dtype: DType | None = None
|
|
||||||
) -> tuple[fl.WeightedModule, fl.WeightedModule]:
|
|
||||||
"""Create the down and up layers of the LoRA.
|
"""Create the down and up layers of the LoRA.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -67,18 +68,18 @@ class Lora(fl.Chain, ABC):
|
||||||
...
|
...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def down(self) -> fl.WeightedModule:
|
def down(self) -> T:
|
||||||
"""The down layer."""
|
"""The down layer."""
|
||||||
down_layer = self[0]
|
down_layer = self[0]
|
||||||
assert isinstance(down_layer, fl.WeightedModule)
|
assert isinstance(down_layer, fl.WeightedModule)
|
||||||
return down_layer
|
return cast(T, down_layer)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def up(self) -> fl.WeightedModule:
|
def up(self) -> T:
|
||||||
"""The up layer."""
|
"""The up layer."""
|
||||||
up_layer = self[1]
|
up_layer = self[1]
|
||||||
assert isinstance(up_layer, fl.WeightedModule)
|
assert isinstance(up_layer, fl.WeightedModule)
|
||||||
return up_layer
|
return cast(T, up_layer)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def rank(self) -> int:
|
def rank(self) -> int:
|
||||||
|
@ -102,7 +103,7 @@ class Lora(fl.Chain, ABC):
|
||||||
/,
|
/,
|
||||||
down: Tensor,
|
down: Tensor,
|
||||||
up: Tensor,
|
up: Tensor,
|
||||||
) -> "Lora":
|
) -> "Lora[Any]":
|
||||||
match (up.ndim, down.ndim):
|
match (up.ndim, down.ndim):
|
||||||
case (2, 2):
|
case (2, 2):
|
||||||
return LinearLora.from_weights(name, up=up, down=down)
|
return LinearLora.from_weights(name, up=up, down=down)
|
||||||
|
@ -112,14 +113,14 @@ class Lora(fl.Chain, ABC):
|
||||||
raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}")
|
raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, name: str, /, state_dict: dict[str, Tensor]) -> dict[str, "Lora"]:
|
def from_dict(cls, name: str, /, state_dict: dict[str, Tensor]) -> dict[str, "Lora[Any]"]:
|
||||||
"""
|
"""
|
||||||
Create a dictionary of LoRA layers from a state dict.
|
Create a dictionary of LoRA layers from a state dict.
|
||||||
|
|
||||||
Expects the state dict to be a succession of down and up weights.
|
Expects the state dict to be a succession of down and up weights.
|
||||||
"""
|
"""
|
||||||
state_dict = {k: v for k, v in state_dict.items() if ".weight" in k}
|
state_dict = {k: v for k, v in state_dict.items() if ".weight" in k}
|
||||||
loras: dict[str, Lora] = {}
|
loras: dict[str, Lora[Any]] = {}
|
||||||
for down_key, down_tensor, up_tensor in zip(
|
for down_key, down_tensor, up_tensor in zip(
|
||||||
list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2]
|
list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2]
|
||||||
):
|
):
|
||||||
|
@ -168,7 +169,7 @@ class Lora(fl.Chain, ABC):
|
||||||
self.up.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
|
self.up.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
|
||||||
|
|
||||||
|
|
||||||
class LinearLora(Lora):
|
class LinearLora(Lora[fl.Linear]):
|
||||||
"""Low-Rank Adaptation (LoRA) layer for linear layers.
|
"""Low-Rank Adaptation (LoRA) layer for linear layers.
|
||||||
|
|
||||||
This layer uses two [`Linear`][refiners.fluxion.layers.Linear] layers as its down and up layers.
|
This layer uses two [`Linear`][refiners.fluxion.layers.Linear] layers as its down and up layers.
|
||||||
|
@ -254,7 +255,7 @@ class LinearLora(Lora):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class Conv2dLora(Lora):
|
class Conv2dLora(Lora[fl.Conv2d]):
|
||||||
"""Low-Rank Adaptation (LoRA) layer for 2D convolutional layers.
|
"""Low-Rank Adaptation (LoRA) layer for 2D convolutional layers.
|
||||||
|
|
||||||
This layer uses two [`Conv2d`][refiners.fluxion.layers.Conv2d] layers as its down and up layers.
|
This layer uses two [`Conv2d`][refiners.fluxion.layers.Conv2d] layers as its down and up layers.
|
||||||
|
@ -374,7 +375,7 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
|
||||||
This adapter simply sums the target layer with the given LoRA layers.
|
This adapter simply sums the target layer with the given LoRA layers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, target: fl.WeightedModule, /, *loras: Lora) -> None:
|
def __init__(self, target: fl.WeightedModule, /, *loras: Lora[Any]) -> None:
|
||||||
"""Initialize the adapter.
|
"""Initialize the adapter.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -387,24 +388,24 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
|
||||||
@property
|
@property
|
||||||
def names(self) -> list[str]:
|
def names(self) -> list[str]:
|
||||||
"""The names of the LoRA layers."""
|
"""The names of the LoRA layers."""
|
||||||
return [lora.name for lora in self.layers(Lora)]
|
return [lora.name for lora in self.layers(Lora[Any])]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loras(self) -> dict[str, Lora]:
|
def loras(self) -> dict[str, Lora[Any]]:
|
||||||
"""The LoRA layers indexed by name."""
|
"""The LoRA layers indexed by name."""
|
||||||
return {lora.name: lora for lora in self.layers(Lora)}
|
return {lora.name: lora for lora in self.layers(Lora[Any])}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scales(self) -> dict[str, float]:
|
def scales(self) -> dict[str, float]:
|
||||||
"""The scales of the LoRA layers indexed by names."""
|
"""The scales of the LoRA layers indexed by names."""
|
||||||
return {lora.name: lora.scale for lora in self.layers(Lora)}
|
return {lora.name: lora.scale for lora in self.layers(Lora[Any])}
|
||||||
|
|
||||||
@scales.setter
|
@scales.setter
|
||||||
def scale(self, values: dict[str, float]) -> None:
|
def scale(self, values: dict[str, float]) -> None:
|
||||||
for name, value in values.items():
|
for name, value in values.items():
|
||||||
self.loras[name].scale = value
|
self.loras[name].scale = value
|
||||||
|
|
||||||
def add_lora(self, lora: Lora, /) -> None:
|
def add_lora(self, lora: Lora[Any], /) -> None:
|
||||||
"""Add a LoRA layer to the adapter.
|
"""Add a LoRA layer to the adapter.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -416,7 +417,7 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
|
||||||
assert lora.name not in self.names, f"LoRA layer with name {lora.name} already exists"
|
assert lora.name not in self.names, f"LoRA layer with name {lora.name} already exists"
|
||||||
self.append(lora)
|
self.append(lora)
|
||||||
|
|
||||||
def remove_lora(self, name: str, /) -> Lora | None:
|
def remove_lora(self, name: str, /) -> Lora[Any] | None:
|
||||||
"""Remove a LoRA layer from the adapter.
|
"""Remove a LoRA layer from the adapter.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from typing import Any
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
@ -106,7 +107,7 @@ class SDLoraManager:
|
||||||
for name, lora_tensors in tensors.items():
|
for name, lora_tensors in tensors.items():
|
||||||
self.add_loras(name, tensors=lora_tensors, scale=scale[name] if scale else 1.0)
|
self.add_loras(name, tensors=lora_tensors, scale=scale[name] if scale else 1.0)
|
||||||
|
|
||||||
def add_loras_to_text_encoder(self, loras: dict[str, Lora], /) -> None:
|
def add_loras_to_text_encoder(self, loras: dict[str, Lora[Any]], /) -> None:
|
||||||
"""Add multiple LoRAs to the text encoder.
|
"""Add multiple LoRAs to the text encoder.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -116,7 +117,7 @@ class SDLoraManager:
|
||||||
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
|
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
|
||||||
SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder)
|
SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder)
|
||||||
|
|
||||||
def add_loras_to_unet(self, loras: dict[str, Lora], /) -> None:
|
def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None:
|
||||||
"""Add multiple LoRAs to the U-Net.
|
"""Add multiple LoRAs to the U-Net.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -147,7 +148,7 @@ class SDLoraManager:
|
||||||
for lora_adapter in self.lora_adapters:
|
for lora_adapter in self.lora_adapters:
|
||||||
lora_adapter.eject()
|
lora_adapter.eject()
|
||||||
|
|
||||||
def get_loras_by_name(self, name: str, /) -> list[Lora]:
|
def get_loras_by_name(self, name: str, /) -> list[Lora[Any]]:
|
||||||
"""Get the LoRA layers with the given name.
|
"""Get the LoRA layers with the given name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -190,9 +191,9 @@ class SDLoraManager:
|
||||||
lora.scale = scale
|
lora.scale = scale
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loras(self) -> list[Lora]:
|
def loras(self) -> list[Lora[Any]]:
|
||||||
"""List of all the LoRA layers managed by the SDLoraManager."""
|
"""List of all the LoRA layers managed by the SDLoraManager."""
|
||||||
return list(self.unet.layers(Lora)) + list(self.clip_text_encoder.layers(Lora))
|
return list(self.unet.layers(Lora[Any])) + list(self.clip_text_encoder.layers(Lora[Any]))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def names(self) -> list[str]:
|
def names(self) -> list[str]:
|
||||||
|
@ -239,12 +240,12 @@ class SDLoraManager:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def auto_attach(
|
def auto_attach(
|
||||||
loras: dict[str, Lora],
|
loras: dict[str, Lora[Any]],
|
||||||
target: fl.Chain,
|
target: fl.Chain,
|
||||||
/,
|
/,
|
||||||
exclude: list[str] | None = None,
|
exclude: list[str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
failed_loras: dict[str, Lora] = {}
|
failed_loras: dict[str, Lora[Any]] = {}
|
||||||
for key, lora in loras.items():
|
for key, lora in loras.items():
|
||||||
if attach := lora.auto_attach(target, exclude=exclude):
|
if attach := lora.auto_attach(target, exclude=exclude):
|
||||||
adapter, parent = attach
|
adapter, parent = attach
|
||||||
|
|
|
@ -11,11 +11,11 @@ def lora() -> LinearLora:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def conv_lora() -> Lora:
|
def conv_lora() -> Conv2dLora:
|
||||||
return Conv2dLora("conv_test", in_channels=16, out_channels=8, kernel_size=(3, 1), rank=4)
|
return Conv2dLora("conv_test", in_channels=16, out_channels=8, kernel_size=(3, 1), rank=4)
|
||||||
|
|
||||||
|
|
||||||
def test_properties(lora: LinearLora, conv_lora: Lora) -> None:
|
def test_properties(lora: LinearLora, conv_lora: Conv2dLora) -> None:
|
||||||
assert lora.name == "test"
|
assert lora.name == "test"
|
||||||
assert lora.rank == lora.down.out_features == lora.up.in_features == 16
|
assert lora.rank == lora.down.out_features == lora.up.in_features == 16
|
||||||
assert lora.scale == 1.0
|
assert lora.scale == 1.0
|
||||||
|
@ -27,7 +27,6 @@ def test_properties(lora: LinearLora, conv_lora: Lora) -> None:
|
||||||
assert conv_lora.scale == 1.0
|
assert conv_lora.scale == 1.0
|
||||||
assert conv_lora.in_channels == conv_lora.down.in_channels == 16
|
assert conv_lora.in_channels == conv_lora.down.in_channels == 16
|
||||||
assert conv_lora.out_channels == conv_lora.up.out_channels == 8
|
assert conv_lora.out_channels == conv_lora.up.out_channels == 8
|
||||||
assert isinstance(conv_lora.down, fl.Conv2d) and isinstance(conv_lora.up, fl.Conv2d)
|
|
||||||
assert conv_lora.kernel_size == (conv_lora.down.kernel_size[0], conv_lora.up.kernel_size[0]) == (3, 1)
|
assert conv_lora.kernel_size == (conv_lora.down.kernel_size[0], conv_lora.up.kernel_size[0]) == (3, 1)
|
||||||
# padding is set so the spatial dimensions are preserved
|
# padding is set so the spatial dimensions are preserved
|
||||||
assert conv_lora.padding == (conv_lora.down.padding[0], conv_lora.up.padding[0]) == (0, 1)
|
assert conv_lora.padding == (conv_lora.down.padding[0], conv_lora.up.padding[0]) == (0, 1)
|
||||||
|
@ -40,12 +39,10 @@ def test_scale_setter(lora: LinearLora) -> None:
|
||||||
|
|
||||||
|
|
||||||
def test_from_weights(lora: LinearLora, conv_lora: Conv2dLora) -> None:
|
def test_from_weights(lora: LinearLora, conv_lora: Conv2dLora) -> None:
|
||||||
assert isinstance(lora.down, fl.Linear) and isinstance(lora.up, fl.Linear)
|
|
||||||
new_lora = LinearLora.from_weights("test", down=lora.down.weight, up=lora.up.weight)
|
new_lora = LinearLora.from_weights("test", down=lora.down.weight, up=lora.up.weight)
|
||||||
x = torch.randn(1, 320)
|
x = torch.randn(1, 320)
|
||||||
assert torch.allclose(lora(x), new_lora(x))
|
assert torch.allclose(lora(x), new_lora(x))
|
||||||
|
|
||||||
assert isinstance(conv_lora.down, fl.Conv2d) and isinstance(conv_lora.up, fl.Conv2d)
|
|
||||||
new_conv_lora = Conv2dLora.from_weights("conv_test", down=conv_lora.down.weight, up=conv_lora.up.weight)
|
new_conv_lora = Conv2dLora.from_weights("conv_test", down=conv_lora.down.weight, up=conv_lora.up.weight)
|
||||||
x = torch.randn(1, 16, 64, 64)
|
x = torch.randn(1, 16, 64, 64)
|
||||||
assert torch.allclose(conv_lora(x), new_conv_lora(x))
|
assert torch.allclose(conv_lora(x), new_conv_lora(x))
|
||||||
|
|
Loading…
Reference in a new issue