make LoRA's weight initialization overridable

This commit is contained in:
Laurent 2024-03-13 14:19:31 +00:00 committed by Laureηt
parent c1b3a52141
commit b8fae60d38

View file

@ -14,10 +14,9 @@ T = TypeVar("T", bound=fl.WeightedModule)
class Lora(Generic[T], fl.Chain, ABC):
"""Low-Rank Adaptation (LoRA) layer.
This layer is composed of two [`WeightedModule`][refiners.fluxion.layers.WeightedModule]:
- `down`: initialized with a random normal distribution
- `up`: initialized with zeros
This layer's purpose is to approximate a given layer by two smaller layers:
the [`down`][refiners.fluxion.adapters.lora.Lora.down] layer (aka A) and the [`up`][refiners.fluxion.adapters.lora.Lora.up] layer (aka B).
See [[ arXiv:2106.09685] LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) for more details.
Note:
This layer is not meant to be used directly.
@ -53,7 +52,10 @@ class Lora(Generic[T], fl.Chain, ABC):
*self.lora_layers(device=device, dtype=dtype),
fl.Multiply(scale),
)
self.reset_parameters()
def reset_parameters(self) -> None:
"""Reset the parameters of up and down layers."""
normal_(tensor=self.down.weight, std=1 / self.rank)
zeros_(tensor=self.up.weight)