Add inner_dim Parameter to Attention Layer in Fluxion

This commit is contained in:
limiteinductive 2023-08-28 15:47:29 +02:00 committed by Benjamin Trom
parent 7ca6bd0ccd
commit 8615dbdbde

View file

@ -66,6 +66,7 @@ class Attention(Chain):
"heads_dim", "heads_dim",
"key_embedding_dim", "key_embedding_dim",
"value_embedding_dim", "value_embedding_dim",
"inner_dim",
"use_bias", "use_bias",
"is_causal", "is_causal",
] ]
@ -76,6 +77,7 @@ class Attention(Chain):
num_heads: int = 1, num_heads: int = 1,
key_embedding_dim: int | None = None, key_embedding_dim: int | None = None,
value_embedding_dim: int | None = None, value_embedding_dim: int | None = None,
inner_dim: int | None = None,
use_bias: bool = True, use_bias: bool = True,
is_causal: bool | None = None, is_causal: bool | None = None,
device: Device | str | None = None, device: Device | str | None = None,
@ -89,27 +91,28 @@ class Attention(Chain):
self.heads_dim = embedding_dim // num_heads self.heads_dim = embedding_dim // num_heads
self.key_embedding_dim = key_embedding_dim or embedding_dim self.key_embedding_dim = key_embedding_dim or embedding_dim
self.value_embedding_dim = value_embedding_dim or embedding_dim self.value_embedding_dim = value_embedding_dim or embedding_dim
self.inner_dim = inner_dim or embedding_dim
self.use_bias = use_bias self.use_bias = use_bias
self.is_causal = is_causal self.is_causal = is_causal
super().__init__( super().__init__(
Distribute( Distribute(
Linear( Linear(
in_features=self.embedding_dim, in_features=self.embedding_dim,
out_features=self.embedding_dim, out_features=self.inner_dim,
bias=self.use_bias, bias=self.use_bias,
device=device, device=device,
dtype=dtype, dtype=dtype,
), ),
Linear( Linear(
in_features=self.key_embedding_dim, in_features=self.key_embedding_dim,
out_features=self.embedding_dim, out_features=self.inner_dim,
bias=self.use_bias, bias=self.use_bias,
device=device, device=device,
dtype=dtype, dtype=dtype,
), ),
Linear( Linear(
in_features=self.value_embedding_dim, in_features=self.value_embedding_dim,
out_features=self.embedding_dim, out_features=self.inner_dim,
bias=self.use_bias, bias=self.use_bias,
device=device, device=device,
dtype=dtype, dtype=dtype,
@ -117,7 +120,7 @@ class Attention(Chain):
), ),
ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal), ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal),
Linear( Linear(
in_features=self.embedding_dim, in_features=self.inner_dim,
out_features=self.embedding_dim, out_features=self.embedding_dim,
bias=True, bias=True,
device=device, device=device,
@ -130,6 +133,7 @@ class SelfAttention(Attention):
def __init__( def __init__(
self, self,
embedding_dim: int, embedding_dim: int,
inner_dim: int | None = None,
num_heads: int = 1, num_heads: int = 1,
use_bias: bool = True, use_bias: bool = True,
is_causal: bool | None = None, is_causal: bool | None = None,
@ -138,6 +142,7 @@ class SelfAttention(Attention):
) -> None: ) -> None:
super().__init__( super().__init__(
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
inner_dim=inner_dim,
num_heads=num_heads, num_heads=num_heads,
use_bias=use_bias, use_bias=use_bias,
is_causal=is_causal, is_causal=is_causal,