(fluxion/layers/activations) replace ApproximateGeLU by GeLUApproximation

This commit is contained in:
Laurent 2024-02-02 14:06:03 +00:00 committed by Cédric Deltheil
parent 2bdb42e88d
commit 8d190e4256
3 changed files with 18 additions and 18 deletions

View file

@ -1,8 +1,8 @@
from refiners.fluxion.layers.activations import (
GLU,
Activation,
ApproximateGeLU,
GeLU,
GeLUApproximation,
ReLU,
Sigmoid,
SiLU,
@ -64,10 +64,10 @@ __all__ = [
"InstanceNorm2d",
"Activation",
"GeLU",
"GeLUApproximation",
"GLU",
"SiLU",
"ReLU",
"ApproximateGeLU",
"Sigmoid",
"Attention",
"ScaledDotProductAttention",

View file

@ -97,23 +97,20 @@ class GeLU(Activation):
```
"""
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor) -> Tensor:
return gelu(x) # type: ignore
class ApproximateGeLU(Activation):
"""
The approximate form of Gaussian Error Linear Unit (GELU)
For more details, see section 2: https://arxiv.org/abs/1606.08415
"""
def __init__(self) -> None:
def __init__(
self,
approximation: GeLUApproximation = GeLUApproximation.NONE,
) -> None:
super().__init__()
self.approximation = approximation
def forward(self, x: Tensor) -> Tensor:
match self.approximation:
case GeLUApproximation.NONE:
return gelu(x, approximate="none")
case GeLUApproximation.TANH:
return gelu(x, approximate="tanh")
case GeLUApproximation.SIGMOID:
return x * sigmoid(1.702 * x)

View file

@ -146,7 +146,10 @@ class CLIPTextEncoder(fl.Chain):
)
if use_quick_gelu:
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)):
parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU())
parent.replace(
old_module=gelu,
new_module=fl.GeLU(approximation=fl.GeLUApproximation.SIGMOID),
)
class CLIPTextEncoderL(CLIPTextEncoder):