feature: add sliced-attention for memory efficiency

This allowed me to produce HD images on M1 32gb and 7000x5000 on Nvidia 4090

I saw no visual difference in images generated.

Some datapoints on slice_size
# 4096 max needed for SD 1.5 512x512
# 9216 max needed for SD 1.5 768x768
# 16384 max needed for SD 1.5 1024x1024
# 32400 max needed for SD 1.5 1920x1080 (HD)
# 129600 max needed for SD 1.5 3840x2160 (4k)
# 234375 max needed for SD 1.5 5000x3000
This commit is contained in:
Bryce 2023-11-30 22:55:38 -08:00 committed by Cédric Deltheil
parent b306c7db1b
commit 4176868e79

View file

@ -1,4 +1,5 @@
import math
import torch
from jaxtyping import Float
from torch.nn.functional import scaled_dot_product_attention as _scaled_dot_product_attention # type: ignore
@ -37,11 +38,18 @@ def sparse_dot_product_attention_non_optimized(
class ScaledDotProductAttention(Module):
def __init__(self, num_heads: int = 1, is_causal: bool | None = None, is_optimized: bool = True) -> None:
def __init__(
self,
num_heads: int = 1,
is_causal: bool | None = None,
is_optimized: bool = True,
slice_size: int | None = None,
) -> None:
super().__init__()
self.num_heads = num_heads
self.is_causal = is_causal
self.is_optimized = is_optimized
self.slice_size = slice_size
self.dot_product = (
scaled_dot_product_attention if self.is_optimized else sparse_dot_product_attention_non_optimized
)
@ -52,6 +60,35 @@ class ScaledDotProductAttention(Module):
key: Float[Tensor, "batch num_keys embedding_dim"],
value: Float[Tensor, "batch num_values embedding_dim"],
is_causal: bool | None = None,
) -> Float[Tensor, "batch num_queries dim"]:
if self.slice_size is None:
return self._process_attention(query, key, value, is_causal)
return self._sliced_attention(query, key, value, is_causal=is_causal, slice_size=self.slice_size)
def _sliced_attention(
self,
query: Float[Tensor, "batch num_queries embedding_dim"],
key: Float[Tensor, "batch num_keys embedding_dim"],
value: Float[Tensor, "batch num_values embedding_dim"],
slice_size: int,
is_causal: bool | None = None,
) -> Float[Tensor, "batch num_queries dim"]:
_, num_queries, _ = query.shape
output = torch.zeros_like(query)
for start_idx in range(0, num_queries, slice_size):
end_idx = min(start_idx + slice_size, num_queries)
output[:, start_idx:end_idx, :] = self._process_attention(
query[:, start_idx:end_idx, :], key, value, is_causal
)
return output
def _process_attention(
self,
query: Float[Tensor, "batch num_queries embedding_dim"],
key: Float[Tensor, "batch num_keys embedding_dim"],
value: Float[Tensor, "batch num_values embedding_dim"],
is_causal: bool | None = None,
) -> Float[Tensor, "batch num_queries dim"]:
return self.merge_multi_head(
x=self.dot_product(