(doc/fluxion/norm) add/convert docstrings to mkdocstrings format

This commit is contained in:
Laurent 2024-02-01 21:44:23 +00:00 committed by Laureηt
parent 3282782b56
commit 49847658e9

View file

@ -1,10 +1,37 @@
from jaxtyping import Float
from torch import Tensor, device as Device, dtype as DType, nn, ones, sqrt, zeros
from torch import Tensor, device as Device, dtype as DType, ones, sqrt, zeros
from torch.nn import (
GroupNorm as _GroupNorm,
InstanceNorm2d as _InstanceNorm2d,
LayerNorm as _LayerNorm,
Parameter as TorchParameter,
)
from refiners.fluxion.layers.module import Module, WeightedModule
class LayerNorm(nn.LayerNorm, WeightedModule):
class LayerNorm(_LayerNorm, WeightedModule):
"""Layer Normalization layer.
This layer wraps [`torch.nn.LayerNorm`][torch.nn.LayerNorm].
Receives:
(Float[Tensor, "batch *normalized_shape"]):
Returns:
(Float[Tensor, "batch *normalized_shape"]):
Example:
```py
layernorm = fl.LayerNorm(normalized_shape=128)
tensor = torch.randn(2, 128)
output = layernorm(tensor)
assert output.shape == (2, 128)
```
"""
def __init__(
self,
normalized_shape: int | list[int],
@ -21,7 +48,28 @@ class LayerNorm(nn.LayerNorm, WeightedModule):
)
class GroupNorm(nn.GroupNorm, WeightedModule):
class GroupNorm(_GroupNorm, WeightedModule):
"""Group Normalization layer.
This layer wraps [`torch.nn.GroupNorm`][torch.nn.GroupNorm].
Receives:
(Float[Tensor, "batch channels *normalized_shape"]):
Returns:
(Float[Tensor, "batch channels *normalized_shape"]):
Example:
```py
groupnorm = fl.GroupNorm(channels=128, num_groups=8)
tensor = torch.randn(2, 128, 8)
output = groupnorm(tensor)
assert output.shape == (2, 128, 8)
```
"""
def __init__(
self,
channels: int,
@ -44,12 +92,15 @@ class GroupNorm(nn.GroupNorm, WeightedModule):
class LayerNorm2d(WeightedModule):
"""
2D Layer Normalization module.
"""2D Layer Normalization layer.
Parameters:
channels (int): Number of channels in the input tensor.
eps (float, optional): A small constant for numerical stability. Default: 1e-6.
This layer applies Layer Normalization along the 2nd dimension of a 4D tensor.
Receives:
(Float[Tensor, "batch channels height width"]):
Returns:
(Float[Tensor, "batch channels height width"]):
"""
def __init__(
@ -60,11 +111,14 @@ class LayerNorm2d(WeightedModule):
dtype: DType | None = None,
) -> None:
super().__init__()
self.weight = nn.Parameter(ones(channels, device=device, dtype=dtype))
self.bias = nn.Parameter(zeros(channels, device=device, dtype=dtype))
self.weight = TorchParameter(ones(channels, device=device, dtype=dtype))
self.bias = TorchParameter(zeros(channels, device=device, dtype=dtype))
self.eps = eps
def forward(self, x: Float[Tensor, "batch channels height width"]) -> Float[Tensor, "batch channels height width"]:
def forward(
self,
x: Float[Tensor, "batch channels height width"],
) -> Float[Tensor, "batch channels height width"]:
x_mean = x.mean(1, keepdim=True)
x_var = (x - x_mean).pow(2).mean(1, keepdim=True)
x_norm = (x - x_mean) / sqrt(x_var + self.eps)
@ -72,7 +126,18 @@ class LayerNorm2d(WeightedModule):
return x_out
class InstanceNorm2d(nn.InstanceNorm2d, Module):
class InstanceNorm2d(_InstanceNorm2d, Module):
"""Instance Normalization layer.
This layer wraps [`torch.nn.InstanceNorm2d`][torch.nn.InstanceNorm2d].
Receives:
(Float[Tensor, "batch channels height width"]):
Returns:
(Float[Tensor, "batch channels height width"]):
"""
def __init__(
self,
num_features: int,