From b8fae60d38c3eed119d84b847aff52778df9a7fe Mon Sep 17 00:00:00 2001 From: Laurent Date: Wed, 13 Mar 2024 14:19:31 +0000 Subject: [PATCH] make LoRA's weight initialization overridable --- src/refiners/fluxion/adapters/lora.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/refiners/fluxion/adapters/lora.py b/src/refiners/fluxion/adapters/lora.py index 29a4bb9..e493429 100644 --- a/src/refiners/fluxion/adapters/lora.py +++ b/src/refiners/fluxion/adapters/lora.py @@ -14,10 +14,9 @@ T = TypeVar("T", bound=fl.WeightedModule) class Lora(Generic[T], fl.Chain, ABC): """Low-Rank Adaptation (LoRA) layer. - This layer is composed of two [`WeightedModule`][refiners.fluxion.layers.WeightedModule]: - - - `down`: initialized with a random normal distribution - - `up`: initialized with zeros + This layer's purpose is to approximate a given layer by two smaller layers: + the [`down`][refiners.fluxion.adapters.lora.Lora.down] layer (aka A) and the [`up`][refiners.fluxion.adapters.lora.Lora.up] layer (aka B). + See [[ arXiv:2106.09685] LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) for more details. Note: This layer is not meant to be used directly. @@ -53,7 +52,10 @@ class Lora(Generic[T], fl.Chain, ABC): *self.lora_layers(device=device, dtype=dtype), fl.Multiply(scale), ) + self.reset_parameters() + def reset_parameters(self) -> None: + """Reset the parameters of up and down layers.""" normal_(tensor=self.down.weight, std=1 / self.rank) zeros_(tensor=self.up.weight)