diff --git a/src/refiners/fluxion/layers/attentions.py b/src/refiners/fluxion/layers/attentions.py index 8676bff..3d356b0 100644 --- a/src/refiners/fluxion/layers/attentions.py +++ b/src/refiners/fluxion/layers/attentions.py @@ -1,3 +1,5 @@ +import math +import torch from jaxtyping import Float from torch.nn.functional import scaled_dot_product_attention as _scaled_dot_product_attention # type: ignore from torch import Tensor, device as Device, dtype as DType @@ -18,11 +20,31 @@ def scaled_dot_product_attention( return _scaled_dot_product_attention(query, key, value, is_causal=is_causal) # type: ignore +def sparse_dot_product_attention_non_optimized( + query: Float[Tensor, "batch source_sequence_length dim"], + key: Float[Tensor, "batch target_sequence_length dim"], + value: Float[Tensor, "batch target_sequence_length dim"], + is_causal: bool = False, +) -> Float[Tensor, "batch source_sequence_length dim"]: + if is_causal: + # TODO: implement causal attention + raise NotImplementedError("Causal attention for non_optimized attention is not yet implemented") + _, _, _, dim = query.shape + attention = query @ key.permute(0, 1, 3, 2) + attention = attention / math.sqrt(dim) + attention = torch.softmax(input=attention, dim=-1) + return attention @ value + + class ScaledDotProductAttention(Module): - def __init__(self, num_heads: int = 1, is_causal: bool | None = None) -> None: + def __init__(self, num_heads: int = 1, is_causal: bool | None = None, is_optimized: bool = True) -> None: super().__init__() self.num_heads = num_heads self.is_causal = is_causal + self.is_optimized = is_optimized + self.dot_product = ( + scaled_dot_product_attention if self.is_optimized else sparse_dot_product_attention_non_optimized + ) def forward( self, @@ -32,7 +54,7 @@ class ScaledDotProductAttention(Module): is_causal: bool | None = None, ) -> Float[Tensor, "batch num_queries dim"]: return self.merge_multi_head( - scaled_dot_product_attention( + x=self.dot_product( query=self.split_to_multi_head(query), key=self.split_to_multi_head(key), value=self.split_to_multi_head(value), @@ -69,6 +91,7 @@ class Attention(Chain): "inner_dim", "use_bias", "is_causal", + "is_optimized", ] def __init__( @@ -80,6 +103,7 @@ class Attention(Chain): inner_dim: int | None = None, use_bias: bool = True, is_causal: bool | None = None, + is_optimized: bool = True, device: Device | str | None = None, dtype: DType | None = None, ) -> None: @@ -94,6 +118,7 @@ class Attention(Chain): self.inner_dim = inner_dim or embedding_dim self.use_bias = use_bias self.is_causal = is_causal + self.is_optimized = is_optimized super().__init__( Distribute( Linear( @@ -118,7 +143,7 @@ class Attention(Chain): dtype=dtype, ), ), - ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal), + ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal, is_optimized=is_optimized), Linear( in_features=self.inner_dim, out_features=self.embedding_dim, @@ -137,6 +162,7 @@ class SelfAttention(Attention): num_heads: int = 1, use_bias: bool = True, is_causal: bool | None = None, + is_optimized: bool = True, device: Device | str | None = None, dtype: DType | None = None, ) -> None: @@ -146,6 +172,7 @@ class SelfAttention(Attention): num_heads=num_heads, use_bias=use_bias, is_causal=is_causal, + is_optimized=is_optimized, device=device, dtype=dtype, ) @@ -153,7 +180,7 @@ class SelfAttention(Attention): class SelfAttention2d(SelfAttention): - structural_attrs = ["channels"] + structural_attrs = Attention.structural_attrs + ["channels"] def __init__( self, @@ -161,6 +188,7 @@ class SelfAttention2d(SelfAttention): num_heads: int = 1, use_bias: bool = True, is_causal: bool | None = None, + is_optimized: bool = True, device: Device | str | None = None, dtype: DType | None = None, ) -> None: @@ -171,6 +199,7 @@ class SelfAttention2d(SelfAttention): num_heads=num_heads, use_bias=use_bias, is_causal=is_causal, + is_optimized=is_optimized, device=device, dtype=dtype, )