mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
add is_optimized option for attention
This commit is contained in:
parent
fc2390ad1c
commit
121ef4df39
|
@ -1,3 +1,5 @@
|
|||
import math
|
||||
import torch
|
||||
from jaxtyping import Float
|
||||
from torch.nn.functional import scaled_dot_product_attention as _scaled_dot_product_attention # type: ignore
|
||||
from torch import Tensor, device as Device, dtype as DType
|
||||
|
@ -18,11 +20,31 @@ def scaled_dot_product_attention(
|
|||
return _scaled_dot_product_attention(query, key, value, is_causal=is_causal) # type: ignore
|
||||
|
||||
|
||||
def sparse_dot_product_attention_non_optimized(
|
||||
query: Float[Tensor, "batch source_sequence_length dim"],
|
||||
key: Float[Tensor, "batch target_sequence_length dim"],
|
||||
value: Float[Tensor, "batch target_sequence_length dim"],
|
||||
is_causal: bool = False,
|
||||
) -> Float[Tensor, "batch source_sequence_length dim"]:
|
||||
if is_causal:
|
||||
# TODO: implement causal attention
|
||||
raise NotImplementedError("Causal attention for non_optimized attention is not yet implemented")
|
||||
_, _, _, dim = query.shape
|
||||
attention = query @ key.permute(0, 1, 3, 2)
|
||||
attention = attention / math.sqrt(dim)
|
||||
attention = torch.softmax(input=attention, dim=-1)
|
||||
return attention @ value
|
||||
|
||||
|
||||
class ScaledDotProductAttention(Module):
|
||||
def __init__(self, num_heads: int = 1, is_causal: bool | None = None) -> None:
|
||||
def __init__(self, num_heads: int = 1, is_causal: bool | None = None, is_optimized: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.is_causal = is_causal
|
||||
self.is_optimized = is_optimized
|
||||
self.dot_product = (
|
||||
scaled_dot_product_attention if self.is_optimized else sparse_dot_product_attention_non_optimized
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -32,7 +54,7 @@ class ScaledDotProductAttention(Module):
|
|||
is_causal: bool | None = None,
|
||||
) -> Float[Tensor, "batch num_queries dim"]:
|
||||
return self.merge_multi_head(
|
||||
scaled_dot_product_attention(
|
||||
x=self.dot_product(
|
||||
query=self.split_to_multi_head(query),
|
||||
key=self.split_to_multi_head(key),
|
||||
value=self.split_to_multi_head(value),
|
||||
|
@ -69,6 +91,7 @@ class Attention(Chain):
|
|||
"inner_dim",
|
||||
"use_bias",
|
||||
"is_causal",
|
||||
"is_optimized",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
|
@ -80,6 +103,7 @@ class Attention(Chain):
|
|||
inner_dim: int | None = None,
|
||||
use_bias: bool = True,
|
||||
is_causal: bool | None = None,
|
||||
is_optimized: bool = True,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
|
@ -94,6 +118,7 @@ class Attention(Chain):
|
|||
self.inner_dim = inner_dim or embedding_dim
|
||||
self.use_bias = use_bias
|
||||
self.is_causal = is_causal
|
||||
self.is_optimized = is_optimized
|
||||
super().__init__(
|
||||
Distribute(
|
||||
Linear(
|
||||
|
@ -118,7 +143,7 @@ class Attention(Chain):
|
|||
dtype=dtype,
|
||||
),
|
||||
),
|
||||
ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal),
|
||||
ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal, is_optimized=is_optimized),
|
||||
Linear(
|
||||
in_features=self.inner_dim,
|
||||
out_features=self.embedding_dim,
|
||||
|
@ -137,6 +162,7 @@ class SelfAttention(Attention):
|
|||
num_heads: int = 1,
|
||||
use_bias: bool = True,
|
||||
is_causal: bool | None = None,
|
||||
is_optimized: bool = True,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
|
@ -146,6 +172,7 @@ class SelfAttention(Attention):
|
|||
num_heads=num_heads,
|
||||
use_bias=use_bias,
|
||||
is_causal=is_causal,
|
||||
is_optimized=is_optimized,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -153,7 +180,7 @@ class SelfAttention(Attention):
|
|||
|
||||
|
||||
class SelfAttention2d(SelfAttention):
|
||||
structural_attrs = ["channels"]
|
||||
structural_attrs = Attention.structural_attrs + ["channels"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -161,6 +188,7 @@ class SelfAttention2d(SelfAttention):
|
|||
num_heads: int = 1,
|
||||
use_bias: bool = True,
|
||||
is_causal: bool | None = None,
|
||||
is_optimized: bool = True,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
|
@ -171,6 +199,7 @@ class SelfAttention2d(SelfAttention):
|
|||
num_heads=num_heads,
|
||||
use_bias=use_bias,
|
||||
is_causal=is_causal,
|
||||
is_optimized=is_optimized,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue