diff --git a/src/refiners/fluxion/layers/norm.py b/src/refiners/fluxion/layers/norm.py index e088c95..11d7d3a 100644 --- a/src/refiners/fluxion/layers/norm.py +++ b/src/refiners/fluxion/layers/norm.py @@ -9,14 +9,13 @@ class LayerNorm(_LayerNorm, WeightedModule): self, normalized_shape: int | list[int], eps: float = 0.00001, - elementwise_affine: bool = True, device: Device | str | None = None, dtype: DType | None = None, ) -> None: super().__init__( # type: ignore normalized_shape=normalized_shape, eps=eps, - elementwise_affine=elementwise_affine, + elementwise_affine=True, # otherwise not a WeightedModule device=device, dtype=dtype, ) @@ -28,7 +27,6 @@ class GroupNorm(_GroupNorm, WeightedModule): channels: int, num_groups: int, eps: float = 1e-5, - affine: bool = True, device: Device | str | None = None, dtype: DType | None = None, ) -> None: @@ -36,14 +34,13 @@ class GroupNorm(_GroupNorm, WeightedModule): num_groups=num_groups, num_channels=channels, eps=eps, - affine=affine, + affine=True, # otherwise not a WeightedModule device=device, dtype=dtype, ) self.channels = channels self.num_groups = num_groups self.eps = eps - self.affine = affine class LayerNorm2d(WeightedModule): diff --git a/src/refiners/foundationals/latent_diffusion/cross_attention.py b/src/refiners/foundationals/latent_diffusion/cross_attention.py index 9641596..592b165 100644 --- a/src/refiners/foundationals/latent_diffusion/cross_attention.py +++ b/src/refiners/foundationals/latent_diffusion/cross_attention.py @@ -143,14 +143,14 @@ class CrossAttentionBlock2d(Sum): in_block = ( Chain( - GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, affine=True, device=device, dtype=dtype), + GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, device=device, dtype=dtype), StatefulFlatten(context="flatten", key="sizes", start_dim=2), Transpose(1, 2), Linear(in_features=channels, out_features=channels, device=device, dtype=dtype), ) if use_linear_projection else Chain( - GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, affine=True, device=device, dtype=dtype), + GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, device=device, dtype=dtype), Conv2d(in_channels=channels, out_channels=channels, kernel_size=1, device=device, dtype=dtype), StatefulFlatten(context="flatten", key="sizes", start_dim=2), Transpose(1, 2),