mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
(doc/fluxion/norm) add/convert docstrings to mkdocstrings format
This commit is contained in:
parent
3282782b56
commit
49847658e9
|
@ -1,10 +1,37 @@
|
||||||
from jaxtyping import Float
|
from jaxtyping import Float
|
||||||
from torch import Tensor, device as Device, dtype as DType, nn, ones, sqrt, zeros
|
from torch import Tensor, device as Device, dtype as DType, ones, sqrt, zeros
|
||||||
|
from torch.nn import (
|
||||||
|
GroupNorm as _GroupNorm,
|
||||||
|
InstanceNorm2d as _InstanceNorm2d,
|
||||||
|
LayerNorm as _LayerNorm,
|
||||||
|
Parameter as TorchParameter,
|
||||||
|
)
|
||||||
|
|
||||||
from refiners.fluxion.layers.module import Module, WeightedModule
|
from refiners.fluxion.layers.module import Module, WeightedModule
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.LayerNorm, WeightedModule):
|
class LayerNorm(_LayerNorm, WeightedModule):
|
||||||
|
"""Layer Normalization layer.
|
||||||
|
|
||||||
|
This layer wraps [`torch.nn.LayerNorm`][torch.nn.LayerNorm].
|
||||||
|
|
||||||
|
Receives:
|
||||||
|
(Float[Tensor, "batch *normalized_shape"]):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Float[Tensor, "batch *normalized_shape"]):
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
layernorm = fl.LayerNorm(normalized_shape=128)
|
||||||
|
|
||||||
|
tensor = torch.randn(2, 128)
|
||||||
|
output = layernorm(tensor)
|
||||||
|
|
||||||
|
assert output.shape == (2, 128)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
normalized_shape: int | list[int],
|
normalized_shape: int | list[int],
|
||||||
|
@ -21,7 +48,28 @@ class LayerNorm(nn.LayerNorm, WeightedModule):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class GroupNorm(nn.GroupNorm, WeightedModule):
|
class GroupNorm(_GroupNorm, WeightedModule):
|
||||||
|
"""Group Normalization layer.
|
||||||
|
|
||||||
|
This layer wraps [`torch.nn.GroupNorm`][torch.nn.GroupNorm].
|
||||||
|
|
||||||
|
Receives:
|
||||||
|
(Float[Tensor, "batch channels *normalized_shape"]):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Float[Tensor, "batch channels *normalized_shape"]):
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
groupnorm = fl.GroupNorm(channels=128, num_groups=8)
|
||||||
|
|
||||||
|
tensor = torch.randn(2, 128, 8)
|
||||||
|
output = groupnorm(tensor)
|
||||||
|
|
||||||
|
assert output.shape == (2, 128, 8)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
channels: int,
|
channels: int,
|
||||||
|
@ -44,12 +92,15 @@ class GroupNorm(nn.GroupNorm, WeightedModule):
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm2d(WeightedModule):
|
class LayerNorm2d(WeightedModule):
|
||||||
"""
|
"""2D Layer Normalization layer.
|
||||||
2D Layer Normalization module.
|
|
||||||
|
|
||||||
Parameters:
|
This layer applies Layer Normalization along the 2nd dimension of a 4D tensor.
|
||||||
channels (int): Number of channels in the input tensor.
|
|
||||||
eps (float, optional): A small constant for numerical stability. Default: 1e-6.
|
Receives:
|
||||||
|
(Float[Tensor, "batch channels height width"]):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Float[Tensor, "batch channels height width"]):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -60,11 +111,14 @@ class LayerNorm2d(WeightedModule):
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(ones(channels, device=device, dtype=dtype))
|
self.weight = TorchParameter(ones(channels, device=device, dtype=dtype))
|
||||||
self.bias = nn.Parameter(zeros(channels, device=device, dtype=dtype))
|
self.bias = TorchParameter(zeros(channels, device=device, dtype=dtype))
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
def forward(self, x: Float[Tensor, "batch channels height width"]) -> Float[Tensor, "batch channels height width"]:
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Float[Tensor, "batch channels height width"],
|
||||||
|
) -> Float[Tensor, "batch channels height width"]:
|
||||||
x_mean = x.mean(1, keepdim=True)
|
x_mean = x.mean(1, keepdim=True)
|
||||||
x_var = (x - x_mean).pow(2).mean(1, keepdim=True)
|
x_var = (x - x_mean).pow(2).mean(1, keepdim=True)
|
||||||
x_norm = (x - x_mean) / sqrt(x_var + self.eps)
|
x_norm = (x - x_mean) / sqrt(x_var + self.eps)
|
||||||
|
@ -72,7 +126,18 @@ class LayerNorm2d(WeightedModule):
|
||||||
return x_out
|
return x_out
|
||||||
|
|
||||||
|
|
||||||
class InstanceNorm2d(nn.InstanceNorm2d, Module):
|
class InstanceNorm2d(_InstanceNorm2d, Module):
|
||||||
|
"""Instance Normalization layer.
|
||||||
|
|
||||||
|
This layer wraps [`torch.nn.InstanceNorm2d`][torch.nn.InstanceNorm2d].
|
||||||
|
|
||||||
|
Receives:
|
||||||
|
(Float[Tensor, "batch channels height width"]):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Float[Tensor, "batch channels height width"]):
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_features: int,
|
num_features: int,
|
||||||
|
|
Loading…
Reference in a new issue