mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
(fluxion/layers/activations) replace ApproximateGeLU
by GeLUApproximation
This commit is contained in:
parent
2bdb42e88d
commit
8d190e4256
|
@ -1,8 +1,8 @@
|
|||
from refiners.fluxion.layers.activations import (
|
||||
GLU,
|
||||
Activation,
|
||||
ApproximateGeLU,
|
||||
GeLU,
|
||||
GeLUApproximation,
|
||||
ReLU,
|
||||
Sigmoid,
|
||||
SiLU,
|
||||
|
@ -64,10 +64,10 @@ __all__ = [
|
|||
"InstanceNorm2d",
|
||||
"Activation",
|
||||
"GeLU",
|
||||
"GeLUApproximation",
|
||||
"GLU",
|
||||
"SiLU",
|
||||
"ReLU",
|
||||
"ApproximateGeLU",
|
||||
"Sigmoid",
|
||||
"Attention",
|
||||
"ScaledDotProductAttention",
|
||||
|
|
|
@ -97,24 +97,21 @@ class GeLU(Activation):
|
|||
```
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
approximation: GeLUApproximation = GeLUApproximation.NONE,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.approximation = approximation
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return gelu(x) # type: ignore
|
||||
|
||||
|
||||
class ApproximateGeLU(Activation):
|
||||
"""
|
||||
The approximate form of Gaussian Error Linear Unit (GELU)
|
||||
For more details, see section 2: https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x * sigmoid(1.702 * x)
|
||||
match self.approximation:
|
||||
case GeLUApproximation.NONE:
|
||||
return gelu(x, approximate="none")
|
||||
case GeLUApproximation.TANH:
|
||||
return gelu(x, approximate="tanh")
|
||||
case GeLUApproximation.SIGMOID:
|
||||
return x * sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class Sigmoid(Activation):
|
||||
|
|
|
@ -146,7 +146,10 @@ class CLIPTextEncoder(fl.Chain):
|
|||
)
|
||||
if use_quick_gelu:
|
||||
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)):
|
||||
parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU())
|
||||
parent.replace(
|
||||
old_module=gelu,
|
||||
new_module=fl.GeLU(approximation=fl.GeLUApproximation.SIGMOID),
|
||||
)
|
||||
|
||||
|
||||
class CLIPTextEncoderL(CLIPTextEncoder):
|
||||
|
|
Loading…
Reference in a new issue