GroupNorm and LayerNorm must be affine to be WeightedModules

This commit is contained in:
Pierre Chapuis 2023-08-08 12:13:01 +02:00
parent 32425016c8
commit e10f761a84
2 changed files with 4 additions and 7 deletions

View file

@ -9,14 +9,13 @@ class LayerNorm(_LayerNorm, WeightedModule):
self, self,
normalized_shape: int | list[int], normalized_shape: int | list[int],
eps: float = 0.00001, eps: float = 0.00001,
elementwise_affine: bool = True,
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
) -> None: ) -> None:
super().__init__( # type: ignore super().__init__( # type: ignore
normalized_shape=normalized_shape, normalized_shape=normalized_shape,
eps=eps, eps=eps,
elementwise_affine=elementwise_affine, elementwise_affine=True, # otherwise not a WeightedModule
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
@ -28,7 +27,6 @@ class GroupNorm(_GroupNorm, WeightedModule):
channels: int, channels: int,
num_groups: int, num_groups: int,
eps: float = 1e-5, eps: float = 1e-5,
affine: bool = True,
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
) -> None: ) -> None:
@ -36,14 +34,13 @@ class GroupNorm(_GroupNorm, WeightedModule):
num_groups=num_groups, num_groups=num_groups,
num_channels=channels, num_channels=channels,
eps=eps, eps=eps,
affine=affine, affine=True, # otherwise not a WeightedModule
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
self.channels = channels self.channels = channels
self.num_groups = num_groups self.num_groups = num_groups
self.eps = eps self.eps = eps
self.affine = affine
class LayerNorm2d(WeightedModule): class LayerNorm2d(WeightedModule):

View file

@ -143,14 +143,14 @@ class CrossAttentionBlock2d(Sum):
in_block = ( in_block = (
Chain( Chain(
GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, affine=True, device=device, dtype=dtype), GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, device=device, dtype=dtype),
StatefulFlatten(context="flatten", key="sizes", start_dim=2), StatefulFlatten(context="flatten", key="sizes", start_dim=2),
Transpose(1, 2), Transpose(1, 2),
Linear(in_features=channels, out_features=channels, device=device, dtype=dtype), Linear(in_features=channels, out_features=channels, device=device, dtype=dtype),
) )
if use_linear_projection if use_linear_projection
else Chain( else Chain(
GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, affine=True, device=device, dtype=dtype), GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, device=device, dtype=dtype),
Conv2d(in_channels=channels, out_channels=channels, kernel_size=1, device=device, dtype=dtype), Conv2d(in_channels=channels, out_channels=channels, kernel_size=1, device=device, dtype=dtype),
StatefulFlatten(context="flatten", key="sizes", start_dim=2), StatefulFlatten(context="flatten", key="sizes", start_dim=2),
Transpose(1, 2), Transpose(1, 2),