mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
(doc/fluxion/activations) add/convert docstrings to mkdocstrings format
This commit is contained in:
parent
0fc3264fae
commit
a7c048f5fb
|
@ -1,34 +1,102 @@
|
||||||
from torch import Tensor, sigmoid
|
from abc import ABC
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
from torch.nn.functional import (
|
from torch.nn.functional import (
|
||||||
gelu, # type: ignore
|
gelu,
|
||||||
|
relu,
|
||||||
|
sigmoid,
|
||||||
silu,
|
silu,
|
||||||
)
|
)
|
||||||
|
|
||||||
from refiners.fluxion.layers.module import Module
|
from refiners.fluxion.layers.module import Module
|
||||||
|
|
||||||
|
|
||||||
class Activation(Module):
|
class Activation(Module, ABC):
|
||||||
|
"""Base class for activation layers.
|
||||||
|
|
||||||
|
Activation layers are layers that apply a (non-linear) function to their input.
|
||||||
|
|
||||||
|
Receives:
|
||||||
|
x (Tensor):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Tensor):
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
class SiLU(Activation):
|
class SiLU(Activation):
|
||||||
|
"""Sigmoid Linear Unit activation function.
|
||||||
|
|
||||||
|
See [[arXiv:1702.03118] Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning](https://arxiv.org/abs/1702.03118) for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return silu(x) # type: ignore
|
return silu(x)
|
||||||
|
|
||||||
|
|
||||||
class ReLU(Activation):
|
class ReLU(Activation):
|
||||||
|
"""Rectified Linear Unit activation function.
|
||||||
|
|
||||||
|
See [Rectified Linear Units Improve Restricted Boltzmann Machines](https://www.cs.toronto.edu/%7Efritz/absps/reluICML.pdf)
|
||||||
|
and [Cognitron: A self-organizing multilayered neural network](https://link.springer.com/article/10.1007/BF00342633)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
relu = fl.ReLU()
|
||||||
|
|
||||||
|
tensor = torch.tensor([[-1.0, 0.0, 1.0]])
|
||||||
|
output = relu(tensor)
|
||||||
|
|
||||||
|
expected_output = torch.tensor([[0.0, 0.0, 1.0]])
|
||||||
|
assert torch.allclose(output, expected_output)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return x.relu()
|
return relu(x)
|
||||||
|
|
||||||
|
|
||||||
|
class GeLUApproximation(Enum):
|
||||||
|
"""Approximation methods for the Gaussian Error Linear Unit activation function.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
NONE: No approximation, use the original formula.
|
||||||
|
TANH: Use the tanh approximation.
|
||||||
|
SIGMOID: Use the sigmoid approximation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
NONE = "none"
|
||||||
|
TANH = "tanh"
|
||||||
|
SIGMOID = "sigmoid"
|
||||||
|
|
||||||
|
|
||||||
class GeLU(Activation):
|
class GeLU(Activation):
|
||||||
|
"""Gaussian Error Linear Unit activation function.
|
||||||
|
|
||||||
|
This activation can be quite expensive to compute, a few approximations are available,
|
||||||
|
see [`GeLUApproximation`][refiners.fluxion.layers.activations.GeLUApproximation].
|
||||||
|
|
||||||
|
See [[arXiv:1606.08415] Gaussian Error Linear Units](https://arxiv.org/abs/1606.08415) for more details.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
gelu = fl.GeLU()
|
||||||
|
|
||||||
|
tensor = torch.tensor([[-1.0, 0.0, 1.0]])
|
||||||
|
output = gelu(tensor)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -50,18 +118,36 @@ class ApproximateGeLU(Activation):
|
||||||
|
|
||||||
|
|
||||||
class Sigmoid(Activation):
|
class Sigmoid(Activation):
|
||||||
|
"""Sigmoid activation function.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
sigmoid = fl.Sigmoid()
|
||||||
|
|
||||||
|
tensor = torch.tensor([[-1.0, 0.0, 1.0]])
|
||||||
|
output = sigmoid(tensor)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return x.sigmoid()
|
return sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
class GLU(Activation):
|
class GLU(Activation):
|
||||||
"""
|
"""Gated Linear Unit activation function.
|
||||||
Gated Linear Unit activation layer.
|
|
||||||
|
|
||||||
See https://arxiv.org/abs/2002.05202v1 for details.
|
See [[arXiv:2002.05202] GLU Variants Improve Transformer](https://arxiv.org/abs/2002.05202) for more details.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
glu = fl.GLU()
|
||||||
|
|
||||||
|
tensor = torch.tensor([[-1.0, 0.0, 1.0]])
|
||||||
|
output = glu(tensor)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, activation: Activation) -> None:
|
def __init__(self, activation: Activation) -> None:
|
||||||
|
|
Loading…
Reference in a new issue