diff --git a/src/refiners/fluxion/layers/attentions.py b/src/refiners/fluxion/layers/attentions.py index 13a16af..8676bff 100644 --- a/src/refiners/fluxion/layers/attentions.py +++ b/src/refiners/fluxion/layers/attentions.py @@ -66,6 +66,7 @@ class Attention(Chain): "heads_dim", "key_embedding_dim", "value_embedding_dim", + "inner_dim", "use_bias", "is_causal", ] @@ -76,6 +77,7 @@ class Attention(Chain): num_heads: int = 1, key_embedding_dim: int | None = None, value_embedding_dim: int | None = None, + inner_dim: int | None = None, use_bias: bool = True, is_causal: bool | None = None, device: Device | str | None = None, @@ -89,27 +91,28 @@ class Attention(Chain): self.heads_dim = embedding_dim // num_heads self.key_embedding_dim = key_embedding_dim or embedding_dim self.value_embedding_dim = value_embedding_dim or embedding_dim + self.inner_dim = inner_dim or embedding_dim self.use_bias = use_bias self.is_causal = is_causal super().__init__( Distribute( Linear( in_features=self.embedding_dim, - out_features=self.embedding_dim, + out_features=self.inner_dim, bias=self.use_bias, device=device, dtype=dtype, ), Linear( in_features=self.key_embedding_dim, - out_features=self.embedding_dim, + out_features=self.inner_dim, bias=self.use_bias, device=device, dtype=dtype, ), Linear( in_features=self.value_embedding_dim, - out_features=self.embedding_dim, + out_features=self.inner_dim, bias=self.use_bias, device=device, dtype=dtype, @@ -117,7 +120,7 @@ class Attention(Chain): ), ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal), Linear( - in_features=self.embedding_dim, + in_features=self.inner_dim, out_features=self.embedding_dim, bias=True, device=device, @@ -130,6 +133,7 @@ class SelfAttention(Attention): def __init__( self, embedding_dim: int, + inner_dim: int | None = None, num_heads: int = 1, use_bias: bool = True, is_causal: bool | None = None, @@ -138,6 +142,7 @@ class SelfAttention(Attention): ) -> None: super().__init__( embedding_dim=embedding_dim, + inner_dim=inner_dim, num_heads=num_heads, use_bias=use_bias, is_causal=is_causal,