mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
Add inner_dim Parameter to Attention Layer in Fluxion
This commit is contained in:
parent
7ca6bd0ccd
commit
8615dbdbde
|
@ -66,6 +66,7 @@ class Attention(Chain):
|
|||
"heads_dim",
|
||||
"key_embedding_dim",
|
||||
"value_embedding_dim",
|
||||
"inner_dim",
|
||||
"use_bias",
|
||||
"is_causal",
|
||||
]
|
||||
|
@ -76,6 +77,7 @@ class Attention(Chain):
|
|||
num_heads: int = 1,
|
||||
key_embedding_dim: int | None = None,
|
||||
value_embedding_dim: int | None = None,
|
||||
inner_dim: int | None = None,
|
||||
use_bias: bool = True,
|
||||
is_causal: bool | None = None,
|
||||
device: Device | str | None = None,
|
||||
|
@ -89,27 +91,28 @@ class Attention(Chain):
|
|||
self.heads_dim = embedding_dim // num_heads
|
||||
self.key_embedding_dim = key_embedding_dim or embedding_dim
|
||||
self.value_embedding_dim = value_embedding_dim or embedding_dim
|
||||
self.inner_dim = inner_dim or embedding_dim
|
||||
self.use_bias = use_bias
|
||||
self.is_causal = is_causal
|
||||
super().__init__(
|
||||
Distribute(
|
||||
Linear(
|
||||
in_features=self.embedding_dim,
|
||||
out_features=self.embedding_dim,
|
||||
out_features=self.inner_dim,
|
||||
bias=self.use_bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
Linear(
|
||||
in_features=self.key_embedding_dim,
|
||||
out_features=self.embedding_dim,
|
||||
out_features=self.inner_dim,
|
||||
bias=self.use_bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
Linear(
|
||||
in_features=self.value_embedding_dim,
|
||||
out_features=self.embedding_dim,
|
||||
out_features=self.inner_dim,
|
||||
bias=self.use_bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
|
@ -117,7 +120,7 @@ class Attention(Chain):
|
|||
),
|
||||
ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal),
|
||||
Linear(
|
||||
in_features=self.embedding_dim,
|
||||
in_features=self.inner_dim,
|
||||
out_features=self.embedding_dim,
|
||||
bias=True,
|
||||
device=device,
|
||||
|
@ -130,6 +133,7 @@ class SelfAttention(Attention):
|
|||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
inner_dim: int | None = None,
|
||||
num_heads: int = 1,
|
||||
use_bias: bool = True,
|
||||
is_causal: bool | None = None,
|
||||
|
@ -138,6 +142,7 @@ class SelfAttention(Attention):
|
|||
) -> None:
|
||||
super().__init__(
|
||||
embedding_dim=embedding_dim,
|
||||
inner_dim=inner_dim,
|
||||
num_heads=num_heads,
|
||||
use_bias=use_bias,
|
||||
is_causal=is_causal,
|
||||
|
|
Loading…
Reference in a new issue