mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 15:02:01 +00:00
make LoRA generic
This commit is contained in:
parent
471ef91d1c
commit
37425fb609
|
@ -1,4 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Generic, TypeVar, cast
|
||||
|
||||
from torch import Tensor, device as Device, dtype as DType
|
||||
from torch.nn import Parameter as TorchParameter
|
||||
|
@ -7,8 +8,10 @@ from torch.nn.init import normal_, zeros_
|
|||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.adapters.adapter import Adapter
|
||||
|
||||
T = TypeVar("T", bound=fl.WeightedModule)
|
||||
|
||||
class Lora(fl.Chain, ABC):
|
||||
|
||||
class Lora(Generic[T], fl.Chain, ABC):
|
||||
"""Low-Rank Adaptation (LoRA) layer.
|
||||
|
||||
This layer is composed of two [`WeightedModule`][refiners.fluxion.layers.WeightedModule]:
|
||||
|
@ -55,9 +58,7 @@ class Lora(fl.Chain, ABC):
|
|||
zeros_(tensor=self.up.weight)
|
||||
|
||||
@abstractmethod
|
||||
def lora_layers(
|
||||
self, device: Device | str | None = None, dtype: DType | None = None
|
||||
) -> tuple[fl.WeightedModule, fl.WeightedModule]:
|
||||
def lora_layers(self, device: Device | str | None = None, dtype: DType | None = None) -> tuple[T, T]:
|
||||
"""Create the down and up layers of the LoRA.
|
||||
|
||||
Args:
|
||||
|
@ -67,18 +68,18 @@ class Lora(fl.Chain, ABC):
|
|||
...
|
||||
|
||||
@property
|
||||
def down(self) -> fl.WeightedModule:
|
||||
def down(self) -> T:
|
||||
"""The down layer."""
|
||||
down_layer = self[0]
|
||||
assert isinstance(down_layer, fl.WeightedModule)
|
||||
return down_layer
|
||||
return cast(T, down_layer)
|
||||
|
||||
@property
|
||||
def up(self) -> fl.WeightedModule:
|
||||
def up(self) -> T:
|
||||
"""The up layer."""
|
||||
up_layer = self[1]
|
||||
assert isinstance(up_layer, fl.WeightedModule)
|
||||
return up_layer
|
||||
return cast(T, up_layer)
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
|
@ -102,7 +103,7 @@ class Lora(fl.Chain, ABC):
|
|||
/,
|
||||
down: Tensor,
|
||||
up: Tensor,
|
||||
) -> "Lora":
|
||||
) -> "Lora[Any]":
|
||||
match (up.ndim, down.ndim):
|
||||
case (2, 2):
|
||||
return LinearLora.from_weights(name, up=up, down=down)
|
||||
|
@ -112,14 +113,14 @@ class Lora(fl.Chain, ABC):
|
|||
raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, name: str, /, state_dict: dict[str, Tensor]) -> dict[str, "Lora"]:
|
||||
def from_dict(cls, name: str, /, state_dict: dict[str, Tensor]) -> dict[str, "Lora[Any]"]:
|
||||
"""
|
||||
Create a dictionary of LoRA layers from a state dict.
|
||||
|
||||
Expects the state dict to be a succession of down and up weights.
|
||||
"""
|
||||
state_dict = {k: v for k, v in state_dict.items() if ".weight" in k}
|
||||
loras: dict[str, Lora] = {}
|
||||
loras: dict[str, Lora[Any]] = {}
|
||||
for down_key, down_tensor, up_tensor in zip(
|
||||
list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2]
|
||||
):
|
||||
|
@ -168,7 +169,7 @@ class Lora(fl.Chain, ABC):
|
|||
self.up.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
|
||||
|
||||
|
||||
class LinearLora(Lora):
|
||||
class LinearLora(Lora[fl.Linear]):
|
||||
"""Low-Rank Adaptation (LoRA) layer for linear layers.
|
||||
|
||||
This layer uses two [`Linear`][refiners.fluxion.layers.Linear] layers as its down and up layers.
|
||||
|
@ -254,7 +255,7 @@ class LinearLora(Lora):
|
|||
return False
|
||||
|
||||
|
||||
class Conv2dLora(Lora):
|
||||
class Conv2dLora(Lora[fl.Conv2d]):
|
||||
"""Low-Rank Adaptation (LoRA) layer for 2D convolutional layers.
|
||||
|
||||
This layer uses two [`Conv2d`][refiners.fluxion.layers.Conv2d] layers as its down and up layers.
|
||||
|
@ -374,7 +375,7 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
|
|||
This adapter simply sums the target layer with the given LoRA layers.
|
||||
"""
|
||||
|
||||
def __init__(self, target: fl.WeightedModule, /, *loras: Lora) -> None:
|
||||
def __init__(self, target: fl.WeightedModule, /, *loras: Lora[Any]) -> None:
|
||||
"""Initialize the adapter.
|
||||
|
||||
Args:
|
||||
|
@ -387,24 +388,24 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
|
|||
@property
|
||||
def names(self) -> list[str]:
|
||||
"""The names of the LoRA layers."""
|
||||
return [lora.name for lora in self.layers(Lora)]
|
||||
return [lora.name for lora in self.layers(Lora[Any])]
|
||||
|
||||
@property
|
||||
def loras(self) -> dict[str, Lora]:
|
||||
def loras(self) -> dict[str, Lora[Any]]:
|
||||
"""The LoRA layers indexed by name."""
|
||||
return {lora.name: lora for lora in self.layers(Lora)}
|
||||
return {lora.name: lora for lora in self.layers(Lora[Any])}
|
||||
|
||||
@property
|
||||
def scales(self) -> dict[str, float]:
|
||||
"""The scales of the LoRA layers indexed by names."""
|
||||
return {lora.name: lora.scale for lora in self.layers(Lora)}
|
||||
return {lora.name: lora.scale for lora in self.layers(Lora[Any])}
|
||||
|
||||
@scales.setter
|
||||
def scale(self, values: dict[str, float]) -> None:
|
||||
for name, value in values.items():
|
||||
self.loras[name].scale = value
|
||||
|
||||
def add_lora(self, lora: Lora, /) -> None:
|
||||
def add_lora(self, lora: Lora[Any], /) -> None:
|
||||
"""Add a LoRA layer to the adapter.
|
||||
|
||||
Raises:
|
||||
|
@ -416,7 +417,7 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
|
|||
assert lora.name not in self.names, f"LoRA layer with name {lora.name} already exists"
|
||||
self.append(lora)
|
||||
|
||||
def remove_lora(self, name: str, /) -> Lora | None:
|
||||
def remove_lora(self, name: str, /) -> Lora[Any] | None:
|
||||
"""Remove a LoRA layer from the adapter.
|
||||
|
||||
Note:
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from typing import Any
|
||||
from warnings import warn
|
||||
|
||||
from torch import Tensor
|
||||
|
@ -106,7 +107,7 @@ class SDLoraManager:
|
|||
for name, lora_tensors in tensors.items():
|
||||
self.add_loras(name, tensors=lora_tensors, scale=scale[name] if scale else 1.0)
|
||||
|
||||
def add_loras_to_text_encoder(self, loras: dict[str, Lora], /) -> None:
|
||||
def add_loras_to_text_encoder(self, loras: dict[str, Lora[Any]], /) -> None:
|
||||
"""Add multiple LoRAs to the text encoder.
|
||||
|
||||
Args:
|
||||
|
@ -116,7 +117,7 @@ class SDLoraManager:
|
|||
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
|
||||
SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder)
|
||||
|
||||
def add_loras_to_unet(self, loras: dict[str, Lora], /) -> None:
|
||||
def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None:
|
||||
"""Add multiple LoRAs to the U-Net.
|
||||
|
||||
Args:
|
||||
|
@ -147,7 +148,7 @@ class SDLoraManager:
|
|||
for lora_adapter in self.lora_adapters:
|
||||
lora_adapter.eject()
|
||||
|
||||
def get_loras_by_name(self, name: str, /) -> list[Lora]:
|
||||
def get_loras_by_name(self, name: str, /) -> list[Lora[Any]]:
|
||||
"""Get the LoRA layers with the given name.
|
||||
|
||||
Args:
|
||||
|
@ -190,9 +191,9 @@ class SDLoraManager:
|
|||
lora.scale = scale
|
||||
|
||||
@property
|
||||
def loras(self) -> list[Lora]:
|
||||
def loras(self) -> list[Lora[Any]]:
|
||||
"""List of all the LoRA layers managed by the SDLoraManager."""
|
||||
return list(self.unet.layers(Lora)) + list(self.clip_text_encoder.layers(Lora))
|
||||
return list(self.unet.layers(Lora[Any])) + list(self.clip_text_encoder.layers(Lora[Any]))
|
||||
|
||||
@property
|
||||
def names(self) -> list[str]:
|
||||
|
@ -239,12 +240,12 @@ class SDLoraManager:
|
|||
|
||||
@staticmethod
|
||||
def auto_attach(
|
||||
loras: dict[str, Lora],
|
||||
loras: dict[str, Lora[Any]],
|
||||
target: fl.Chain,
|
||||
/,
|
||||
exclude: list[str] | None = None,
|
||||
) -> None:
|
||||
failed_loras: dict[str, Lora] = {}
|
||||
failed_loras: dict[str, Lora[Any]] = {}
|
||||
for key, lora in loras.items():
|
||||
if attach := lora.auto_attach(target, exclude=exclude):
|
||||
adapter, parent = attach
|
||||
|
|
|
@ -11,11 +11,11 @@ def lora() -> LinearLora:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def conv_lora() -> Lora:
|
||||
def conv_lora() -> Conv2dLora:
|
||||
return Conv2dLora("conv_test", in_channels=16, out_channels=8, kernel_size=(3, 1), rank=4)
|
||||
|
||||
|
||||
def test_properties(lora: LinearLora, conv_lora: Lora) -> None:
|
||||
def test_properties(lora: LinearLora, conv_lora: Conv2dLora) -> None:
|
||||
assert lora.name == "test"
|
||||
assert lora.rank == lora.down.out_features == lora.up.in_features == 16
|
||||
assert lora.scale == 1.0
|
||||
|
@ -27,7 +27,6 @@ def test_properties(lora: LinearLora, conv_lora: Lora) -> None:
|
|||
assert conv_lora.scale == 1.0
|
||||
assert conv_lora.in_channels == conv_lora.down.in_channels == 16
|
||||
assert conv_lora.out_channels == conv_lora.up.out_channels == 8
|
||||
assert isinstance(conv_lora.down, fl.Conv2d) and isinstance(conv_lora.up, fl.Conv2d)
|
||||
assert conv_lora.kernel_size == (conv_lora.down.kernel_size[0], conv_lora.up.kernel_size[0]) == (3, 1)
|
||||
# padding is set so the spatial dimensions are preserved
|
||||
assert conv_lora.padding == (conv_lora.down.padding[0], conv_lora.up.padding[0]) == (0, 1)
|
||||
|
@ -40,12 +39,10 @@ def test_scale_setter(lora: LinearLora) -> None:
|
|||
|
||||
|
||||
def test_from_weights(lora: LinearLora, conv_lora: Conv2dLora) -> None:
|
||||
assert isinstance(lora.down, fl.Linear) and isinstance(lora.up, fl.Linear)
|
||||
new_lora = LinearLora.from_weights("test", down=lora.down.weight, up=lora.up.weight)
|
||||
x = torch.randn(1, 320)
|
||||
assert torch.allclose(lora(x), new_lora(x))
|
||||
|
||||
assert isinstance(conv_lora.down, fl.Conv2d) and isinstance(conv_lora.up, fl.Conv2d)
|
||||
new_conv_lora = Conv2dLora.from_weights("conv_test", down=conv_lora.down.weight, up=conv_lora.up.weight)
|
||||
x = torch.randn(1, 16, 64, 64)
|
||||
assert torch.allclose(conv_lora(x), new_conv_lora(x))
|
||||
|
|
Loading…
Reference in a new issue