mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 15:02:01 +00:00
GroupNorm and LayerNorm must be affine to be WeightedModules
This commit is contained in:
parent
32425016c8
commit
e10f761a84
|
@ -9,14 +9,13 @@ class LayerNorm(_LayerNorm, WeightedModule):
|
|||
self,
|
||||
normalized_shape: int | list[int],
|
||||
eps: float = 0.00001,
|
||||
elementwise_affine: bool = True,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__( # type: ignore
|
||||
normalized_shape=normalized_shape,
|
||||
eps=eps,
|
||||
elementwise_affine=elementwise_affine,
|
||||
elementwise_affine=True, # otherwise not a WeightedModule
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -28,7 +27,6 @@ class GroupNorm(_GroupNorm, WeightedModule):
|
|||
channels: int,
|
||||
num_groups: int,
|
||||
eps: float = 1e-5,
|
||||
affine: bool = True,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
|
@ -36,14 +34,13 @@ class GroupNorm(_GroupNorm, WeightedModule):
|
|||
num_groups=num_groups,
|
||||
num_channels=channels,
|
||||
eps=eps,
|
||||
affine=affine,
|
||||
affine=True, # otherwise not a WeightedModule
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.channels = channels
|
||||
self.num_groups = num_groups
|
||||
self.eps = eps
|
||||
self.affine = affine
|
||||
|
||||
|
||||
class LayerNorm2d(WeightedModule):
|
||||
|
|
|
@ -143,14 +143,14 @@ class CrossAttentionBlock2d(Sum):
|
|||
|
||||
in_block = (
|
||||
Chain(
|
||||
GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, affine=True, device=device, dtype=dtype),
|
||||
GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, device=device, dtype=dtype),
|
||||
StatefulFlatten(context="flatten", key="sizes", start_dim=2),
|
||||
Transpose(1, 2),
|
||||
Linear(in_features=channels, out_features=channels, device=device, dtype=dtype),
|
||||
)
|
||||
if use_linear_projection
|
||||
else Chain(
|
||||
GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, affine=True, device=device, dtype=dtype),
|
||||
GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, device=device, dtype=dtype),
|
||||
Conv2d(in_channels=channels, out_channels=channels, kernel_size=1, device=device, dtype=dtype),
|
||||
StatefulFlatten(context="flatten", key="sizes", start_dim=2),
|
||||
Transpose(1, 2),
|
||||
|
|
Loading…
Reference in a new issue