mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
(fluxion/layers/activations) replace ApproximateGeLU
by GeLUApproximation
This commit is contained in:
parent
2bdb42e88d
commit
8d190e4256
|
@ -1,8 +1,8 @@
|
||||||
from refiners.fluxion.layers.activations import (
|
from refiners.fluxion.layers.activations import (
|
||||||
GLU,
|
GLU,
|
||||||
Activation,
|
Activation,
|
||||||
ApproximateGeLU,
|
|
||||||
GeLU,
|
GeLU,
|
||||||
|
GeLUApproximation,
|
||||||
ReLU,
|
ReLU,
|
||||||
Sigmoid,
|
Sigmoid,
|
||||||
SiLU,
|
SiLU,
|
||||||
|
@ -64,10 +64,10 @@ __all__ = [
|
||||||
"InstanceNorm2d",
|
"InstanceNorm2d",
|
||||||
"Activation",
|
"Activation",
|
||||||
"GeLU",
|
"GeLU",
|
||||||
|
"GeLUApproximation",
|
||||||
"GLU",
|
"GLU",
|
||||||
"SiLU",
|
"SiLU",
|
||||||
"ReLU",
|
"ReLU",
|
||||||
"ApproximateGeLU",
|
|
||||||
"Sigmoid",
|
"Sigmoid",
|
||||||
"Attention",
|
"Attention",
|
||||||
"ScaledDotProductAttention",
|
"ScaledDotProductAttention",
|
||||||
|
|
|
@ -97,23 +97,20 @@ class GeLU(Activation):
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(
|
||||||
super().__init__()
|
self,
|
||||||
|
approximation: GeLUApproximation = GeLUApproximation.NONE,
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
) -> None:
|
||||||
return gelu(x) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
class ApproximateGeLU(Activation):
|
|
||||||
"""
|
|
||||||
The approximate form of Gaussian Error Linear Unit (GELU)
|
|
||||||
For more details, see section 2: https://arxiv.org/abs/1606.08415
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.approximation = approximation
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
match self.approximation:
|
||||||
|
case GeLUApproximation.NONE:
|
||||||
|
return gelu(x, approximate="none")
|
||||||
|
case GeLUApproximation.TANH:
|
||||||
|
return gelu(x, approximate="tanh")
|
||||||
|
case GeLUApproximation.SIGMOID:
|
||||||
return x * sigmoid(1.702 * x)
|
return x * sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -146,7 +146,10 @@ class CLIPTextEncoder(fl.Chain):
|
||||||
)
|
)
|
||||||
if use_quick_gelu:
|
if use_quick_gelu:
|
||||||
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)):
|
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)):
|
||||||
parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU())
|
parent.replace(
|
||||||
|
old_module=gelu,
|
||||||
|
new_module=fl.GeLU(approximation=fl.GeLUApproximation.SIGMOID),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextEncoderL(CLIPTextEncoder):
|
class CLIPTextEncoderL(CLIPTextEncoder):
|
||||||
|
|
Loading…
Reference in a new issue