mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
(doc/fluxion/norm) add/convert docstrings to mkdocstrings format
This commit is contained in:
parent
3282782b56
commit
49847658e9
|
@ -1,10 +1,37 @@
|
|||
from jaxtyping import Float
|
||||
from torch import Tensor, device as Device, dtype as DType, nn, ones, sqrt, zeros
|
||||
from torch import Tensor, device as Device, dtype as DType, ones, sqrt, zeros
|
||||
from torch.nn import (
|
||||
GroupNorm as _GroupNorm,
|
||||
InstanceNorm2d as _InstanceNorm2d,
|
||||
LayerNorm as _LayerNorm,
|
||||
Parameter as TorchParameter,
|
||||
)
|
||||
|
||||
from refiners.fluxion.layers.module import Module, WeightedModule
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm, WeightedModule):
|
||||
class LayerNorm(_LayerNorm, WeightedModule):
|
||||
"""Layer Normalization layer.
|
||||
|
||||
This layer wraps [`torch.nn.LayerNorm`][torch.nn.LayerNorm].
|
||||
|
||||
Receives:
|
||||
(Float[Tensor, "batch *normalized_shape"]):
|
||||
|
||||
Returns:
|
||||
(Float[Tensor, "batch *normalized_shape"]):
|
||||
|
||||
Example:
|
||||
```py
|
||||
layernorm = fl.LayerNorm(normalized_shape=128)
|
||||
|
||||
tensor = torch.randn(2, 128)
|
||||
output = layernorm(tensor)
|
||||
|
||||
assert output.shape == (2, 128)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
normalized_shape: int | list[int],
|
||||
|
@ -21,7 +48,28 @@ class LayerNorm(nn.LayerNorm, WeightedModule):
|
|||
)
|
||||
|
||||
|
||||
class GroupNorm(nn.GroupNorm, WeightedModule):
|
||||
class GroupNorm(_GroupNorm, WeightedModule):
|
||||
"""Group Normalization layer.
|
||||
|
||||
This layer wraps [`torch.nn.GroupNorm`][torch.nn.GroupNorm].
|
||||
|
||||
Receives:
|
||||
(Float[Tensor, "batch channels *normalized_shape"]):
|
||||
|
||||
Returns:
|
||||
(Float[Tensor, "batch channels *normalized_shape"]):
|
||||
|
||||
Example:
|
||||
```py
|
||||
groupnorm = fl.GroupNorm(channels=128, num_groups=8)
|
||||
|
||||
tensor = torch.randn(2, 128, 8)
|
||||
output = groupnorm(tensor)
|
||||
|
||||
assert output.shape == (2, 128, 8)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
|
@ -44,12 +92,15 @@ class GroupNorm(nn.GroupNorm, WeightedModule):
|
|||
|
||||
|
||||
class LayerNorm2d(WeightedModule):
|
||||
"""
|
||||
2D Layer Normalization module.
|
||||
"""2D Layer Normalization layer.
|
||||
|
||||
Parameters:
|
||||
channels (int): Number of channels in the input tensor.
|
||||
eps (float, optional): A small constant for numerical stability. Default: 1e-6.
|
||||
This layer applies Layer Normalization along the 2nd dimension of a 4D tensor.
|
||||
|
||||
Receives:
|
||||
(Float[Tensor, "batch channels height width"]):
|
||||
|
||||
Returns:
|
||||
(Float[Tensor, "batch channels height width"]):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -60,11 +111,14 @@ class LayerNorm2d(WeightedModule):
|
|||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(ones(channels, device=device, dtype=dtype))
|
||||
self.bias = nn.Parameter(zeros(channels, device=device, dtype=dtype))
|
||||
self.weight = TorchParameter(ones(channels, device=device, dtype=dtype))
|
||||
self.bias = TorchParameter(zeros(channels, device=device, dtype=dtype))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: Float[Tensor, "batch channels height width"]) -> Float[Tensor, "batch channels height width"]:
|
||||
def forward(
|
||||
self,
|
||||
x: Float[Tensor, "batch channels height width"],
|
||||
) -> Float[Tensor, "batch channels height width"]:
|
||||
x_mean = x.mean(1, keepdim=True)
|
||||
x_var = (x - x_mean).pow(2).mean(1, keepdim=True)
|
||||
x_norm = (x - x_mean) / sqrt(x_var + self.eps)
|
||||
|
@ -72,7 +126,18 @@ class LayerNorm2d(WeightedModule):
|
|||
return x_out
|
||||
|
||||
|
||||
class InstanceNorm2d(nn.InstanceNorm2d, Module):
|
||||
class InstanceNorm2d(_InstanceNorm2d, Module):
|
||||
"""Instance Normalization layer.
|
||||
|
||||
This layer wraps [`torch.nn.InstanceNorm2d`][torch.nn.InstanceNorm2d].
|
||||
|
||||
Receives:
|
||||
(Float[Tensor, "batch channels height width"]):
|
||||
|
||||
Returns:
|
||||
(Float[Tensor, "batch channels height width"]):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
|
|
Loading…
Reference in a new issue