mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
add is_optimized option for attention
This commit is contained in:
parent
fc2390ad1c
commit
121ef4df39
|
@ -1,3 +1,5 @@
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
from jaxtyping import Float
|
from jaxtyping import Float
|
||||||
from torch.nn.functional import scaled_dot_product_attention as _scaled_dot_product_attention # type: ignore
|
from torch.nn.functional import scaled_dot_product_attention as _scaled_dot_product_attention # type: ignore
|
||||||
from torch import Tensor, device as Device, dtype as DType
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
@ -18,11 +20,31 @@ def scaled_dot_product_attention(
|
||||||
return _scaled_dot_product_attention(query, key, value, is_causal=is_causal) # type: ignore
|
return _scaled_dot_product_attention(query, key, value, is_causal=is_causal) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_dot_product_attention_non_optimized(
|
||||||
|
query: Float[Tensor, "batch source_sequence_length dim"],
|
||||||
|
key: Float[Tensor, "batch target_sequence_length dim"],
|
||||||
|
value: Float[Tensor, "batch target_sequence_length dim"],
|
||||||
|
is_causal: bool = False,
|
||||||
|
) -> Float[Tensor, "batch source_sequence_length dim"]:
|
||||||
|
if is_causal:
|
||||||
|
# TODO: implement causal attention
|
||||||
|
raise NotImplementedError("Causal attention for non_optimized attention is not yet implemented")
|
||||||
|
_, _, _, dim = query.shape
|
||||||
|
attention = query @ key.permute(0, 1, 3, 2)
|
||||||
|
attention = attention / math.sqrt(dim)
|
||||||
|
attention = torch.softmax(input=attention, dim=-1)
|
||||||
|
return attention @ value
|
||||||
|
|
||||||
|
|
||||||
class ScaledDotProductAttention(Module):
|
class ScaledDotProductAttention(Module):
|
||||||
def __init__(self, num_heads: int = 1, is_causal: bool | None = None) -> None:
|
def __init__(self, num_heads: int = 1, is_causal: bool | None = None, is_optimized: bool = True) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.is_causal = is_causal
|
self.is_causal = is_causal
|
||||||
|
self.is_optimized = is_optimized
|
||||||
|
self.dot_product = (
|
||||||
|
scaled_dot_product_attention if self.is_optimized else sparse_dot_product_attention_non_optimized
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -32,7 +54,7 @@ class ScaledDotProductAttention(Module):
|
||||||
is_causal: bool | None = None,
|
is_causal: bool | None = None,
|
||||||
) -> Float[Tensor, "batch num_queries dim"]:
|
) -> Float[Tensor, "batch num_queries dim"]:
|
||||||
return self.merge_multi_head(
|
return self.merge_multi_head(
|
||||||
scaled_dot_product_attention(
|
x=self.dot_product(
|
||||||
query=self.split_to_multi_head(query),
|
query=self.split_to_multi_head(query),
|
||||||
key=self.split_to_multi_head(key),
|
key=self.split_to_multi_head(key),
|
||||||
value=self.split_to_multi_head(value),
|
value=self.split_to_multi_head(value),
|
||||||
|
@ -69,6 +91,7 @@ class Attention(Chain):
|
||||||
"inner_dim",
|
"inner_dim",
|
||||||
"use_bias",
|
"use_bias",
|
||||||
"is_causal",
|
"is_causal",
|
||||||
|
"is_optimized",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -80,6 +103,7 @@ class Attention(Chain):
|
||||||
inner_dim: int | None = None,
|
inner_dim: int | None = None,
|
||||||
use_bias: bool = True,
|
use_bias: bool = True,
|
||||||
is_causal: bool | None = None,
|
is_causal: bool | None = None,
|
||||||
|
is_optimized: bool = True,
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -94,6 +118,7 @@ class Attention(Chain):
|
||||||
self.inner_dim = inner_dim or embedding_dim
|
self.inner_dim = inner_dim or embedding_dim
|
||||||
self.use_bias = use_bias
|
self.use_bias = use_bias
|
||||||
self.is_causal = is_causal
|
self.is_causal = is_causal
|
||||||
|
self.is_optimized = is_optimized
|
||||||
super().__init__(
|
super().__init__(
|
||||||
Distribute(
|
Distribute(
|
||||||
Linear(
|
Linear(
|
||||||
|
@ -118,7 +143,7 @@ class Attention(Chain):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal),
|
ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal, is_optimized=is_optimized),
|
||||||
Linear(
|
Linear(
|
||||||
in_features=self.inner_dim,
|
in_features=self.inner_dim,
|
||||||
out_features=self.embedding_dim,
|
out_features=self.embedding_dim,
|
||||||
|
@ -137,6 +162,7 @@ class SelfAttention(Attention):
|
||||||
num_heads: int = 1,
|
num_heads: int = 1,
|
||||||
use_bias: bool = True,
|
use_bias: bool = True,
|
||||||
is_causal: bool | None = None,
|
is_causal: bool | None = None,
|
||||||
|
is_optimized: bool = True,
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -146,6 +172,7 @@ class SelfAttention(Attention):
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
use_bias=use_bias,
|
use_bias=use_bias,
|
||||||
is_causal=is_causal,
|
is_causal=is_causal,
|
||||||
|
is_optimized=is_optimized,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
@ -153,7 +180,7 @@ class SelfAttention(Attention):
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention2d(SelfAttention):
|
class SelfAttention2d(SelfAttention):
|
||||||
structural_attrs = ["channels"]
|
structural_attrs = Attention.structural_attrs + ["channels"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -161,6 +188,7 @@ class SelfAttention2d(SelfAttention):
|
||||||
num_heads: int = 1,
|
num_heads: int = 1,
|
||||||
use_bias: bool = True,
|
use_bias: bool = True,
|
||||||
is_causal: bool | None = None,
|
is_causal: bool | None = None,
|
||||||
|
is_optimized: bool = True,
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -171,6 +199,7 @@ class SelfAttention2d(SelfAttention):
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
use_bias=use_bias,
|
use_bias=use_bias,
|
||||||
is_causal=is_causal,
|
is_causal=is_causal,
|
||||||
|
is_optimized=is_optimized,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue