make LoRA's weight initialization overridable

This commit is contained in:
Laurent 2024-03-13 14:19:31 +00:00 committed by Laureηt
parent c1b3a52141
commit b8fae60d38

View file

@ -14,10 +14,9 @@ T = TypeVar("T", bound=fl.WeightedModule)
class Lora(Generic[T], fl.Chain, ABC): class Lora(Generic[T], fl.Chain, ABC):
"""Low-Rank Adaptation (LoRA) layer. """Low-Rank Adaptation (LoRA) layer.
This layer is composed of two [`WeightedModule`][refiners.fluxion.layers.WeightedModule]: This layer's purpose is to approximate a given layer by two smaller layers:
the [`down`][refiners.fluxion.adapters.lora.Lora.down] layer (aka A) and the [`up`][refiners.fluxion.adapters.lora.Lora.up] layer (aka B).
- `down`: initialized with a random normal distribution See [[ arXiv:2106.09685] LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) for more details.
- `up`: initialized with zeros
Note: Note:
This layer is not meant to be used directly. This layer is not meant to be used directly.
@ -53,7 +52,10 @@ class Lora(Generic[T], fl.Chain, ABC):
*self.lora_layers(device=device, dtype=dtype), *self.lora_layers(device=device, dtype=dtype),
fl.Multiply(scale), fl.Multiply(scale),
) )
self.reset_parameters()
def reset_parameters(self) -> None:
"""Reset the parameters of up and down layers."""
normal_(tensor=self.down.weight, std=1 / self.rank) normal_(tensor=self.down.weight, std=1 / self.rank)
zeros_(tensor=self.up.weight) zeros_(tensor=self.up.weight)