mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
(doc/fluxion/lora) add/convert docstrings to mkdocstrings format
This commit is contained in:
parent
9fb9df5f91
commit
4054421854
|
@ -1,3 +1,10 @@
|
|||
from refiners.fluxion.adapters.adapter import Adapter
|
||||
from refiners.fluxion.adapters.lora import Conv2dLora, LinearLora, Lora, LoraAdapter
|
||||
|
||||
__all__ = ["Adapter"]
|
||||
__all__ = [
|
||||
"Adapter",
|
||||
"Lora",
|
||||
"LinearLora",
|
||||
"Conv2dLora",
|
||||
"LoraAdapter",
|
||||
]
|
||||
|
|
|
@ -9,6 +9,21 @@ from refiners.fluxion.adapters.adapter import Adapter
|
|||
|
||||
|
||||
class Lora(fl.Chain, ABC):
|
||||
"""Low-rank approximation (LoRA) layer.
|
||||
|
||||
This layer is composed of two [`WeightedModule`][refiners.fluxion.layers.WeightedModule]:
|
||||
|
||||
- `down`: initialized with a random normal distribution
|
||||
- `up`: initialized with zeros
|
||||
|
||||
Note:
|
||||
This layer is not meant to be used directly.
|
||||
Instead, use one of its subclasses:
|
||||
|
||||
- [`LinearLora`][refiners.fluxion.adapters.lora.LinearLora]
|
||||
- [`Conv2dLora`][refiners.fluxion.adapters.lora.Conv2dLora]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
|
@ -18,11 +33,23 @@ class Lora(fl.Chain, ABC):
|
|||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
"""Initialize the LoRA layer.
|
||||
|
||||
Args:
|
||||
name: The name of the LoRA.
|
||||
rank: The rank of the LoRA.
|
||||
scale: The scale of the LoRA.
|
||||
device: The device of the LoRA weights.
|
||||
dtype: The dtype of the LoRA weights.
|
||||
"""
|
||||
self.name = name
|
||||
self._rank = rank
|
||||
self._scale = scale
|
||||
|
||||
super().__init__(*self.lora_layers(device=device, dtype=dtype), fl.Multiply(scale))
|
||||
super().__init__(
|
||||
*self.lora_layers(device=device, dtype=dtype),
|
||||
fl.Multiply(scale),
|
||||
)
|
||||
|
||||
normal_(tensor=self.down.weight, std=1 / self.rank)
|
||||
zeros_(tensor=self.up.weight)
|
||||
|
@ -31,26 +58,36 @@ class Lora(fl.Chain, ABC):
|
|||
def lora_layers(
|
||||
self, device: Device | str | None = None, dtype: DType | None = None
|
||||
) -> tuple[fl.WeightedModule, fl.WeightedModule]:
|
||||
"""Create the down and up layers of the LoRA.
|
||||
|
||||
Args:
|
||||
device: The device of the LoRA weights.
|
||||
dtype: The dtype of the LoRA weights.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def down(self) -> fl.WeightedModule:
|
||||
"""The down layer."""
|
||||
down_layer = self[0]
|
||||
assert isinstance(down_layer, fl.WeightedModule)
|
||||
return down_layer
|
||||
|
||||
@property
|
||||
def up(self) -> fl.WeightedModule:
|
||||
"""The up layer."""
|
||||
up_layer = self[1]
|
||||
assert isinstance(up_layer, fl.WeightedModule)
|
||||
return up_layer
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
"""The rank of the low-rank approximation."""
|
||||
return self._rank
|
||||
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
"""The scale of the low-rank approximation."""
|
||||
return self._scale
|
||||
|
||||
@scale.setter
|
||||
|
@ -119,6 +156,12 @@ class Lora(fl.Chain, ABC):
|
|||
return LoraAdapter(layer, self), parent
|
||||
|
||||
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
|
||||
"""Load the weights of the LoRA.
|
||||
|
||||
Args:
|
||||
down_weight: The down weight.
|
||||
up_weight: The up weight.
|
||||
"""
|
||||
assert down_weight.shape == self.down.weight.shape
|
||||
assert up_weight.shape == self.up.weight.shape
|
||||
self.down.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
|
||||
|
@ -126,6 +169,11 @@ class Lora(fl.Chain, ABC):
|
|||
|
||||
|
||||
class LinearLora(Lora):
|
||||
"""Low-rank approximation (LoRA) layer for linear layers.
|
||||
|
||||
This layer uses two [`Linear`][refiners.fluxion.layers.Linear] layers as its down and up layers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
|
@ -137,10 +185,27 @@ class LinearLora(Lora):
|
|||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
"""Initialize the LoRA layer.
|
||||
|
||||
Args:
|
||||
name: The name of the LoRA.
|
||||
in_features: The number of input features.
|
||||
out_features: The number of output features.
|
||||
rank: The rank of the LoRA.
|
||||
scale: The scale of the LoRA.
|
||||
device: The device of the LoRA weights.
|
||||
dtype: The dtype of the LoRA weights.
|
||||
"""
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
super().__init__(name, rank=rank, scale=scale, device=device, dtype=dtype)
|
||||
super().__init__(
|
||||
name,
|
||||
rank=rank,
|
||||
scale=scale,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_weights(
|
||||
|
@ -190,6 +255,11 @@ class LinearLora(Lora):
|
|||
|
||||
|
||||
class Conv2dLora(Lora):
|
||||
"""Low-rank approximation (LoRA) layer for 2D convolutional layers.
|
||||
|
||||
This layer uses two [`Conv2d`][refiners.fluxion.layers.Conv2d] layers as its down and up layers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
|
@ -204,13 +274,33 @@ class Conv2dLora(Lora):
|
|||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
"""Initialize the LoRA layer.
|
||||
|
||||
Args:
|
||||
name: The name of the LoRA.
|
||||
in_channels: The number of input channels.
|
||||
out_channels: The number of output channels.
|
||||
rank: The rank of the LoRA.
|
||||
scale: The scale of the LoRA.
|
||||
kernel_size: The kernel size of the LoRA.
|
||||
stride: The stride of the LoRA.
|
||||
padding: The padding of the LoRA.
|
||||
device: The device of the LoRA weights.
|
||||
dtype: The dtype of the LoRA weights.
|
||||
"""
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
|
||||
super().__init__(name, rank=rank, scale=scale, device=device, dtype=dtype)
|
||||
super().__init__(
|
||||
name,
|
||||
rank=rank,
|
||||
scale=scale,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_weights(
|
||||
|
@ -279,20 +369,34 @@ class Conv2dLora(Lora):
|
|||
|
||||
|
||||
class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
|
||||
"""Adapter for LoRA layers.
|
||||
|
||||
This adapter simply sums the target layer with the given LoRA layers.
|
||||
"""
|
||||
|
||||
def __init__(self, target: fl.WeightedModule, /, *loras: Lora) -> None:
|
||||
"""Initialize the adapter.
|
||||
|
||||
Args:
|
||||
target: The target layer.
|
||||
loras: The LoRA layers.
|
||||
"""
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(target, *loras)
|
||||
|
||||
@property
|
||||
def names(self) -> list[str]:
|
||||
"""The names of the LoRA layers."""
|
||||
return [lora.name for lora in self.layers(Lora)]
|
||||
|
||||
@property
|
||||
def loras(self) -> dict[str, Lora]:
|
||||
"""The LoRA layers."""
|
||||
return {lora.name: lora for lora in self.layers(Lora)}
|
||||
|
||||
@property
|
||||
def scales(self) -> dict[str, float]:
|
||||
"""The scales of the LoRA layers."""
|
||||
return {lora.name: lora.scale for lora in self.layers(Lora)}
|
||||
|
||||
@scales.setter
|
||||
|
@ -301,10 +405,20 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
|
|||
self.loras[name].scale = value
|
||||
|
||||
def add_lora(self, lora: Lora, /) -> None:
|
||||
"""Add a LoRA layer to the adapter.
|
||||
|
||||
Args:
|
||||
lora: The LoRA layer to add.
|
||||
"""
|
||||
assert lora.name not in self.names, f"LoRA layer with name {lora.name} already exists"
|
||||
self.append(lora)
|
||||
|
||||
def remove_lora(self, name: str, /) -> Lora | None:
|
||||
"""Remove a LoRA layer from the adapter.
|
||||
|
||||
Args:
|
||||
name: The name of the LoRA layer to remove.
|
||||
"""
|
||||
if name in self.names:
|
||||
lora = self.loras[name]
|
||||
self.remove(lora)
|
||||
|
|
Loading…
Reference in a new issue