add is_optimized option for attention

This commit is contained in:
Benjamin Trom 2023-09-14 11:50:33 +02:00
parent fc2390ad1c
commit 121ef4df39

View file

@ -1,3 +1,5 @@
import math
import torch
from jaxtyping import Float from jaxtyping import Float
from torch.nn.functional import scaled_dot_product_attention as _scaled_dot_product_attention # type: ignore from torch.nn.functional import scaled_dot_product_attention as _scaled_dot_product_attention # type: ignore
from torch import Tensor, device as Device, dtype as DType from torch import Tensor, device as Device, dtype as DType
@ -18,11 +20,31 @@ def scaled_dot_product_attention(
return _scaled_dot_product_attention(query, key, value, is_causal=is_causal) # type: ignore return _scaled_dot_product_attention(query, key, value, is_causal=is_causal) # type: ignore
def sparse_dot_product_attention_non_optimized(
query: Float[Tensor, "batch source_sequence_length dim"],
key: Float[Tensor, "batch target_sequence_length dim"],
value: Float[Tensor, "batch target_sequence_length dim"],
is_causal: bool = False,
) -> Float[Tensor, "batch source_sequence_length dim"]:
if is_causal:
# TODO: implement causal attention
raise NotImplementedError("Causal attention for non_optimized attention is not yet implemented")
_, _, _, dim = query.shape
attention = query @ key.permute(0, 1, 3, 2)
attention = attention / math.sqrt(dim)
attention = torch.softmax(input=attention, dim=-1)
return attention @ value
class ScaledDotProductAttention(Module): class ScaledDotProductAttention(Module):
def __init__(self, num_heads: int = 1, is_causal: bool | None = None) -> None: def __init__(self, num_heads: int = 1, is_causal: bool | None = None, is_optimized: bool = True) -> None:
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
self.is_causal = is_causal self.is_causal = is_causal
self.is_optimized = is_optimized
self.dot_product = (
scaled_dot_product_attention if self.is_optimized else sparse_dot_product_attention_non_optimized
)
def forward( def forward(
self, self,
@ -32,7 +54,7 @@ class ScaledDotProductAttention(Module):
is_causal: bool | None = None, is_causal: bool | None = None,
) -> Float[Tensor, "batch num_queries dim"]: ) -> Float[Tensor, "batch num_queries dim"]:
return self.merge_multi_head( return self.merge_multi_head(
scaled_dot_product_attention( x=self.dot_product(
query=self.split_to_multi_head(query), query=self.split_to_multi_head(query),
key=self.split_to_multi_head(key), key=self.split_to_multi_head(key),
value=self.split_to_multi_head(value), value=self.split_to_multi_head(value),
@ -69,6 +91,7 @@ class Attention(Chain):
"inner_dim", "inner_dim",
"use_bias", "use_bias",
"is_causal", "is_causal",
"is_optimized",
] ]
def __init__( def __init__(
@ -80,6 +103,7 @@ class Attention(Chain):
inner_dim: int | None = None, inner_dim: int | None = None,
use_bias: bool = True, use_bias: bool = True,
is_causal: bool | None = None, is_causal: bool | None = None,
is_optimized: bool = True,
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
) -> None: ) -> None:
@ -94,6 +118,7 @@ class Attention(Chain):
self.inner_dim = inner_dim or embedding_dim self.inner_dim = inner_dim or embedding_dim
self.use_bias = use_bias self.use_bias = use_bias
self.is_causal = is_causal self.is_causal = is_causal
self.is_optimized = is_optimized
super().__init__( super().__init__(
Distribute( Distribute(
Linear( Linear(
@ -118,7 +143,7 @@ class Attention(Chain):
dtype=dtype, dtype=dtype,
), ),
), ),
ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal), ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal, is_optimized=is_optimized),
Linear( Linear(
in_features=self.inner_dim, in_features=self.inner_dim,
out_features=self.embedding_dim, out_features=self.embedding_dim,
@ -137,6 +162,7 @@ class SelfAttention(Attention):
num_heads: int = 1, num_heads: int = 1,
use_bias: bool = True, use_bias: bool = True,
is_causal: bool | None = None, is_causal: bool | None = None,
is_optimized: bool = True,
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
) -> None: ) -> None:
@ -146,6 +172,7 @@ class SelfAttention(Attention):
num_heads=num_heads, num_heads=num_heads,
use_bias=use_bias, use_bias=use_bias,
is_causal=is_causal, is_causal=is_causal,
is_optimized=is_optimized,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
@ -153,7 +180,7 @@ class SelfAttention(Attention):
class SelfAttention2d(SelfAttention): class SelfAttention2d(SelfAttention):
structural_attrs = ["channels"] structural_attrs = Attention.structural_attrs + ["channels"]
def __init__( def __init__(
self, self,
@ -161,6 +188,7 @@ class SelfAttention2d(SelfAttention):
num_heads: int = 1, num_heads: int = 1,
use_bias: bool = True, use_bias: bool = True,
is_causal: bool | None = None, is_causal: bool | None = None,
is_optimized: bool = True,
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
) -> None: ) -> None:
@ -171,6 +199,7 @@ class SelfAttention2d(SelfAttention):
num_heads=num_heads, num_heads=num_heads,
use_bias=use_bias, use_bias=use_bias,
is_causal=is_causal, is_causal=is_causal,
is_optimized=is_optimized,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )