mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
(doc/fluxion/attention) add/convert docstrings to mkdocstrings format
This commit is contained in:
parent
e3238a6af5
commit
0fc3264fae
|
@ -3,7 +3,7 @@ import math
|
||||||
import torch
|
import torch
|
||||||
from jaxtyping import Float
|
from jaxtyping import Float
|
||||||
from torch import Tensor, device as Device, dtype as DType
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
from torch.nn.functional import scaled_dot_product_attention as _scaled_dot_product_attention # type: ignore
|
from torch.nn.functional import scaled_dot_product_attention as _scaled_dot_product_attention
|
||||||
|
|
||||||
from refiners.fluxion.context import Contexts
|
from refiners.fluxion.context import Contexts
|
||||||
from refiners.fluxion.layers.basics import Identity
|
from refiners.fluxion.layers.basics import Identity
|
||||||
|
@ -18,19 +18,37 @@ def scaled_dot_product_attention(
|
||||||
value: Float[Tensor, "batch target_sequence_length dim"],
|
value: Float[Tensor, "batch target_sequence_length dim"],
|
||||||
is_causal: bool = False,
|
is_causal: bool = False,
|
||||||
) -> Float[Tensor, "batch source_sequence_length dim"]:
|
) -> Float[Tensor, "batch source_sequence_length dim"]:
|
||||||
return _scaled_dot_product_attention(query, key, value, is_causal=is_causal) # type: ignore
|
"""Scaled Dot Product Attention.
|
||||||
|
|
||||||
|
Optimization depends on which pytorch backend is used.
|
||||||
|
See [[arXiv:1706.03762] Attention Is All You Need (Equation 1)](https://arxiv.org/abs/1706.03762) for more details.
|
||||||
|
See also [torch.nn.functional.scaled_dot_product_attention](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
||||||
|
"""
|
||||||
|
return _scaled_dot_product_attention(
|
||||||
|
query=query,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
is_causal=is_causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def sparse_dot_product_attention_non_optimized(
|
def scaled_dot_product_attention_non_optimized(
|
||||||
query: Float[Tensor, "batch source_sequence_length dim"],
|
query: Float[Tensor, "batch source_sequence_length dim"],
|
||||||
key: Float[Tensor, "batch target_sequence_length dim"],
|
key: Float[Tensor, "batch target_sequence_length dim"],
|
||||||
value: Float[Tensor, "batch target_sequence_length dim"],
|
value: Float[Tensor, "batch target_sequence_length dim"],
|
||||||
is_causal: bool = False,
|
is_causal: bool = False,
|
||||||
) -> Float[Tensor, "batch source_sequence_length dim"]:
|
) -> Float[Tensor, "batch source_sequence_length dim"]:
|
||||||
|
"""Non-optimized Scaled Dot Product Attention.
|
||||||
|
|
||||||
|
See [[arXiv:1706.03762] Attention Is All You Need (Equation 1)](https://arxiv.org/abs/1706.03762) for more details.
|
||||||
|
"""
|
||||||
if is_causal:
|
if is_causal:
|
||||||
# TODO: implement causal attention
|
# TODO: implement causal attention
|
||||||
raise NotImplementedError("Causal attention for non_optimized attention is not yet implemented")
|
raise NotImplementedError(
|
||||||
_, _, _, dim = query.shape
|
"Causal attention for `scaled_dot_product_attention_non_optimized` is not yet implemented"
|
||||||
|
)
|
||||||
|
|
||||||
|
dim = query.shape[-1]
|
||||||
attention = query @ key.permute(0, 1, 3, 2)
|
attention = query @ key.permute(0, 1, 3, 2)
|
||||||
attention = attention / math.sqrt(dim)
|
attention = attention / math.sqrt(dim)
|
||||||
attention = torch.softmax(input=attention, dim=-1)
|
attention = torch.softmax(input=attention, dim=-1)
|
||||||
|
@ -38,20 +56,58 @@ def sparse_dot_product_attention_non_optimized(
|
||||||
|
|
||||||
|
|
||||||
class ScaledDotProductAttention(Module):
|
class ScaledDotProductAttention(Module):
|
||||||
|
"""Scaled Dot Product Attention.
|
||||||
|
|
||||||
|
??? note "See [[arXiv:1706.03762] Attention Is All You Need (Figure 2)](https://arxiv.org/abs/1706.03762) for more details"
|
||||||
|
|
||||||
|
![](https://ar5iv.labs.arxiv.org/html/1706.03762/assets/Figures/ModalNet-19.png)
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This layer simply wraps `scaled_dot_product_attention` inside an `fl.Module`.
|
||||||
|
|
||||||
|
Receives:
|
||||||
|
Query (Float[Tensor, "batch num_queries embedding_dim"]):
|
||||||
|
Key (Float[Tensor, "batch num_keys embedding_dim"]):
|
||||||
|
Value (Float[Tensor, "batch num_values embedding_dim"]):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Float[Tensor, "batch num_queries embedding_dim"]):
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
attention = fl.ScaledDotProductAttention(num_heads=8)
|
||||||
|
|
||||||
|
query = torch.randn(2, 10, 128)
|
||||||
|
key = torch.randn(2, 10, 128)
|
||||||
|
value = torch.randn(2, 10, 128)
|
||||||
|
output = attention(query, key, value)
|
||||||
|
|
||||||
|
assert output.shape == (2, 10, 128)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_heads: int = 1,
|
num_heads: int = 1,
|
||||||
is_causal: bool | None = None,
|
is_causal: bool = False,
|
||||||
is_optimized: bool = True,
|
is_optimized: bool = True,
|
||||||
slice_size: int | None = None,
|
slice_size: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initialize the Scaled Dot Product Attention layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_heads: The number of heads to use.
|
||||||
|
is_causal: Whether to use causal attention.
|
||||||
|
is_optimized: Whether to use optimized attention.
|
||||||
|
slice_size: The slice size to use for the optimized attention.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.is_causal = is_causal
|
self.is_causal = is_causal
|
||||||
self.is_optimized = is_optimized
|
self.is_optimized = is_optimized
|
||||||
self.slice_size = slice_size
|
self.slice_size = slice_size
|
||||||
self.dot_product = (
|
self.dot_product = (
|
||||||
scaled_dot_product_attention if self.is_optimized else sparse_dot_product_attention_non_optimized
|
scaled_dot_product_attention if self.is_optimized else scaled_dot_product_attention_non_optimized
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -59,12 +115,20 @@ class ScaledDotProductAttention(Module):
|
||||||
query: Float[Tensor, "batch num_queries embedding_dim"],
|
query: Float[Tensor, "batch num_queries embedding_dim"],
|
||||||
key: Float[Tensor, "batch num_keys embedding_dim"],
|
key: Float[Tensor, "batch num_keys embedding_dim"],
|
||||||
value: Float[Tensor, "batch num_values embedding_dim"],
|
value: Float[Tensor, "batch num_values embedding_dim"],
|
||||||
is_causal: bool | None = None,
|
) -> Float[Tensor, "batch num_queries embedding_dim"]:
|
||||||
) -> Float[Tensor, "batch num_queries dim"]:
|
if self.slice_size:
|
||||||
if self.slice_size is None:
|
return self._sliced_attention(
|
||||||
return self._process_attention(query, key, value, is_causal)
|
query=query,
|
||||||
|
key=key,
|
||||||
return self._sliced_attention(query, key, value, is_causal=is_causal, slice_size=self.slice_size)
|
value=value,
|
||||||
|
slice_size=self.slice_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self._process_attention(
|
||||||
|
query=query,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
)
|
||||||
|
|
||||||
def _sliced_attention(
|
def _sliced_attention(
|
||||||
self,
|
self,
|
||||||
|
@ -72,14 +136,19 @@ class ScaledDotProductAttention(Module):
|
||||||
key: Float[Tensor, "batch num_keys embedding_dim"],
|
key: Float[Tensor, "batch num_keys embedding_dim"],
|
||||||
value: Float[Tensor, "batch num_values embedding_dim"],
|
value: Float[Tensor, "batch num_values embedding_dim"],
|
||||||
slice_size: int,
|
slice_size: int,
|
||||||
is_causal: bool | None = None,
|
) -> Float[Tensor, "batch num_queries embedding_dim"]:
|
||||||
) -> Float[Tensor, "batch num_queries dim"]:
|
"""Compute the scaled dot product attention in slices.
|
||||||
|
|
||||||
|
This is useful when the input tensors are too large to be processed in one go.
|
||||||
|
"""
|
||||||
_, num_queries, _ = query.shape
|
_, num_queries, _ = query.shape
|
||||||
output = torch.zeros_like(query)
|
output = torch.zeros_like(query)
|
||||||
for start_idx in range(0, num_queries, slice_size):
|
for start_idx in range(0, num_queries, slice_size):
|
||||||
end_idx = min(start_idx + slice_size, num_queries)
|
end_idx = min(start_idx + slice_size, num_queries)
|
||||||
output[:, start_idx:end_idx, :] = self._process_attention(
|
output[:, start_idx:end_idx, :] = self._process_attention(
|
||||||
query[:, start_idx:end_idx, :], key, value, is_causal
|
query=query[:, start_idx:end_idx, :],
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -88,37 +157,84 @@ class ScaledDotProductAttention(Module):
|
||||||
query: Float[Tensor, "batch num_queries embedding_dim"],
|
query: Float[Tensor, "batch num_queries embedding_dim"],
|
||||||
key: Float[Tensor, "batch num_keys embedding_dim"],
|
key: Float[Tensor, "batch num_keys embedding_dim"],
|
||||||
value: Float[Tensor, "batch num_values embedding_dim"],
|
value: Float[Tensor, "batch num_values embedding_dim"],
|
||||||
is_causal: bool | None = None,
|
) -> Float[Tensor, "batch num_queries embedding_dim"]:
|
||||||
) -> Float[Tensor, "batch num_queries dim"]:
|
"""Compute the scaled dot product attention.
|
||||||
return self.merge_multi_head(
|
|
||||||
|
Split the input tensors (query, key, value) into multiple heads along the embedding dimension,
|
||||||
|
then compute the scaled dot product attention for each head, and finally merge the heads back.
|
||||||
|
"""
|
||||||
|
return self._merge_multi_head(
|
||||||
x=self.dot_product(
|
x=self.dot_product(
|
||||||
query=self.split_to_multi_head(query),
|
query=self._split_to_multi_head(query),
|
||||||
key=self.split_to_multi_head(key),
|
key=self._split_to_multi_head(key),
|
||||||
value=self.split_to_multi_head(value),
|
value=self._split_to_multi_head(value),
|
||||||
is_causal=(
|
is_causal=self.is_causal,
|
||||||
is_causal if is_causal is not None else (self.is_causal if self.is_causal is not None else False)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def split_to_multi_head(
|
def _split_to_multi_head(
|
||||||
self, x: Float[Tensor, "batch_size sequence_length embedding_dim"]
|
self,
|
||||||
|
x: Float[Tensor, "batch_size sequence_length embedding_dim"],
|
||||||
) -> Float[Tensor, "batch_size num_heads sequence_length (embedding_dim//num_heads)"]:
|
) -> Float[Tensor, "batch_size num_heads sequence_length (embedding_dim//num_heads)"]:
|
||||||
|
"""Split the input tensor into multiple heads along the embedding dimension.
|
||||||
|
|
||||||
|
See also `merge_multi_head`, which is the inverse operation.
|
||||||
|
"""
|
||||||
assert (
|
assert (
|
||||||
len(x.shape) == 3
|
x.ndim == 3
|
||||||
), f"Expected tensor with shape (batch_size sequence_length embedding_dim), got {x.shape}"
|
), f"Expected input tensor with shape (batch_size sequence_length embedding_dim), got {x.shape}"
|
||||||
assert (
|
assert (
|
||||||
x.shape[-1] % self.num_heads == 0
|
x.shape[-1] % self.num_heads == 0
|
||||||
), f"Embedding dim (x.shape[-1]={x.shape[-1]}) must be divisible by num heads"
|
), f"Expected embedding_dim (x.shape[-1]={x.shape[-1]}) to be divisible by num_heads ({self.num_heads})"
|
||||||
|
|
||||||
return x.reshape(x.shape[0], x.shape[1], self.num_heads, x.shape[-1] // self.num_heads).transpose(1, 2)
|
return x.reshape(x.shape[0], x.shape[1], self.num_heads, x.shape[-1] // self.num_heads).transpose(1, 2)
|
||||||
|
|
||||||
def merge_multi_head(
|
def _merge_multi_head(
|
||||||
self, x: Float[Tensor, "batch_size num_heads sequence_length heads_dim"]
|
self,
|
||||||
|
x: Float[Tensor, "batch_size num_heads sequence_length heads_dim"],
|
||||||
) -> Float[Tensor, "batch_size sequence_length heads_dim * num_heads"]:
|
) -> Float[Tensor, "batch_size sequence_length heads_dim * num_heads"]:
|
||||||
|
"""Merge the input tensor from multiple heads along the embedding dimension.
|
||||||
|
|
||||||
|
See also `split_to_multi_head`, which is the inverse operation.
|
||||||
|
"""
|
||||||
return x.transpose(1, 2).reshape(x.shape[0], x.shape[2], self.num_heads * x.shape[-1])
|
return x.transpose(1, 2).reshape(x.shape[0], x.shape[2], self.num_heads * x.shape[-1])
|
||||||
|
|
||||||
|
|
||||||
class Attention(Chain):
|
class Attention(Chain):
|
||||||
|
"""Multi-Head Attention layer.
|
||||||
|
|
||||||
|
??? note "See [[arXiv:1706.03762] Attention Is All You Need (Figure 2)](https://arxiv.org/abs/1706.03762) for more details"
|
||||||
|
|
||||||
|
![](https://ar5iv.labs.arxiv.org/html/1706.03762/assets/Figures/ModalNet-20.png)
|
||||||
|
|
||||||
|
Note: This layer simply chains
|
||||||
|
- a [`Distribute`][refiners.fluxion.layers.chain.Distribute] layer,
|
||||||
|
containing 3 [`Linear`][refiners.fluxion.layers.linear.Linear] layers,
|
||||||
|
which transforms the 3 inputs into Query, Key and Value
|
||||||
|
- a [`ScaledDotProductAttention`][refiners.fluxion.layers.attentions.ScaledDotProductAttention] layer
|
||||||
|
- a [`Linear`][refiners.fluxion.layers.linear.Linear] layer,
|
||||||
|
which further transforms the output of the
|
||||||
|
[`ScaledDotProductAttention`][refiners.fluxion.layers.attentions.ScaledDotProductAttention] layer
|
||||||
|
|
||||||
|
Receives:
|
||||||
|
Query (Float[Tensor, "batch sequence_length embedding_dim"]):
|
||||||
|
Key (Float[Tensor, "batch sequence_length embedding_dim"]):
|
||||||
|
Value (Float[Tensor, "batch sequence_length embedding_dim"]):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Float[Tensor, "batch sequence_length embedding_dim"]):
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
attention = fl.Attention(num_heads=8, embedding_dim=128)
|
||||||
|
|
||||||
|
tensor = torch.randn(2, 10, 128)
|
||||||
|
output = attention(tensor, tensor, tensor)
|
||||||
|
|
||||||
|
assert output.shape == (2, 10, 128)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
|
@ -127,11 +243,25 @@ class Attention(Chain):
|
||||||
value_embedding_dim: int | None = None,
|
value_embedding_dim: int | None = None,
|
||||||
inner_dim: int | None = None,
|
inner_dim: int | None = None,
|
||||||
use_bias: bool = True,
|
use_bias: bool = True,
|
||||||
is_causal: bool | None = None,
|
is_causal: bool = False,
|
||||||
is_optimized: bool = True,
|
is_optimized: bool = True,
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initialize the Attention layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_dim: The embedding dimension of the input and output tensors.
|
||||||
|
num_heads: The number of heads to use.
|
||||||
|
key_embedding_dim: The embedding dimension of the key tensor.
|
||||||
|
value_embedding_dim: The embedding dimension of the value tensor.
|
||||||
|
inner_dim: The inner dimension of the linear layers.
|
||||||
|
use_bias: Whether to use bias in the linear layers.
|
||||||
|
is_causal: Whether to use causal attention.
|
||||||
|
is_optimized: Whether to use optimized attention.
|
||||||
|
device: The device to use.
|
||||||
|
dtype: The dtype to use.
|
||||||
|
"""
|
||||||
assert (
|
assert (
|
||||||
embedding_dim % num_heads == 0
|
embedding_dim % num_heads == 0
|
||||||
), f"embedding_dim {embedding_dim} must be divisible by num_heads {num_heads}"
|
), f"embedding_dim {embedding_dim} must be divisible by num_heads {num_heads}"
|
||||||
|
@ -144,23 +274,24 @@ class Attention(Chain):
|
||||||
self.use_bias = use_bias
|
self.use_bias = use_bias
|
||||||
self.is_causal = is_causal
|
self.is_causal = is_causal
|
||||||
self.is_optimized = is_optimized
|
self.is_optimized = is_optimized
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
Distribute(
|
Distribute(
|
||||||
Linear(
|
Linear( # Query projection
|
||||||
in_features=self.embedding_dim,
|
in_features=self.embedding_dim,
|
||||||
out_features=self.inner_dim,
|
out_features=self.inner_dim,
|
||||||
bias=self.use_bias,
|
bias=self.use_bias,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
Linear(
|
Linear( # Key projection
|
||||||
in_features=self.key_embedding_dim,
|
in_features=self.key_embedding_dim,
|
||||||
out_features=self.inner_dim,
|
out_features=self.inner_dim,
|
||||||
bias=self.use_bias,
|
bias=self.use_bias,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
Linear(
|
Linear( # Value projection
|
||||||
in_features=self.value_embedding_dim,
|
in_features=self.value_embedding_dim,
|
||||||
out_features=self.inner_dim,
|
out_features=self.inner_dim,
|
||||||
bias=self.use_bias,
|
bias=self.use_bias,
|
||||||
|
@ -168,8 +299,12 @@ class Attention(Chain):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal, is_optimized=is_optimized),
|
ScaledDotProductAttention(
|
||||||
Linear(
|
num_heads=num_heads,
|
||||||
|
is_causal=is_causal,
|
||||||
|
is_optimized=is_optimized,
|
||||||
|
),
|
||||||
|
Linear( # Output projection
|
||||||
in_features=self.inner_dim,
|
in_features=self.inner_dim,
|
||||||
out_features=self.embedding_dim,
|
out_features=self.embedding_dim,
|
||||||
bias=True,
|
bias=True,
|
||||||
|
@ -180,17 +315,54 @@ class Attention(Chain):
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(Attention):
|
class SelfAttention(Attention):
|
||||||
|
"""Multi-Head Self-Attention layer.
|
||||||
|
|
||||||
|
Note: This layer simply chains
|
||||||
|
- a [`Parallel`][refiners.fluxion.layers.chain.Parallel] layer,
|
||||||
|
which duplicates the input Tensor
|
||||||
|
(for each Linear layer in the `Attention` layer)
|
||||||
|
- an [`Attention`][refiners.fluxion.layers.attentions.Attention] layer
|
||||||
|
|
||||||
|
Receives:
|
||||||
|
(Float[Tensor, "batch sequence_length embedding_dim"]):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Float[Tensor, "batch sequence_length embedding_dim"]):
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
self_attention = fl.SelfAttention(num_heads=8, embedding_dim=128)
|
||||||
|
|
||||||
|
tensor = torch.randn(2, 10, 128)
|
||||||
|
output = self_attention(tensor)
|
||||||
|
|
||||||
|
assert output.shape == (2, 10, 128)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
inner_dim: int | None = None,
|
inner_dim: int | None = None,
|
||||||
num_heads: int = 1,
|
num_heads: int = 1,
|
||||||
use_bias: bool = True,
|
use_bias: bool = True,
|
||||||
is_causal: bool | None = None,
|
is_causal: bool = False,
|
||||||
is_optimized: bool = True,
|
is_optimized: bool = True,
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initialize the Self-Attention layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_dim: The embedding dimension of the input and output tensors.
|
||||||
|
inner_dim: The inner dimension of the linear layers.
|
||||||
|
num_heads: The number of heads to use.
|
||||||
|
use_bias: Whether to use bias in the linear layers.
|
||||||
|
is_causal: Whether to use causal attention.
|
||||||
|
is_optimized: Whether to use optimized attention.
|
||||||
|
device: The device to use.
|
||||||
|
dtype: The dtype to use.
|
||||||
|
"""
|
||||||
super().__init__(
|
super().__init__(
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
inner_dim=inner_dim,
|
inner_dim=inner_dim,
|
||||||
|
@ -201,22 +373,67 @@ class SelfAttention(Attention):
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
self.insert(0, Parallel(Identity(), Identity(), Identity()))
|
self.insert(
|
||||||
|
index=0,
|
||||||
|
module=Parallel(
|
||||||
|
Identity(), # Query projection's input
|
||||||
|
Identity(), # Key projection's input
|
||||||
|
Identity(), # Value projection's input
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention2d(SelfAttention):
|
class SelfAttention2d(SelfAttention):
|
||||||
|
"""Multi-Head 2D Self-Attention layer.
|
||||||
|
|
||||||
|
Note: This Module simply chains
|
||||||
|
- a [`Lambda`][refiners.fluxion.layers.chain.Lambda] layer,
|
||||||
|
which transforms the input Tensor into a sequence
|
||||||
|
- a [`SelfAttention`][refiners.fluxion.layers.attentions.SelfAttention] layer
|
||||||
|
- a [`Lambda`][refiners.fluxion.layers.chain.Lambda] layer,
|
||||||
|
which transforms the output sequence into a 2D Tensor
|
||||||
|
|
||||||
|
Receives:
|
||||||
|
(Float[Tensor, "batch channels height width"]):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Float[Tensor, "batch channels height width"]):
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
self_attention = fl.SelfAttention2d(channels=128, num_heads=8)
|
||||||
|
|
||||||
|
tensor = torch.randn(2, 128, 64, 64)
|
||||||
|
output = self_attention(tensor)
|
||||||
|
|
||||||
|
assert output.shape == (2, 128, 64, 64)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
channels: int,
|
channels: int,
|
||||||
num_heads: int = 1,
|
num_heads: int = 1,
|
||||||
use_bias: bool = True,
|
use_bias: bool = True,
|
||||||
is_causal: bool | None = None,
|
is_causal: bool = False,
|
||||||
is_optimized: bool = True,
|
is_optimized: bool = True,
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initialize the 2D Self-Attention layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels: The number of channels of the input and output tensors.
|
||||||
|
num_heads: The number of heads to use.
|
||||||
|
use_bias: Whether to use bias in the linear layers.
|
||||||
|
is_causal: Whether to use causal attention.
|
||||||
|
is_optimized: Whether to use optimized attention.
|
||||||
|
device: The device to use.
|
||||||
|
dtype: The dtype to use.
|
||||||
|
"""
|
||||||
assert channels % num_heads == 0, f"channels {channels} must be divisible by num_heads {num_heads}"
|
assert channels % num_heads == 0, f"channels {channels} must be divisible by num_heads {num_heads}"
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
embedding_dim=channels,
|
embedding_dim=channels,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
|
@ -226,21 +443,45 @@ class SelfAttention2d(SelfAttention):
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
self.insert(0, Lambda(self.tensor_2d_to_sequence))
|
|
||||||
self.append(Lambda(self.sequence_to_tensor_2d))
|
self.insert(0, Lambda(self._tensor_2d_to_sequence))
|
||||||
|
self.append(Lambda(self._sequence_to_tensor_2d))
|
||||||
|
|
||||||
def init_context(self) -> Contexts:
|
def init_context(self) -> Contexts:
|
||||||
return {"reshape": {"height": None, "width": None}}
|
return {
|
||||||
|
"reshape": {
|
||||||
|
"height": None,
|
||||||
|
"width": None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def tensor_2d_to_sequence(
|
def _tensor_2d_to_sequence(
|
||||||
self, x: Float[Tensor, "batch channels height width"]
|
self,
|
||||||
|
x: Float[Tensor, "batch channels height width"],
|
||||||
) -> Float[Tensor, "batch height*width channels"]:
|
) -> Float[Tensor, "batch height*width channels"]:
|
||||||
|
"""Transform a 2D Tensor into a sequence.
|
||||||
|
|
||||||
|
The height and width of the input Tensor are stored in the context,
|
||||||
|
so that the output Tensor can be transformed back into a 2D Tensor in the `sequence_to_tensor_2d` method.
|
||||||
|
"""
|
||||||
height, width = x.shape[-2:]
|
height, width = x.shape[-2:]
|
||||||
self.set_context(context="reshape", value={"height": height, "width": width})
|
self.set_context(
|
||||||
|
context="reshape",
|
||||||
|
value={
|
||||||
|
"height": height,
|
||||||
|
"width": width,
|
||||||
|
},
|
||||||
|
)
|
||||||
return x.reshape(x.shape[0], x.shape[1], height * width).transpose(1, 2)
|
return x.reshape(x.shape[0], x.shape[1], height * width).transpose(1, 2)
|
||||||
|
|
||||||
def sequence_to_tensor_2d(
|
def _sequence_to_tensor_2d(
|
||||||
self, x: Float[Tensor, "batch sequence_length channels"]
|
self,
|
||||||
|
x: Float[Tensor, "batch sequence_length channels"],
|
||||||
) -> Float[Tensor, "batch channels height width"]:
|
) -> Float[Tensor, "batch channels height width"]:
|
||||||
|
"""Transform a sequence into a 2D Tensor.
|
||||||
|
|
||||||
|
The height and width of the output Tensor are retrieved from the context,
|
||||||
|
which was set in the `tensor_2d_to_sequence` method.
|
||||||
|
"""
|
||||||
height, width = self.use_context("reshape").values()
|
height, width = self.use_context("reshape").values()
|
||||||
return x.transpose(1, 2).reshape(x.shape[0], x.shape[2], height, width)
|
return x.transpose(1, 2).reshape(x.shape[0], x.shape[2], height, width)
|
||||||
|
|
Loading…
Reference in a new issue