Add inner_dim Parameter to Attention Layer in Fluxion

This commit is contained in:
limiteinductive 2023-08-28 15:47:29 +02:00 committed by Benjamin Trom
parent 7ca6bd0ccd
commit 8615dbdbde

View file

@ -66,6 +66,7 @@ class Attention(Chain):
"heads_dim",
"key_embedding_dim",
"value_embedding_dim",
"inner_dim",
"use_bias",
"is_causal",
]
@ -76,6 +77,7 @@ class Attention(Chain):
num_heads: int = 1,
key_embedding_dim: int | None = None,
value_embedding_dim: int | None = None,
inner_dim: int | None = None,
use_bias: bool = True,
is_causal: bool | None = None,
device: Device | str | None = None,
@ -89,27 +91,28 @@ class Attention(Chain):
self.heads_dim = embedding_dim // num_heads
self.key_embedding_dim = key_embedding_dim or embedding_dim
self.value_embedding_dim = value_embedding_dim or embedding_dim
self.inner_dim = inner_dim or embedding_dim
self.use_bias = use_bias
self.is_causal = is_causal
super().__init__(
Distribute(
Linear(
in_features=self.embedding_dim,
out_features=self.embedding_dim,
out_features=self.inner_dim,
bias=self.use_bias,
device=device,
dtype=dtype,
),
Linear(
in_features=self.key_embedding_dim,
out_features=self.embedding_dim,
out_features=self.inner_dim,
bias=self.use_bias,
device=device,
dtype=dtype,
),
Linear(
in_features=self.value_embedding_dim,
out_features=self.embedding_dim,
out_features=self.inner_dim,
bias=self.use_bias,
device=device,
dtype=dtype,
@ -117,7 +120,7 @@ class Attention(Chain):
),
ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal),
Linear(
in_features=self.embedding_dim,
in_features=self.inner_dim,
out_features=self.embedding_dim,
bias=True,
device=device,
@ -130,6 +133,7 @@ class SelfAttention(Attention):
def __init__(
self,
embedding_dim: int,
inner_dim: int | None = None,
num_heads: int = 1,
use_bias: bool = True,
is_causal: bool | None = None,
@ -138,6 +142,7 @@ class SelfAttention(Attention):
) -> None:
super().__init__(
embedding_dim=embedding_dim,
inner_dim=inner_dim,
num_heads=num_heads,
use_bias=use_bias,
is_causal=is_causal,