mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
feature: add sliced-attention for memory efficiency
This allowed me to produce HD images on M1 32gb and 7000x5000 on Nvidia 4090 I saw no visual difference in images generated. Some datapoints on slice_size # 4096 max needed for SD 1.5 512x512 # 9216 max needed for SD 1.5 768x768 # 16384 max needed for SD 1.5 1024x1024 # 32400 max needed for SD 1.5 1920x1080 (HD) # 129600 max needed for SD 1.5 3840x2160 (4k) # 234375 max needed for SD 1.5 5000x3000
This commit is contained in:
parent
b306c7db1b
commit
4176868e79
|
@ -1,4 +1,5 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
from jaxtyping import Float
|
||||
from torch.nn.functional import scaled_dot_product_attention as _scaled_dot_product_attention # type: ignore
|
||||
|
@ -37,11 +38,18 @@ def sparse_dot_product_attention_non_optimized(
|
|||
|
||||
|
||||
class ScaledDotProductAttention(Module):
|
||||
def __init__(self, num_heads: int = 1, is_causal: bool | None = None, is_optimized: bool = True) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int = 1,
|
||||
is_causal: bool | None = None,
|
||||
is_optimized: bool = True,
|
||||
slice_size: int | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.is_causal = is_causal
|
||||
self.is_optimized = is_optimized
|
||||
self.slice_size = slice_size
|
||||
self.dot_product = (
|
||||
scaled_dot_product_attention if self.is_optimized else sparse_dot_product_attention_non_optimized
|
||||
)
|
||||
|
@ -52,6 +60,35 @@ class ScaledDotProductAttention(Module):
|
|||
key: Float[Tensor, "batch num_keys embedding_dim"],
|
||||
value: Float[Tensor, "batch num_values embedding_dim"],
|
||||
is_causal: bool | None = None,
|
||||
) -> Float[Tensor, "batch num_queries dim"]:
|
||||
if self.slice_size is None:
|
||||
return self._process_attention(query, key, value, is_causal)
|
||||
|
||||
return self._sliced_attention(query, key, value, is_causal=is_causal, slice_size=self.slice_size)
|
||||
|
||||
def _sliced_attention(
|
||||
self,
|
||||
query: Float[Tensor, "batch num_queries embedding_dim"],
|
||||
key: Float[Tensor, "batch num_keys embedding_dim"],
|
||||
value: Float[Tensor, "batch num_values embedding_dim"],
|
||||
slice_size: int,
|
||||
is_causal: bool | None = None,
|
||||
) -> Float[Tensor, "batch num_queries dim"]:
|
||||
_, num_queries, _ = query.shape
|
||||
output = torch.zeros_like(query)
|
||||
for start_idx in range(0, num_queries, slice_size):
|
||||
end_idx = min(start_idx + slice_size, num_queries)
|
||||
output[:, start_idx:end_idx, :] = self._process_attention(
|
||||
query[:, start_idx:end_idx, :], key, value, is_causal
|
||||
)
|
||||
return output
|
||||
|
||||
def _process_attention(
|
||||
self,
|
||||
query: Float[Tensor, "batch num_queries embedding_dim"],
|
||||
key: Float[Tensor, "batch num_keys embedding_dim"],
|
||||
value: Float[Tensor, "batch num_values embedding_dim"],
|
||||
is_causal: bool | None = None,
|
||||
) -> Float[Tensor, "batch num_queries dim"]:
|
||||
return self.merge_multi_head(
|
||||
x=self.dot_product(
|
||||
|
|
Loading…
Reference in a new issue