mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
feature: add sliced-attention for memory efficiency
This allowed me to produce HD images on M1 32gb and 7000x5000 on Nvidia 4090 I saw no visual difference in images generated. Some datapoints on slice_size # 4096 max needed for SD 1.5 512x512 # 9216 max needed for SD 1.5 768x768 # 16384 max needed for SD 1.5 1024x1024 # 32400 max needed for SD 1.5 1920x1080 (HD) # 129600 max needed for SD 1.5 3840x2160 (4k) # 234375 max needed for SD 1.5 5000x3000
This commit is contained in:
parent
b306c7db1b
commit
4176868e79
|
@ -1,4 +1,5 @@
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from jaxtyping import Float
|
from jaxtyping import Float
|
||||||
from torch.nn.functional import scaled_dot_product_attention as _scaled_dot_product_attention # type: ignore
|
from torch.nn.functional import scaled_dot_product_attention as _scaled_dot_product_attention # type: ignore
|
||||||
|
@ -37,11 +38,18 @@ def sparse_dot_product_attention_non_optimized(
|
||||||
|
|
||||||
|
|
||||||
class ScaledDotProductAttention(Module):
|
class ScaledDotProductAttention(Module):
|
||||||
def __init__(self, num_heads: int = 1, is_causal: bool | None = None, is_optimized: bool = True) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int = 1,
|
||||||
|
is_causal: bool | None = None,
|
||||||
|
is_optimized: bool = True,
|
||||||
|
slice_size: int | None = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.is_causal = is_causal
|
self.is_causal = is_causal
|
||||||
self.is_optimized = is_optimized
|
self.is_optimized = is_optimized
|
||||||
|
self.slice_size = slice_size
|
||||||
self.dot_product = (
|
self.dot_product = (
|
||||||
scaled_dot_product_attention if self.is_optimized else sparse_dot_product_attention_non_optimized
|
scaled_dot_product_attention if self.is_optimized else sparse_dot_product_attention_non_optimized
|
||||||
)
|
)
|
||||||
|
@ -52,6 +60,35 @@ class ScaledDotProductAttention(Module):
|
||||||
key: Float[Tensor, "batch num_keys embedding_dim"],
|
key: Float[Tensor, "batch num_keys embedding_dim"],
|
||||||
value: Float[Tensor, "batch num_values embedding_dim"],
|
value: Float[Tensor, "batch num_values embedding_dim"],
|
||||||
is_causal: bool | None = None,
|
is_causal: bool | None = None,
|
||||||
|
) -> Float[Tensor, "batch num_queries dim"]:
|
||||||
|
if self.slice_size is None:
|
||||||
|
return self._process_attention(query, key, value, is_causal)
|
||||||
|
|
||||||
|
return self._sliced_attention(query, key, value, is_causal=is_causal, slice_size=self.slice_size)
|
||||||
|
|
||||||
|
def _sliced_attention(
|
||||||
|
self,
|
||||||
|
query: Float[Tensor, "batch num_queries embedding_dim"],
|
||||||
|
key: Float[Tensor, "batch num_keys embedding_dim"],
|
||||||
|
value: Float[Tensor, "batch num_values embedding_dim"],
|
||||||
|
slice_size: int,
|
||||||
|
is_causal: bool | None = None,
|
||||||
|
) -> Float[Tensor, "batch num_queries dim"]:
|
||||||
|
_, num_queries, _ = query.shape
|
||||||
|
output = torch.zeros_like(query)
|
||||||
|
for start_idx in range(0, num_queries, slice_size):
|
||||||
|
end_idx = min(start_idx + slice_size, num_queries)
|
||||||
|
output[:, start_idx:end_idx, :] = self._process_attention(
|
||||||
|
query[:, start_idx:end_idx, :], key, value, is_causal
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _process_attention(
|
||||||
|
self,
|
||||||
|
query: Float[Tensor, "batch num_queries embedding_dim"],
|
||||||
|
key: Float[Tensor, "batch num_keys embedding_dim"],
|
||||||
|
value: Float[Tensor, "batch num_values embedding_dim"],
|
||||||
|
is_causal: bool | None = None,
|
||||||
) -> Float[Tensor, "batch num_queries dim"]:
|
) -> Float[Tensor, "batch num_queries dim"]:
|
||||||
return self.merge_multi_head(
|
return self.merge_multi_head(
|
||||||
x=self.dot_product(
|
x=self.dot_product(
|
||||||
|
|
Loading…
Reference in a new issue