(fluxion/layers/activations) replace ApproximateGeLU by GeLUApproximation

This commit is contained in:
Laurent 2024-02-02 14:06:03 +00:00 committed by Cédric Deltheil
parent 2bdb42e88d
commit 8d190e4256
3 changed files with 18 additions and 18 deletions

View file

@ -1,8 +1,8 @@
from refiners.fluxion.layers.activations import ( from refiners.fluxion.layers.activations import (
GLU, GLU,
Activation, Activation,
ApproximateGeLU,
GeLU, GeLU,
GeLUApproximation,
ReLU, ReLU,
Sigmoid, Sigmoid,
SiLU, SiLU,
@ -64,10 +64,10 @@ __all__ = [
"InstanceNorm2d", "InstanceNorm2d",
"Activation", "Activation",
"GeLU", "GeLU",
"GeLUApproximation",
"GLU", "GLU",
"SiLU", "SiLU",
"ReLU", "ReLU",
"ApproximateGeLU",
"Sigmoid", "Sigmoid",
"Attention", "Attention",
"ScaledDotProductAttention", "ScaledDotProductAttention",

View file

@ -97,24 +97,21 @@ class GeLU(Activation):
``` ```
""" """
def __init__(self) -> None: def __init__(
self,
approximation: GeLUApproximation = GeLUApproximation.NONE,
) -> None:
super().__init__() super().__init__()
self.approximation = approximation
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return gelu(x) # type: ignore match self.approximation:
case GeLUApproximation.NONE:
return gelu(x, approximate="none")
class ApproximateGeLU(Activation): case GeLUApproximation.TANH:
""" return gelu(x, approximate="tanh")
The approximate form of Gaussian Error Linear Unit (GELU) case GeLUApproximation.SIGMOID:
For more details, see section 2: https://arxiv.org/abs/1606.08415 return x * sigmoid(1.702 * x)
"""
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor) -> Tensor:
return x * sigmoid(1.702 * x)
class Sigmoid(Activation): class Sigmoid(Activation):

View file

@ -146,7 +146,10 @@ class CLIPTextEncoder(fl.Chain):
) )
if use_quick_gelu: if use_quick_gelu:
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)): for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)):
parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU()) parent.replace(
old_module=gelu,
new_module=fl.GeLU(approximation=fl.GeLUApproximation.SIGMOID),
)
class CLIPTextEncoderL(CLIPTextEncoder): class CLIPTextEncoderL(CLIPTextEncoder):