diff --git a/src/refiners/fluxion/layers/__init__.py b/src/refiners/fluxion/layers/__init__.py index 00ab130..610f5d4 100644 --- a/src/refiners/fluxion/layers/__init__.py +++ b/src/refiners/fluxion/layers/__init__.py @@ -1,8 +1,8 @@ from refiners.fluxion.layers.activations import ( GLU, Activation, - ApproximateGeLU, GeLU, + GeLUApproximation, ReLU, Sigmoid, SiLU, @@ -64,10 +64,10 @@ __all__ = [ "InstanceNorm2d", "Activation", "GeLU", + "GeLUApproximation", "GLU", "SiLU", "ReLU", - "ApproximateGeLU", "Sigmoid", "Attention", "ScaledDotProductAttention", diff --git a/src/refiners/fluxion/layers/activations.py b/src/refiners/fluxion/layers/activations.py index 2bb3ab5..786f644 100644 --- a/src/refiners/fluxion/layers/activations.py +++ b/src/refiners/fluxion/layers/activations.py @@ -97,24 +97,21 @@ class GeLU(Activation): ``` """ - def __init__(self) -> None: + def __init__( + self, + approximation: GeLUApproximation = GeLUApproximation.NONE, + ) -> None: super().__init__() + self.approximation = approximation def forward(self, x: Tensor) -> Tensor: - return gelu(x) # type: ignore - - -class ApproximateGeLU(Activation): - """ - The approximate form of Gaussian Error Linear Unit (GELU) - For more details, see section 2: https://arxiv.org/abs/1606.08415 - """ - - def __init__(self) -> None: - super().__init__() - - def forward(self, x: Tensor) -> Tensor: - return x * sigmoid(1.702 * x) + match self.approximation: + case GeLUApproximation.NONE: + return gelu(x, approximate="none") + case GeLUApproximation.TANH: + return gelu(x, approximate="tanh") + case GeLUApproximation.SIGMOID: + return x * sigmoid(1.702 * x) class Sigmoid(Activation): diff --git a/src/refiners/foundationals/clip/text_encoder.py b/src/refiners/foundationals/clip/text_encoder.py index 85d10fa..c48a7c0 100644 --- a/src/refiners/foundationals/clip/text_encoder.py +++ b/src/refiners/foundationals/clip/text_encoder.py @@ -146,7 +146,10 @@ class CLIPTextEncoder(fl.Chain): ) if use_quick_gelu: for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)): - parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU()) + parent.replace( + old_module=gelu, + new_module=fl.GeLU(approximation=fl.GeLUApproximation.SIGMOID), + ) class CLIPTextEncoderL(CLIPTextEncoder):