diff --git a/src/refiners/fluxion/layers/norm.py b/src/refiners/fluxion/layers/norm.py index 5ac6784..bc7f0dd 100644 --- a/src/refiners/fluxion/layers/norm.py +++ b/src/refiners/fluxion/layers/norm.py @@ -1,10 +1,37 @@ from jaxtyping import Float -from torch import Tensor, device as Device, dtype as DType, nn, ones, sqrt, zeros +from torch import Tensor, device as Device, dtype as DType, ones, sqrt, zeros +from torch.nn import ( + GroupNorm as _GroupNorm, + InstanceNorm2d as _InstanceNorm2d, + LayerNorm as _LayerNorm, + Parameter as TorchParameter, +) from refiners.fluxion.layers.module import Module, WeightedModule -class LayerNorm(nn.LayerNorm, WeightedModule): +class LayerNorm(_LayerNorm, WeightedModule): + """Layer Normalization layer. + + This layer wraps [`torch.nn.LayerNorm`][torch.nn.LayerNorm]. + + Receives: + (Float[Tensor, "batch *normalized_shape"]): + + Returns: + (Float[Tensor, "batch *normalized_shape"]): + + Example: + ```py + layernorm = fl.LayerNorm(normalized_shape=128) + + tensor = torch.randn(2, 128) + output = layernorm(tensor) + + assert output.shape == (2, 128) + ``` + """ + def __init__( self, normalized_shape: int | list[int], @@ -21,7 +48,28 @@ class LayerNorm(nn.LayerNorm, WeightedModule): ) -class GroupNorm(nn.GroupNorm, WeightedModule): +class GroupNorm(_GroupNorm, WeightedModule): + """Group Normalization layer. + + This layer wraps [`torch.nn.GroupNorm`][torch.nn.GroupNorm]. + + Receives: + (Float[Tensor, "batch channels *normalized_shape"]): + + Returns: + (Float[Tensor, "batch channels *normalized_shape"]): + + Example: + ```py + groupnorm = fl.GroupNorm(channels=128, num_groups=8) + + tensor = torch.randn(2, 128, 8) + output = groupnorm(tensor) + + assert output.shape == (2, 128, 8) + ``` + """ + def __init__( self, channels: int, @@ -44,12 +92,15 @@ class GroupNorm(nn.GroupNorm, WeightedModule): class LayerNorm2d(WeightedModule): - """ - 2D Layer Normalization module. + """2D Layer Normalization layer. - Parameters: - channels (int): Number of channels in the input tensor. - eps (float, optional): A small constant for numerical stability. Default: 1e-6. + This layer applies Layer Normalization along the 2nd dimension of a 4D tensor. + + Receives: + (Float[Tensor, "batch channels height width"]): + + Returns: + (Float[Tensor, "batch channels height width"]): """ def __init__( @@ -60,11 +111,14 @@ class LayerNorm2d(WeightedModule): dtype: DType | None = None, ) -> None: super().__init__() - self.weight = nn.Parameter(ones(channels, device=device, dtype=dtype)) - self.bias = nn.Parameter(zeros(channels, device=device, dtype=dtype)) + self.weight = TorchParameter(ones(channels, device=device, dtype=dtype)) + self.bias = TorchParameter(zeros(channels, device=device, dtype=dtype)) self.eps = eps - def forward(self, x: Float[Tensor, "batch channels height width"]) -> Float[Tensor, "batch channels height width"]: + def forward( + self, + x: Float[Tensor, "batch channels height width"], + ) -> Float[Tensor, "batch channels height width"]: x_mean = x.mean(1, keepdim=True) x_var = (x - x_mean).pow(2).mean(1, keepdim=True) x_norm = (x - x_mean) / sqrt(x_var + self.eps) @@ -72,7 +126,18 @@ class LayerNorm2d(WeightedModule): return x_out -class InstanceNorm2d(nn.InstanceNorm2d, Module): +class InstanceNorm2d(_InstanceNorm2d, Module): + """Instance Normalization layer. + + This layer wraps [`torch.nn.InstanceNorm2d`][torch.nn.InstanceNorm2d]. + + Receives: + (Float[Tensor, "batch channels height width"]): + + Returns: + (Float[Tensor, "batch channels height width"]): + """ + def __init__( self, num_features: int,