GroupNorm and LayerNorm must be affine to be WeightedModules

This commit is contained in:
Pierre Chapuis 2023-08-08 12:13:01 +02:00
parent 32425016c8
commit e10f761a84
2 changed files with 4 additions and 7 deletions

View file

@ -9,14 +9,13 @@ class LayerNorm(_LayerNorm, WeightedModule):
self,
normalized_shape: int | list[int],
eps: float = 0.00001,
elementwise_affine: bool = True,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__( # type: ignore
normalized_shape=normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
elementwise_affine=True, # otherwise not a WeightedModule
device=device,
dtype=dtype,
)
@ -28,7 +27,6 @@ class GroupNorm(_GroupNorm, WeightedModule):
channels: int,
num_groups: int,
eps: float = 1e-5,
affine: bool = True,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
@ -36,14 +34,13 @@ class GroupNorm(_GroupNorm, WeightedModule):
num_groups=num_groups,
num_channels=channels,
eps=eps,
affine=affine,
affine=True, # otherwise not a WeightedModule
device=device,
dtype=dtype,
)
self.channels = channels
self.num_groups = num_groups
self.eps = eps
self.affine = affine
class LayerNorm2d(WeightedModule):

View file

@ -143,14 +143,14 @@ class CrossAttentionBlock2d(Sum):
in_block = (
Chain(
GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, affine=True, device=device, dtype=dtype),
GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, device=device, dtype=dtype),
StatefulFlatten(context="flatten", key="sizes", start_dim=2),
Transpose(1, 2),
Linear(in_features=channels, out_features=channels, device=device, dtype=dtype),
)
if use_linear_projection
else Chain(
GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, affine=True, device=device, dtype=dtype),
GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, device=device, dtype=dtype),
Conv2d(in_channels=channels, out_channels=channels, kernel_size=1, device=device, dtype=dtype),
StatefulFlatten(context="flatten", key="sizes", start_dim=2),
Transpose(1, 2),