(doc/fluxion/attention) add/convert docstrings to mkdocstrings format

This commit is contained in:
Laurent 2024-02-01 22:05:54 +00:00 committed by Laureηt
parent e3238a6af5
commit 0fc3264fae

View file

@ -3,7 +3,7 @@ import math
import torch import torch
from jaxtyping import Float from jaxtyping import Float
from torch import Tensor, device as Device, dtype as DType from torch import Tensor, device as Device, dtype as DType
from torch.nn.functional import scaled_dot_product_attention as _scaled_dot_product_attention # type: ignore from torch.nn.functional import scaled_dot_product_attention as _scaled_dot_product_attention
from refiners.fluxion.context import Contexts from refiners.fluxion.context import Contexts
from refiners.fluxion.layers.basics import Identity from refiners.fluxion.layers.basics import Identity
@ -18,19 +18,37 @@ def scaled_dot_product_attention(
value: Float[Tensor, "batch target_sequence_length dim"], value: Float[Tensor, "batch target_sequence_length dim"],
is_causal: bool = False, is_causal: bool = False,
) -> Float[Tensor, "batch source_sequence_length dim"]: ) -> Float[Tensor, "batch source_sequence_length dim"]:
return _scaled_dot_product_attention(query, key, value, is_causal=is_causal) # type: ignore """Scaled Dot Product Attention.
Optimization depends on which pytorch backend is used.
See [[arXiv:1706.03762] Attention Is All You Need (Equation 1)](https://arxiv.org/abs/1706.03762) for more details.
See also [torch.nn.functional.scaled_dot_product_attention](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
"""
return _scaled_dot_product_attention(
query=query,
key=key,
value=value,
is_causal=is_causal,
)
def sparse_dot_product_attention_non_optimized( def scaled_dot_product_attention_non_optimized(
query: Float[Tensor, "batch source_sequence_length dim"], query: Float[Tensor, "batch source_sequence_length dim"],
key: Float[Tensor, "batch target_sequence_length dim"], key: Float[Tensor, "batch target_sequence_length dim"],
value: Float[Tensor, "batch target_sequence_length dim"], value: Float[Tensor, "batch target_sequence_length dim"],
is_causal: bool = False, is_causal: bool = False,
) -> Float[Tensor, "batch source_sequence_length dim"]: ) -> Float[Tensor, "batch source_sequence_length dim"]:
"""Non-optimized Scaled Dot Product Attention.
See [[arXiv:1706.03762] Attention Is All You Need (Equation 1)](https://arxiv.org/abs/1706.03762) for more details.
"""
if is_causal: if is_causal:
# TODO: implement causal attention # TODO: implement causal attention
raise NotImplementedError("Causal attention for non_optimized attention is not yet implemented") raise NotImplementedError(
_, _, _, dim = query.shape "Causal attention for `scaled_dot_product_attention_non_optimized` is not yet implemented"
)
dim = query.shape[-1]
attention = query @ key.permute(0, 1, 3, 2) attention = query @ key.permute(0, 1, 3, 2)
attention = attention / math.sqrt(dim) attention = attention / math.sqrt(dim)
attention = torch.softmax(input=attention, dim=-1) attention = torch.softmax(input=attention, dim=-1)
@ -38,20 +56,58 @@ def sparse_dot_product_attention_non_optimized(
class ScaledDotProductAttention(Module): class ScaledDotProductAttention(Module):
"""Scaled Dot Product Attention.
??? note "See [[arXiv:1706.03762] Attention Is All You Need (Figure 2)](https://arxiv.org/abs/1706.03762) for more details"
![](https://ar5iv.labs.arxiv.org/html/1706.03762/assets/Figures/ModalNet-19.png)
Note:
This layer simply wraps `scaled_dot_product_attention` inside an `fl.Module`.
Receives:
Query (Float[Tensor, "batch num_queries embedding_dim"]):
Key (Float[Tensor, "batch num_keys embedding_dim"]):
Value (Float[Tensor, "batch num_values embedding_dim"]):
Returns:
(Float[Tensor, "batch num_queries embedding_dim"]):
Example:
```py
attention = fl.ScaledDotProductAttention(num_heads=8)
query = torch.randn(2, 10, 128)
key = torch.randn(2, 10, 128)
value = torch.randn(2, 10, 128)
output = attention(query, key, value)
assert output.shape == (2, 10, 128)
```
"""
def __init__( def __init__(
self, self,
num_heads: int = 1, num_heads: int = 1,
is_causal: bool | None = None, is_causal: bool = False,
is_optimized: bool = True, is_optimized: bool = True,
slice_size: int | None = None, slice_size: int | None = None,
) -> None: ) -> None:
"""Initialize the Scaled Dot Product Attention layer.
Args:
num_heads: The number of heads to use.
is_causal: Whether to use causal attention.
is_optimized: Whether to use optimized attention.
slice_size: The slice size to use for the optimized attention.
"""
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
self.is_causal = is_causal self.is_causal = is_causal
self.is_optimized = is_optimized self.is_optimized = is_optimized
self.slice_size = slice_size self.slice_size = slice_size
self.dot_product = ( self.dot_product = (
scaled_dot_product_attention if self.is_optimized else sparse_dot_product_attention_non_optimized scaled_dot_product_attention if self.is_optimized else scaled_dot_product_attention_non_optimized
) )
def forward( def forward(
@ -59,12 +115,20 @@ class ScaledDotProductAttention(Module):
query: Float[Tensor, "batch num_queries embedding_dim"], query: Float[Tensor, "batch num_queries embedding_dim"],
key: Float[Tensor, "batch num_keys embedding_dim"], key: Float[Tensor, "batch num_keys embedding_dim"],
value: Float[Tensor, "batch num_values embedding_dim"], value: Float[Tensor, "batch num_values embedding_dim"],
is_causal: bool | None = None, ) -> Float[Tensor, "batch num_queries embedding_dim"]:
) -> Float[Tensor, "batch num_queries dim"]: if self.slice_size:
if self.slice_size is None: return self._sliced_attention(
return self._process_attention(query, key, value, is_causal) query=query,
key=key,
return self._sliced_attention(query, key, value, is_causal=is_causal, slice_size=self.slice_size) value=value,
slice_size=self.slice_size,
)
else:
return self._process_attention(
query=query,
key=key,
value=value,
)
def _sliced_attention( def _sliced_attention(
self, self,
@ -72,14 +136,19 @@ class ScaledDotProductAttention(Module):
key: Float[Tensor, "batch num_keys embedding_dim"], key: Float[Tensor, "batch num_keys embedding_dim"],
value: Float[Tensor, "batch num_values embedding_dim"], value: Float[Tensor, "batch num_values embedding_dim"],
slice_size: int, slice_size: int,
is_causal: bool | None = None, ) -> Float[Tensor, "batch num_queries embedding_dim"]:
) -> Float[Tensor, "batch num_queries dim"]: """Compute the scaled dot product attention in slices.
This is useful when the input tensors are too large to be processed in one go.
"""
_, num_queries, _ = query.shape _, num_queries, _ = query.shape
output = torch.zeros_like(query) output = torch.zeros_like(query)
for start_idx in range(0, num_queries, slice_size): for start_idx in range(0, num_queries, slice_size):
end_idx = min(start_idx + slice_size, num_queries) end_idx = min(start_idx + slice_size, num_queries)
output[:, start_idx:end_idx, :] = self._process_attention( output[:, start_idx:end_idx, :] = self._process_attention(
query[:, start_idx:end_idx, :], key, value, is_causal query=query[:, start_idx:end_idx, :],
key=key,
value=value,
) )
return output return output
@ -88,37 +157,84 @@ class ScaledDotProductAttention(Module):
query: Float[Tensor, "batch num_queries embedding_dim"], query: Float[Tensor, "batch num_queries embedding_dim"],
key: Float[Tensor, "batch num_keys embedding_dim"], key: Float[Tensor, "batch num_keys embedding_dim"],
value: Float[Tensor, "batch num_values embedding_dim"], value: Float[Tensor, "batch num_values embedding_dim"],
is_causal: bool | None = None, ) -> Float[Tensor, "batch num_queries embedding_dim"]:
) -> Float[Tensor, "batch num_queries dim"]: """Compute the scaled dot product attention.
return self.merge_multi_head(
Split the input tensors (query, key, value) into multiple heads along the embedding dimension,
then compute the scaled dot product attention for each head, and finally merge the heads back.
"""
return self._merge_multi_head(
x=self.dot_product( x=self.dot_product(
query=self.split_to_multi_head(query), query=self._split_to_multi_head(query),
key=self.split_to_multi_head(key), key=self._split_to_multi_head(key),
value=self.split_to_multi_head(value), value=self._split_to_multi_head(value),
is_causal=( is_causal=self.is_causal,
is_causal if is_causal is not None else (self.is_causal if self.is_causal is not None else False)
),
) )
) )
def split_to_multi_head( def _split_to_multi_head(
self, x: Float[Tensor, "batch_size sequence_length embedding_dim"] self,
x: Float[Tensor, "batch_size sequence_length embedding_dim"],
) -> Float[Tensor, "batch_size num_heads sequence_length (embedding_dim//num_heads)"]: ) -> Float[Tensor, "batch_size num_heads sequence_length (embedding_dim//num_heads)"]:
"""Split the input tensor into multiple heads along the embedding dimension.
See also `merge_multi_head`, which is the inverse operation.
"""
assert ( assert (
len(x.shape) == 3 x.ndim == 3
), f"Expected tensor with shape (batch_size sequence_length embedding_dim), got {x.shape}" ), f"Expected input tensor with shape (batch_size sequence_length embedding_dim), got {x.shape}"
assert ( assert (
x.shape[-1] % self.num_heads == 0 x.shape[-1] % self.num_heads == 0
), f"Embedding dim (x.shape[-1]={x.shape[-1]}) must be divisible by num heads" ), f"Expected embedding_dim (x.shape[-1]={x.shape[-1]}) to be divisible by num_heads ({self.num_heads})"
return x.reshape(x.shape[0], x.shape[1], self.num_heads, x.shape[-1] // self.num_heads).transpose(1, 2) return x.reshape(x.shape[0], x.shape[1], self.num_heads, x.shape[-1] // self.num_heads).transpose(1, 2)
def merge_multi_head( def _merge_multi_head(
self, x: Float[Tensor, "batch_size num_heads sequence_length heads_dim"] self,
x: Float[Tensor, "batch_size num_heads sequence_length heads_dim"],
) -> Float[Tensor, "batch_size sequence_length heads_dim * num_heads"]: ) -> Float[Tensor, "batch_size sequence_length heads_dim * num_heads"]:
"""Merge the input tensor from multiple heads along the embedding dimension.
See also `split_to_multi_head`, which is the inverse operation.
"""
return x.transpose(1, 2).reshape(x.shape[0], x.shape[2], self.num_heads * x.shape[-1]) return x.transpose(1, 2).reshape(x.shape[0], x.shape[2], self.num_heads * x.shape[-1])
class Attention(Chain): class Attention(Chain):
"""Multi-Head Attention layer.
??? note "See [[arXiv:1706.03762] Attention Is All You Need (Figure 2)](https://arxiv.org/abs/1706.03762) for more details"
![](https://ar5iv.labs.arxiv.org/html/1706.03762/assets/Figures/ModalNet-20.png)
Note: This layer simply chains
- a [`Distribute`][refiners.fluxion.layers.chain.Distribute] layer,
containing 3 [`Linear`][refiners.fluxion.layers.linear.Linear] layers,
which transforms the 3 inputs into Query, Key and Value
- a [`ScaledDotProductAttention`][refiners.fluxion.layers.attentions.ScaledDotProductAttention] layer
- a [`Linear`][refiners.fluxion.layers.linear.Linear] layer,
which further transforms the output of the
[`ScaledDotProductAttention`][refiners.fluxion.layers.attentions.ScaledDotProductAttention] layer
Receives:
Query (Float[Tensor, "batch sequence_length embedding_dim"]):
Key (Float[Tensor, "batch sequence_length embedding_dim"]):
Value (Float[Tensor, "batch sequence_length embedding_dim"]):
Returns:
(Float[Tensor, "batch sequence_length embedding_dim"]):
Example:
```py
attention = fl.Attention(num_heads=8, embedding_dim=128)
tensor = torch.randn(2, 10, 128)
output = attention(tensor, tensor, tensor)
assert output.shape == (2, 10, 128)
```
"""
def __init__( def __init__(
self, self,
embedding_dim: int, embedding_dim: int,
@ -127,11 +243,25 @@ class Attention(Chain):
value_embedding_dim: int | None = None, value_embedding_dim: int | None = None,
inner_dim: int | None = None, inner_dim: int | None = None,
use_bias: bool = True, use_bias: bool = True,
is_causal: bool | None = None, is_causal: bool = False,
is_optimized: bool = True, is_optimized: bool = True,
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
) -> None: ) -> None:
"""Initialize the Attention layer.
Args:
embedding_dim: The embedding dimension of the input and output tensors.
num_heads: The number of heads to use.
key_embedding_dim: The embedding dimension of the key tensor.
value_embedding_dim: The embedding dimension of the value tensor.
inner_dim: The inner dimension of the linear layers.
use_bias: Whether to use bias in the linear layers.
is_causal: Whether to use causal attention.
is_optimized: Whether to use optimized attention.
device: The device to use.
dtype: The dtype to use.
"""
assert ( assert (
embedding_dim % num_heads == 0 embedding_dim % num_heads == 0
), f"embedding_dim {embedding_dim} must be divisible by num_heads {num_heads}" ), f"embedding_dim {embedding_dim} must be divisible by num_heads {num_heads}"
@ -144,23 +274,24 @@ class Attention(Chain):
self.use_bias = use_bias self.use_bias = use_bias
self.is_causal = is_causal self.is_causal = is_causal
self.is_optimized = is_optimized self.is_optimized = is_optimized
super().__init__( super().__init__(
Distribute( Distribute(
Linear( Linear( # Query projection
in_features=self.embedding_dim, in_features=self.embedding_dim,
out_features=self.inner_dim, out_features=self.inner_dim,
bias=self.use_bias, bias=self.use_bias,
device=device, device=device,
dtype=dtype, dtype=dtype,
), ),
Linear( Linear( # Key projection
in_features=self.key_embedding_dim, in_features=self.key_embedding_dim,
out_features=self.inner_dim, out_features=self.inner_dim,
bias=self.use_bias, bias=self.use_bias,
device=device, device=device,
dtype=dtype, dtype=dtype,
), ),
Linear( Linear( # Value projection
in_features=self.value_embedding_dim, in_features=self.value_embedding_dim,
out_features=self.inner_dim, out_features=self.inner_dim,
bias=self.use_bias, bias=self.use_bias,
@ -168,8 +299,12 @@ class Attention(Chain):
dtype=dtype, dtype=dtype,
), ),
), ),
ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal, is_optimized=is_optimized), ScaledDotProductAttention(
Linear( num_heads=num_heads,
is_causal=is_causal,
is_optimized=is_optimized,
),
Linear( # Output projection
in_features=self.inner_dim, in_features=self.inner_dim,
out_features=self.embedding_dim, out_features=self.embedding_dim,
bias=True, bias=True,
@ -180,17 +315,54 @@ class Attention(Chain):
class SelfAttention(Attention): class SelfAttention(Attention):
"""Multi-Head Self-Attention layer.
Note: This layer simply chains
- a [`Parallel`][refiners.fluxion.layers.chain.Parallel] layer,
which duplicates the input Tensor
(for each Linear layer in the `Attention` layer)
- an [`Attention`][refiners.fluxion.layers.attentions.Attention] layer
Receives:
(Float[Tensor, "batch sequence_length embedding_dim"]):
Returns:
(Float[Tensor, "batch sequence_length embedding_dim"]):
Example:
```py
self_attention = fl.SelfAttention(num_heads=8, embedding_dim=128)
tensor = torch.randn(2, 10, 128)
output = self_attention(tensor)
assert output.shape == (2, 10, 128)
```
"""
def __init__( def __init__(
self, self,
embedding_dim: int, embedding_dim: int,
inner_dim: int | None = None, inner_dim: int | None = None,
num_heads: int = 1, num_heads: int = 1,
use_bias: bool = True, use_bias: bool = True,
is_causal: bool | None = None, is_causal: bool = False,
is_optimized: bool = True, is_optimized: bool = True,
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
) -> None: ) -> None:
"""Initialize the Self-Attention layer.
Args:
embedding_dim: The embedding dimension of the input and output tensors.
inner_dim: The inner dimension of the linear layers.
num_heads: The number of heads to use.
use_bias: Whether to use bias in the linear layers.
is_causal: Whether to use causal attention.
is_optimized: Whether to use optimized attention.
device: The device to use.
dtype: The dtype to use.
"""
super().__init__( super().__init__(
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
inner_dim=inner_dim, inner_dim=inner_dim,
@ -201,22 +373,67 @@ class SelfAttention(Attention):
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
self.insert(0, Parallel(Identity(), Identity(), Identity())) self.insert(
index=0,
module=Parallel(
Identity(), # Query projection's input
Identity(), # Key projection's input
Identity(), # Value projection's input
),
)
class SelfAttention2d(SelfAttention): class SelfAttention2d(SelfAttention):
"""Multi-Head 2D Self-Attention layer.
Note: This Module simply chains
- a [`Lambda`][refiners.fluxion.layers.chain.Lambda] layer,
which transforms the input Tensor into a sequence
- a [`SelfAttention`][refiners.fluxion.layers.attentions.SelfAttention] layer
- a [`Lambda`][refiners.fluxion.layers.chain.Lambda] layer,
which transforms the output sequence into a 2D Tensor
Receives:
(Float[Tensor, "batch channels height width"]):
Returns:
(Float[Tensor, "batch channels height width"]):
Example:
```py
self_attention = fl.SelfAttention2d(channels=128, num_heads=8)
tensor = torch.randn(2, 128, 64, 64)
output = self_attention(tensor)
assert output.shape == (2, 128, 64, 64)
```
"""
def __init__( def __init__(
self, self,
channels: int, channels: int,
num_heads: int = 1, num_heads: int = 1,
use_bias: bool = True, use_bias: bool = True,
is_causal: bool | None = None, is_causal: bool = False,
is_optimized: bool = True, is_optimized: bool = True,
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
) -> None: ) -> None:
"""Initialize the 2D Self-Attention layer.
Args:
channels: The number of channels of the input and output tensors.
num_heads: The number of heads to use.
use_bias: Whether to use bias in the linear layers.
is_causal: Whether to use causal attention.
is_optimized: Whether to use optimized attention.
device: The device to use.
dtype: The dtype to use.
"""
assert channels % num_heads == 0, f"channels {channels} must be divisible by num_heads {num_heads}" assert channels % num_heads == 0, f"channels {channels} must be divisible by num_heads {num_heads}"
self.channels = channels self.channels = channels
super().__init__( super().__init__(
embedding_dim=channels, embedding_dim=channels,
num_heads=num_heads, num_heads=num_heads,
@ -226,21 +443,45 @@ class SelfAttention2d(SelfAttention):
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
self.insert(0, Lambda(self.tensor_2d_to_sequence))
self.append(Lambda(self.sequence_to_tensor_2d)) self.insert(0, Lambda(self._tensor_2d_to_sequence))
self.append(Lambda(self._sequence_to_tensor_2d))
def init_context(self) -> Contexts: def init_context(self) -> Contexts:
return {"reshape": {"height": None, "width": None}} return {
"reshape": {
"height": None,
"width": None,
}
}
def tensor_2d_to_sequence( def _tensor_2d_to_sequence(
self, x: Float[Tensor, "batch channels height width"] self,
x: Float[Tensor, "batch channels height width"],
) -> Float[Tensor, "batch height*width channels"]: ) -> Float[Tensor, "batch height*width channels"]:
"""Transform a 2D Tensor into a sequence.
The height and width of the input Tensor are stored in the context,
so that the output Tensor can be transformed back into a 2D Tensor in the `sequence_to_tensor_2d` method.
"""
height, width = x.shape[-2:] height, width = x.shape[-2:]
self.set_context(context="reshape", value={"height": height, "width": width}) self.set_context(
context="reshape",
value={
"height": height,
"width": width,
},
)
return x.reshape(x.shape[0], x.shape[1], height * width).transpose(1, 2) return x.reshape(x.shape[0], x.shape[1], height * width).transpose(1, 2)
def sequence_to_tensor_2d( def _sequence_to_tensor_2d(
self, x: Float[Tensor, "batch sequence_length channels"] self,
x: Float[Tensor, "batch sequence_length channels"],
) -> Float[Tensor, "batch channels height width"]: ) -> Float[Tensor, "batch channels height width"]:
"""Transform a sequence into a 2D Tensor.
The height and width of the output Tensor are retrieved from the context,
which was set in the `tensor_2d_to_sequence` method.
"""
height, width = self.use_context("reshape").values() height, width = self.use_context("reshape").values()
return x.transpose(1, 2).reshape(x.shape[0], x.shape[2], height, width) return x.transpose(1, 2).reshape(x.shape[0], x.shape[2], height, width)