From 4176868e792317a2290edee4ca27284bf1f3053f Mon Sep 17 00:00:00 2001 From: Bryce Date: Thu, 30 Nov 2023 22:55:38 -0800 Subject: [PATCH] feature: add sliced-attention for memory efficiency This allowed me to produce HD images on M1 32gb and 7000x5000 on Nvidia 4090 I saw no visual difference in images generated. Some datapoints on slice_size # 4096 max needed for SD 1.5 512x512 # 9216 max needed for SD 1.5 768x768 # 16384 max needed for SD 1.5 1024x1024 # 32400 max needed for SD 1.5 1920x1080 (HD) # 129600 max needed for SD 1.5 3840x2160 (4k) # 234375 max needed for SD 1.5 5000x3000 --- src/refiners/fluxion/layers/attentions.py | 39 ++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/src/refiners/fluxion/layers/attentions.py b/src/refiners/fluxion/layers/attentions.py index f7f2200..2618b3e 100644 --- a/src/refiners/fluxion/layers/attentions.py +++ b/src/refiners/fluxion/layers/attentions.py @@ -1,4 +1,5 @@ import math + import torch from jaxtyping import Float from torch.nn.functional import scaled_dot_product_attention as _scaled_dot_product_attention # type: ignore @@ -37,11 +38,18 @@ def sparse_dot_product_attention_non_optimized( class ScaledDotProductAttention(Module): - def __init__(self, num_heads: int = 1, is_causal: bool | None = None, is_optimized: bool = True) -> None: + def __init__( + self, + num_heads: int = 1, + is_causal: bool | None = None, + is_optimized: bool = True, + slice_size: int | None = None, + ) -> None: super().__init__() self.num_heads = num_heads self.is_causal = is_causal self.is_optimized = is_optimized + self.slice_size = slice_size self.dot_product = ( scaled_dot_product_attention if self.is_optimized else sparse_dot_product_attention_non_optimized ) @@ -52,6 +60,35 @@ class ScaledDotProductAttention(Module): key: Float[Tensor, "batch num_keys embedding_dim"], value: Float[Tensor, "batch num_values embedding_dim"], is_causal: bool | None = None, + ) -> Float[Tensor, "batch num_queries dim"]: + if self.slice_size is None: + return self._process_attention(query, key, value, is_causal) + + return self._sliced_attention(query, key, value, is_causal=is_causal, slice_size=self.slice_size) + + def _sliced_attention( + self, + query: Float[Tensor, "batch num_queries embedding_dim"], + key: Float[Tensor, "batch num_keys embedding_dim"], + value: Float[Tensor, "batch num_values embedding_dim"], + slice_size: int, + is_causal: bool | None = None, + ) -> Float[Tensor, "batch num_queries dim"]: + _, num_queries, _ = query.shape + output = torch.zeros_like(query) + for start_idx in range(0, num_queries, slice_size): + end_idx = min(start_idx + slice_size, num_queries) + output[:, start_idx:end_idx, :] = self._process_attention( + query[:, start_idx:end_idx, :], key, value, is_causal + ) + return output + + def _process_attention( + self, + query: Float[Tensor, "batch num_queries embedding_dim"], + key: Float[Tensor, "batch num_keys embedding_dim"], + value: Float[Tensor, "batch num_values embedding_dim"], + is_causal: bool | None = None, ) -> Float[Tensor, "batch num_queries dim"]: return self.merge_multi_head( x=self.dot_product(