(doc/fluxion/lora) add/convert docstrings to mkdocstrings format

This commit is contained in:
Laurent 2024-02-01 22:32:24 +00:00 committed by Laureηt
parent 9fb9df5f91
commit 4054421854
2 changed files with 125 additions and 4 deletions

View file

@ -1,3 +1,10 @@
from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.adapters.lora import Conv2dLora, LinearLora, Lora, LoraAdapter
__all__ = ["Adapter"] __all__ = [
"Adapter",
"Lora",
"LinearLora",
"Conv2dLora",
"LoraAdapter",
]

View file

@ -9,6 +9,21 @@ from refiners.fluxion.adapters.adapter import Adapter
class Lora(fl.Chain, ABC): class Lora(fl.Chain, ABC):
"""Low-rank approximation (LoRA) layer.
This layer is composed of two [`WeightedModule`][refiners.fluxion.layers.WeightedModule]:
- `down`: initialized with a random normal distribution
- `up`: initialized with zeros
Note:
This layer is not meant to be used directly.
Instead, use one of its subclasses:
- [`LinearLora`][refiners.fluxion.adapters.lora.LinearLora]
- [`Conv2dLora`][refiners.fluxion.adapters.lora.Conv2dLora]
"""
def __init__( def __init__(
self, self,
name: str, name: str,
@ -18,11 +33,23 @@ class Lora(fl.Chain, ABC):
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
) -> None: ) -> None:
"""Initialize the LoRA layer.
Args:
name: The name of the LoRA.
rank: The rank of the LoRA.
scale: The scale of the LoRA.
device: The device of the LoRA weights.
dtype: The dtype of the LoRA weights.
"""
self.name = name self.name = name
self._rank = rank self._rank = rank
self._scale = scale self._scale = scale
super().__init__(*self.lora_layers(device=device, dtype=dtype), fl.Multiply(scale)) super().__init__(
*self.lora_layers(device=device, dtype=dtype),
fl.Multiply(scale),
)
normal_(tensor=self.down.weight, std=1 / self.rank) normal_(tensor=self.down.weight, std=1 / self.rank)
zeros_(tensor=self.up.weight) zeros_(tensor=self.up.weight)
@ -31,26 +58,36 @@ class Lora(fl.Chain, ABC):
def lora_layers( def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.WeightedModule, fl.WeightedModule]: ) -> tuple[fl.WeightedModule, fl.WeightedModule]:
"""Create the down and up layers of the LoRA.
Args:
device: The device of the LoRA weights.
dtype: The dtype of the LoRA weights.
"""
... ...
@property @property
def down(self) -> fl.WeightedModule: def down(self) -> fl.WeightedModule:
"""The down layer."""
down_layer = self[0] down_layer = self[0]
assert isinstance(down_layer, fl.WeightedModule) assert isinstance(down_layer, fl.WeightedModule)
return down_layer return down_layer
@property @property
def up(self) -> fl.WeightedModule: def up(self) -> fl.WeightedModule:
"""The up layer."""
up_layer = self[1] up_layer = self[1]
assert isinstance(up_layer, fl.WeightedModule) assert isinstance(up_layer, fl.WeightedModule)
return up_layer return up_layer
@property @property
def rank(self) -> int: def rank(self) -> int:
"""The rank of the low-rank approximation."""
return self._rank return self._rank
@property @property
def scale(self) -> float: def scale(self) -> float:
"""The scale of the low-rank approximation."""
return self._scale return self._scale
@scale.setter @scale.setter
@ -119,6 +156,12 @@ class Lora(fl.Chain, ABC):
return LoraAdapter(layer, self), parent return LoraAdapter(layer, self), parent
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None: def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
"""Load the weights of the LoRA.
Args:
down_weight: The down weight.
up_weight: The up weight.
"""
assert down_weight.shape == self.down.weight.shape assert down_weight.shape == self.down.weight.shape
assert up_weight.shape == self.up.weight.shape assert up_weight.shape == self.up.weight.shape
self.down.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype)) self.down.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
@ -126,6 +169,11 @@ class Lora(fl.Chain, ABC):
class LinearLora(Lora): class LinearLora(Lora):
"""Low-rank approximation (LoRA) layer for linear layers.
This layer uses two [`Linear`][refiners.fluxion.layers.Linear] layers as its down and up layers.
"""
def __init__( def __init__(
self, self,
name: str, name: str,
@ -137,10 +185,27 @@ class LinearLora(Lora):
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
) -> None: ) -> None:
"""Initialize the LoRA layer.
Args:
name: The name of the LoRA.
in_features: The number of input features.
out_features: The number of output features.
rank: The rank of the LoRA.
scale: The scale of the LoRA.
device: The device of the LoRA weights.
dtype: The dtype of the LoRA weights.
"""
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
super().__init__(name, rank=rank, scale=scale, device=device, dtype=dtype) super().__init__(
name,
rank=rank,
scale=scale,
device=device,
dtype=dtype,
)
@classmethod @classmethod
def from_weights( def from_weights(
@ -190,6 +255,11 @@ class LinearLora(Lora):
class Conv2dLora(Lora): class Conv2dLora(Lora):
"""Low-rank approximation (LoRA) layer for 2D convolutional layers.
This layer uses two [`Conv2d`][refiners.fluxion.layers.Conv2d] layers as its down and up layers.
"""
def __init__( def __init__(
self, self,
name: str, name: str,
@ -204,13 +274,33 @@ class Conv2dLora(Lora):
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
) -> None: ) -> None:
"""Initialize the LoRA layer.
Args:
name: The name of the LoRA.
in_channels: The number of input channels.
out_channels: The number of output channels.
rank: The rank of the LoRA.
scale: The scale of the LoRA.
kernel_size: The kernel size of the LoRA.
stride: The stride of the LoRA.
padding: The padding of the LoRA.
device: The device of the LoRA weights.
dtype: The dtype of the LoRA weights.
"""
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.stride = stride self.stride = stride
self.padding = padding self.padding = padding
super().__init__(name, rank=rank, scale=scale, device=device, dtype=dtype) super().__init__(
name,
rank=rank,
scale=scale,
device=device,
dtype=dtype,
)
@classmethod @classmethod
def from_weights( def from_weights(
@ -279,20 +369,34 @@ class Conv2dLora(Lora):
class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]): class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
"""Adapter for LoRA layers.
This adapter simply sums the target layer with the given LoRA layers.
"""
def __init__(self, target: fl.WeightedModule, /, *loras: Lora) -> None: def __init__(self, target: fl.WeightedModule, /, *loras: Lora) -> None:
"""Initialize the adapter.
Args:
target: The target layer.
loras: The LoRA layers.
"""
with self.setup_adapter(target): with self.setup_adapter(target):
super().__init__(target, *loras) super().__init__(target, *loras)
@property @property
def names(self) -> list[str]: def names(self) -> list[str]:
"""The names of the LoRA layers."""
return [lora.name for lora in self.layers(Lora)] return [lora.name for lora in self.layers(Lora)]
@property @property
def loras(self) -> dict[str, Lora]: def loras(self) -> dict[str, Lora]:
"""The LoRA layers."""
return {lora.name: lora for lora in self.layers(Lora)} return {lora.name: lora for lora in self.layers(Lora)}
@property @property
def scales(self) -> dict[str, float]: def scales(self) -> dict[str, float]:
"""The scales of the LoRA layers."""
return {lora.name: lora.scale for lora in self.layers(Lora)} return {lora.name: lora.scale for lora in self.layers(Lora)}
@scales.setter @scales.setter
@ -301,10 +405,20 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
self.loras[name].scale = value self.loras[name].scale = value
def add_lora(self, lora: Lora, /) -> None: def add_lora(self, lora: Lora, /) -> None:
"""Add a LoRA layer to the adapter.
Args:
lora: The LoRA layer to add.
"""
assert lora.name not in self.names, f"LoRA layer with name {lora.name} already exists" assert lora.name not in self.names, f"LoRA layer with name {lora.name} already exists"
self.append(lora) self.append(lora)
def remove_lora(self, name: str, /) -> Lora | None: def remove_lora(self, name: str, /) -> Lora | None:
"""Remove a LoRA layer from the adapter.
Args:
name: The name of the LoRA layer to remove.
"""
if name in self.names: if name in self.names:
lora = self.loras[name] lora = self.loras[name]
self.remove(lora) self.remove(lora)