mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 09:08:14 +00:00
(doc/fluxion/lora) add/convert docstrings to mkdocstrings format
This commit is contained in:
parent
9fb9df5f91
commit
4054421854
|
@ -1,3 +1,10 @@
|
||||||
from refiners.fluxion.adapters.adapter import Adapter
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
from refiners.fluxion.adapters.lora import Conv2dLora, LinearLora, Lora, LoraAdapter
|
||||||
|
|
||||||
__all__ = ["Adapter"]
|
__all__ = [
|
||||||
|
"Adapter",
|
||||||
|
"Lora",
|
||||||
|
"LinearLora",
|
||||||
|
"Conv2dLora",
|
||||||
|
"LoraAdapter",
|
||||||
|
]
|
||||||
|
|
|
@ -9,6 +9,21 @@ from refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
|
||||||
|
|
||||||
class Lora(fl.Chain, ABC):
|
class Lora(fl.Chain, ABC):
|
||||||
|
"""Low-rank approximation (LoRA) layer.
|
||||||
|
|
||||||
|
This layer is composed of two [`WeightedModule`][refiners.fluxion.layers.WeightedModule]:
|
||||||
|
|
||||||
|
- `down`: initialized with a random normal distribution
|
||||||
|
- `up`: initialized with zeros
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This layer is not meant to be used directly.
|
||||||
|
Instead, use one of its subclasses:
|
||||||
|
|
||||||
|
- [`LinearLora`][refiners.fluxion.adapters.lora.LinearLora]
|
||||||
|
- [`Conv2dLora`][refiners.fluxion.adapters.lora.Conv2dLora]
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
|
@ -18,11 +33,23 @@ class Lora(fl.Chain, ABC):
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initialize the LoRA layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the LoRA.
|
||||||
|
rank: The rank of the LoRA.
|
||||||
|
scale: The scale of the LoRA.
|
||||||
|
device: The device of the LoRA weights.
|
||||||
|
dtype: The dtype of the LoRA weights.
|
||||||
|
"""
|
||||||
self.name = name
|
self.name = name
|
||||||
self._rank = rank
|
self._rank = rank
|
||||||
self._scale = scale
|
self._scale = scale
|
||||||
|
|
||||||
super().__init__(*self.lora_layers(device=device, dtype=dtype), fl.Multiply(scale))
|
super().__init__(
|
||||||
|
*self.lora_layers(device=device, dtype=dtype),
|
||||||
|
fl.Multiply(scale),
|
||||||
|
)
|
||||||
|
|
||||||
normal_(tensor=self.down.weight, std=1 / self.rank)
|
normal_(tensor=self.down.weight, std=1 / self.rank)
|
||||||
zeros_(tensor=self.up.weight)
|
zeros_(tensor=self.up.weight)
|
||||||
|
@ -31,26 +58,36 @@ class Lora(fl.Chain, ABC):
|
||||||
def lora_layers(
|
def lora_layers(
|
||||||
self, device: Device | str | None = None, dtype: DType | None = None
|
self, device: Device | str | None = None, dtype: DType | None = None
|
||||||
) -> tuple[fl.WeightedModule, fl.WeightedModule]:
|
) -> tuple[fl.WeightedModule, fl.WeightedModule]:
|
||||||
|
"""Create the down and up layers of the LoRA.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device: The device of the LoRA weights.
|
||||||
|
dtype: The dtype of the LoRA weights.
|
||||||
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def down(self) -> fl.WeightedModule:
|
def down(self) -> fl.WeightedModule:
|
||||||
|
"""The down layer."""
|
||||||
down_layer = self[0]
|
down_layer = self[0]
|
||||||
assert isinstance(down_layer, fl.WeightedModule)
|
assert isinstance(down_layer, fl.WeightedModule)
|
||||||
return down_layer
|
return down_layer
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def up(self) -> fl.WeightedModule:
|
def up(self) -> fl.WeightedModule:
|
||||||
|
"""The up layer."""
|
||||||
up_layer = self[1]
|
up_layer = self[1]
|
||||||
assert isinstance(up_layer, fl.WeightedModule)
|
assert isinstance(up_layer, fl.WeightedModule)
|
||||||
return up_layer
|
return up_layer
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def rank(self) -> int:
|
def rank(self) -> int:
|
||||||
|
"""The rank of the low-rank approximation."""
|
||||||
return self._rank
|
return self._rank
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scale(self) -> float:
|
def scale(self) -> float:
|
||||||
|
"""The scale of the low-rank approximation."""
|
||||||
return self._scale
|
return self._scale
|
||||||
|
|
||||||
@scale.setter
|
@scale.setter
|
||||||
|
@ -119,6 +156,12 @@ class Lora(fl.Chain, ABC):
|
||||||
return LoraAdapter(layer, self), parent
|
return LoraAdapter(layer, self), parent
|
||||||
|
|
||||||
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
|
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
|
||||||
|
"""Load the weights of the LoRA.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
down_weight: The down weight.
|
||||||
|
up_weight: The up weight.
|
||||||
|
"""
|
||||||
assert down_weight.shape == self.down.weight.shape
|
assert down_weight.shape == self.down.weight.shape
|
||||||
assert up_weight.shape == self.up.weight.shape
|
assert up_weight.shape == self.up.weight.shape
|
||||||
self.down.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
|
self.down.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
|
||||||
|
@ -126,6 +169,11 @@ class Lora(fl.Chain, ABC):
|
||||||
|
|
||||||
|
|
||||||
class LinearLora(Lora):
|
class LinearLora(Lora):
|
||||||
|
"""Low-rank approximation (LoRA) layer for linear layers.
|
||||||
|
|
||||||
|
This layer uses two [`Linear`][refiners.fluxion.layers.Linear] layers as its down and up layers.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
|
@ -137,10 +185,27 @@ class LinearLora(Lora):
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initialize the LoRA layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the LoRA.
|
||||||
|
in_features: The number of input features.
|
||||||
|
out_features: The number of output features.
|
||||||
|
rank: The rank of the LoRA.
|
||||||
|
scale: The scale of the LoRA.
|
||||||
|
device: The device of the LoRA weights.
|
||||||
|
dtype: The dtype of the LoRA weights.
|
||||||
|
"""
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
|
|
||||||
super().__init__(name, rank=rank, scale=scale, device=device, dtype=dtype)
|
super().__init__(
|
||||||
|
name,
|
||||||
|
rank=rank,
|
||||||
|
scale=scale,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_weights(
|
def from_weights(
|
||||||
|
@ -190,6 +255,11 @@ class LinearLora(Lora):
|
||||||
|
|
||||||
|
|
||||||
class Conv2dLora(Lora):
|
class Conv2dLora(Lora):
|
||||||
|
"""Low-rank approximation (LoRA) layer for 2D convolutional layers.
|
||||||
|
|
||||||
|
This layer uses two [`Conv2d`][refiners.fluxion.layers.Conv2d] layers as its down and up layers.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
|
@ -204,13 +274,33 @@ class Conv2dLora(Lora):
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initialize the LoRA layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the LoRA.
|
||||||
|
in_channels: The number of input channels.
|
||||||
|
out_channels: The number of output channels.
|
||||||
|
rank: The rank of the LoRA.
|
||||||
|
scale: The scale of the LoRA.
|
||||||
|
kernel_size: The kernel size of the LoRA.
|
||||||
|
stride: The stride of the LoRA.
|
||||||
|
padding: The padding of the LoRA.
|
||||||
|
device: The device of the LoRA weights.
|
||||||
|
dtype: The dtype of the LoRA weights.
|
||||||
|
"""
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
|
|
||||||
super().__init__(name, rank=rank, scale=scale, device=device, dtype=dtype)
|
super().__init__(
|
||||||
|
name,
|
||||||
|
rank=rank,
|
||||||
|
scale=scale,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_weights(
|
def from_weights(
|
||||||
|
@ -279,20 +369,34 @@ class Conv2dLora(Lora):
|
||||||
|
|
||||||
|
|
||||||
class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
|
class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
|
||||||
|
"""Adapter for LoRA layers.
|
||||||
|
|
||||||
|
This adapter simply sums the target layer with the given LoRA layers.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, target: fl.WeightedModule, /, *loras: Lora) -> None:
|
def __init__(self, target: fl.WeightedModule, /, *loras: Lora) -> None:
|
||||||
|
"""Initialize the adapter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: The target layer.
|
||||||
|
loras: The LoRA layers.
|
||||||
|
"""
|
||||||
with self.setup_adapter(target):
|
with self.setup_adapter(target):
|
||||||
super().__init__(target, *loras)
|
super().__init__(target, *loras)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def names(self) -> list[str]:
|
def names(self) -> list[str]:
|
||||||
|
"""The names of the LoRA layers."""
|
||||||
return [lora.name for lora in self.layers(Lora)]
|
return [lora.name for lora in self.layers(Lora)]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loras(self) -> dict[str, Lora]:
|
def loras(self) -> dict[str, Lora]:
|
||||||
|
"""The LoRA layers."""
|
||||||
return {lora.name: lora for lora in self.layers(Lora)}
|
return {lora.name: lora for lora in self.layers(Lora)}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scales(self) -> dict[str, float]:
|
def scales(self) -> dict[str, float]:
|
||||||
|
"""The scales of the LoRA layers."""
|
||||||
return {lora.name: lora.scale for lora in self.layers(Lora)}
|
return {lora.name: lora.scale for lora in self.layers(Lora)}
|
||||||
|
|
||||||
@scales.setter
|
@scales.setter
|
||||||
|
@ -301,10 +405,20 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
|
||||||
self.loras[name].scale = value
|
self.loras[name].scale = value
|
||||||
|
|
||||||
def add_lora(self, lora: Lora, /) -> None:
|
def add_lora(self, lora: Lora, /) -> None:
|
||||||
|
"""Add a LoRA layer to the adapter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lora: The LoRA layer to add.
|
||||||
|
"""
|
||||||
assert lora.name not in self.names, f"LoRA layer with name {lora.name} already exists"
|
assert lora.name not in self.names, f"LoRA layer with name {lora.name} already exists"
|
||||||
self.append(lora)
|
self.append(lora)
|
||||||
|
|
||||||
def remove_lora(self, name: str, /) -> Lora | None:
|
def remove_lora(self, name: str, /) -> Lora | None:
|
||||||
|
"""Remove a LoRA layer from the adapter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the LoRA layer to remove.
|
||||||
|
"""
|
||||||
if name in self.names:
|
if name in self.names:
|
||||||
lora = self.loras[name]
|
lora = self.loras[name]
|
||||||
self.remove(lora)
|
self.remove(lora)
|
||||||
|
|
Loading…
Reference in a new issue