mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 00:58:13 +00:00
GroupNorm and LayerNorm must be affine to be WeightedModules
This commit is contained in:
parent
32425016c8
commit
e10f761a84
|
@ -9,14 +9,13 @@ class LayerNorm(_LayerNorm, WeightedModule):
|
||||||
self,
|
self,
|
||||||
normalized_shape: int | list[int],
|
normalized_shape: int | list[int],
|
||||||
eps: float = 0.00001,
|
eps: float = 0.00001,
|
||||||
elementwise_affine: bool = True,
|
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__( # type: ignore
|
super().__init__( # type: ignore
|
||||||
normalized_shape=normalized_shape,
|
normalized_shape=normalized_shape,
|
||||||
eps=eps,
|
eps=eps,
|
||||||
elementwise_affine=elementwise_affine,
|
elementwise_affine=True, # otherwise not a WeightedModule
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
@ -28,7 +27,6 @@ class GroupNorm(_GroupNorm, WeightedModule):
|
||||||
channels: int,
|
channels: int,
|
||||||
num_groups: int,
|
num_groups: int,
|
||||||
eps: float = 1e-5,
|
eps: float = 1e-5,
|
||||||
affine: bool = True,
|
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -36,14 +34,13 @@ class GroupNorm(_GroupNorm, WeightedModule):
|
||||||
num_groups=num_groups,
|
num_groups=num_groups,
|
||||||
num_channels=channels,
|
num_channels=channels,
|
||||||
eps=eps,
|
eps=eps,
|
||||||
affine=affine,
|
affine=True, # otherwise not a WeightedModule
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.num_groups = num_groups
|
self.num_groups = num_groups
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.affine = affine
|
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm2d(WeightedModule):
|
class LayerNorm2d(WeightedModule):
|
||||||
|
|
|
@ -143,14 +143,14 @@ class CrossAttentionBlock2d(Sum):
|
||||||
|
|
||||||
in_block = (
|
in_block = (
|
||||||
Chain(
|
Chain(
|
||||||
GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, affine=True, device=device, dtype=dtype),
|
GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, device=device, dtype=dtype),
|
||||||
StatefulFlatten(context="flatten", key="sizes", start_dim=2),
|
StatefulFlatten(context="flatten", key="sizes", start_dim=2),
|
||||||
Transpose(1, 2),
|
Transpose(1, 2),
|
||||||
Linear(in_features=channels, out_features=channels, device=device, dtype=dtype),
|
Linear(in_features=channels, out_features=channels, device=device, dtype=dtype),
|
||||||
)
|
)
|
||||||
if use_linear_projection
|
if use_linear_projection
|
||||||
else Chain(
|
else Chain(
|
||||||
GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, affine=True, device=device, dtype=dtype),
|
GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, device=device, dtype=dtype),
|
||||||
Conv2d(in_channels=channels, out_channels=channels, kernel_size=1, device=device, dtype=dtype),
|
Conv2d(in_channels=channels, out_channels=channels, kernel_size=1, device=device, dtype=dtype),
|
||||||
StatefulFlatten(context="flatten", key="sizes", start_dim=2),
|
StatefulFlatten(context="flatten", key="sizes", start_dim=2),
|
||||||
Transpose(1, 2),
|
Transpose(1, 2),
|
||||||
|
|
Loading…
Reference in a new issue