mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 14:18:46 +00:00
Add inner_dim Parameter to Attention Layer in Fluxion
This commit is contained in:
parent
7ca6bd0ccd
commit
8615dbdbde
|
@ -66,6 +66,7 @@ class Attention(Chain):
|
||||||
"heads_dim",
|
"heads_dim",
|
||||||
"key_embedding_dim",
|
"key_embedding_dim",
|
||||||
"value_embedding_dim",
|
"value_embedding_dim",
|
||||||
|
"inner_dim",
|
||||||
"use_bias",
|
"use_bias",
|
||||||
"is_causal",
|
"is_causal",
|
||||||
]
|
]
|
||||||
|
@ -76,6 +77,7 @@ class Attention(Chain):
|
||||||
num_heads: int = 1,
|
num_heads: int = 1,
|
||||||
key_embedding_dim: int | None = None,
|
key_embedding_dim: int | None = None,
|
||||||
value_embedding_dim: int | None = None,
|
value_embedding_dim: int | None = None,
|
||||||
|
inner_dim: int | None = None,
|
||||||
use_bias: bool = True,
|
use_bias: bool = True,
|
||||||
is_causal: bool | None = None,
|
is_causal: bool | None = None,
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
|
@ -89,27 +91,28 @@ class Attention(Chain):
|
||||||
self.heads_dim = embedding_dim // num_heads
|
self.heads_dim = embedding_dim // num_heads
|
||||||
self.key_embedding_dim = key_embedding_dim or embedding_dim
|
self.key_embedding_dim = key_embedding_dim or embedding_dim
|
||||||
self.value_embedding_dim = value_embedding_dim or embedding_dim
|
self.value_embedding_dim = value_embedding_dim or embedding_dim
|
||||||
|
self.inner_dim = inner_dim or embedding_dim
|
||||||
self.use_bias = use_bias
|
self.use_bias = use_bias
|
||||||
self.is_causal = is_causal
|
self.is_causal = is_causal
|
||||||
super().__init__(
|
super().__init__(
|
||||||
Distribute(
|
Distribute(
|
||||||
Linear(
|
Linear(
|
||||||
in_features=self.embedding_dim,
|
in_features=self.embedding_dim,
|
||||||
out_features=self.embedding_dim,
|
out_features=self.inner_dim,
|
||||||
bias=self.use_bias,
|
bias=self.use_bias,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
Linear(
|
Linear(
|
||||||
in_features=self.key_embedding_dim,
|
in_features=self.key_embedding_dim,
|
||||||
out_features=self.embedding_dim,
|
out_features=self.inner_dim,
|
||||||
bias=self.use_bias,
|
bias=self.use_bias,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
Linear(
|
Linear(
|
||||||
in_features=self.value_embedding_dim,
|
in_features=self.value_embedding_dim,
|
||||||
out_features=self.embedding_dim,
|
out_features=self.inner_dim,
|
||||||
bias=self.use_bias,
|
bias=self.use_bias,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
@ -117,7 +120,7 @@ class Attention(Chain):
|
||||||
),
|
),
|
||||||
ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal),
|
ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal),
|
||||||
Linear(
|
Linear(
|
||||||
in_features=self.embedding_dim,
|
in_features=self.inner_dim,
|
||||||
out_features=self.embedding_dim,
|
out_features=self.embedding_dim,
|
||||||
bias=True,
|
bias=True,
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -130,6 +133,7 @@ class SelfAttention(Attention):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
|
inner_dim: int | None = None,
|
||||||
num_heads: int = 1,
|
num_heads: int = 1,
|
||||||
use_bias: bool = True,
|
use_bias: bool = True,
|
||||||
is_causal: bool | None = None,
|
is_causal: bool | None = None,
|
||||||
|
@ -138,6 +142,7 @@ class SelfAttention(Attention):
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
|
inner_dim=inner_dim,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
use_bias=use_bias,
|
use_bias=use_bias,
|
||||||
is_causal=is_causal,
|
is_causal=is_causal,
|
||||||
|
|
Loading…
Reference in a new issue