add support for self-attention guidance

See https://arxiv.org/abs/2210.00939
This commit is contained in:
Cédric Deltheil 2023-10-09 16:57:58 +02:00
parent 976b55aea5
commit d3365d6383
10 changed files with 380 additions and 4 deletions

View file

@ -7,7 +7,6 @@ import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
T = TypeVar("T", bound="fl.Module")
@ -68,6 +67,17 @@ class LatentDiffusionModel(fl.Module, ABC):
@abstractmethod
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: ...
@abstractmethod
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None: ...
@abstractmethod
def has_self_attention_guidance(self) -> bool: ...
@abstractmethod
def compute_self_attention_guidance(
self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
) -> Tensor: ...
def forward(
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
) -> Tensor:
@ -80,6 +90,12 @@ class LatentDiffusionModel(fl.Module, ABC):
# classifier-free guidance
noise = unconditional_prediction + condition_scale * (conditional_prediction - unconditional_prediction)
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting
if self.has_self_attention_guidance():
noise += self.compute_self_attention_guidance(
x=x, noise=unconditional_prediction, step=step, clip_text_embedding=clip_text_embedding, **kwargs
)
return self.scheduler(x, noise=noise, step=step)
def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:

View file

@ -0,0 +1,101 @@
from typing import Any, Generic, TypeVar, TYPE_CHECKING
import math
from torch import Tensor, Size
from jaxtyping import Float
import torch
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.context import Contexts
from refiners.fluxion.utils import interpolate, gaussian_blur
import refiners.fluxion.layers as fl
if TYPE_CHECKING:
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
T = TypeVar("T", bound="SD1UNet | SDXLUNet")
TSAGAdapter = TypeVar("TSAGAdapter", bound="SAGAdapter[Any]") # Self (see PEP 673)
class SelfAttentionMap(fl.Passthrough):
def __init__(self, num_heads: int, context_key: str) -> None:
self.num_heads = num_heads
self.context_key = context_key
super().__init__(
fl.Lambda(func=self.compute_attention_scores),
fl.SetContext(context="self_attention_map", key=context_key),
)
def split_to_multi_head(
self, x: Float[Tensor, "batch_size sequence_length embedding_dim"]
) -> Float[Tensor, "batch_size num_heads sequence_length (embedding_dim//num_heads)"]:
assert (
len(x.shape) == 3
), f"Expected tensor with shape (batch_size sequence_length embedding_dim), got {x.shape}"
assert (
x.shape[-1] % self.num_heads == 0
), f"Embedding dim (x.shape[-1]={x.shape[-1]}) must be divisible by num heads"
return x.reshape(x.shape[0], x.shape[1], self.num_heads, x.shape[-1] // self.num_heads).transpose(1, 2)
def compute_attention_scores(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
query, key = self.split_to_multi_head(query), self.split_to_multi_head(key)
_, _, _, dim = query.shape
attention = query @ key.permute(0, 1, 3, 2)
attention = attention / math.sqrt(dim)
return torch.softmax(input=attention, dim=-1)
class SelfAttentionShape(fl.Passthrough):
def __init__(self, context_key: str) -> None:
self.context_key = context_key
super().__init__(
fl.SetContext(context="self_attention_map", key=context_key, callback=self.register_shape),
)
def register_shape(self, shapes: list[Size], x: Tensor) -> None:
assert x.ndim == 4, f"Expected 4D tensor, got {x.ndim}D with shape {x.shape}"
shapes.append(x.shape[-2:])
class SAGAdapter(Generic[T], fl.Chain, Adapter[T]):
def __init__(self, target: T, scale: float = 1.0, kernel_size: int = 9, sigma: float = 1.0) -> None:
self.scale = scale
self.kernel_size = kernel_size
self.sigma = sigma
with self.setup_adapter(target):
super().__init__(target)
def inject(self: "TSAGAdapter", parent: fl.Chain | None = None) -> "TSAGAdapter":
return super().inject(parent)
def eject(self) -> None:
super().eject()
def compute_sag_mask(
self, latents: Float[Tensor, "batch_size channels height width"], classifier_free_guidance: bool = True
) -> Float[Tensor, "batch_size channels height width"]:
attn_map = self.use_context("self_attention_map")["middle_block_attn_map"]
if classifier_free_guidance:
unconditional_attn, _ = attn_map.chunk(2)
attn_map = unconditional_attn
attn_shape = self.use_context("self_attention_map")["middle_block_attn_shape"].pop()
assert len(attn_shape) == 2
b, c, h, w = latents.shape
attn_h, attn_w = attn_shape
attn_mask = attn_map.mean(dim=1, keepdim=False).sum(dim=1, keepdim=False) > 1.0
attn_mask = attn_mask.reshape(b, attn_h, attn_w).unsqueeze(1).repeat(1, c, 1, 1).type(attn_map.dtype)
return interpolate(attn_mask, Size((h, w)))
def compute_degraded_latents(
self, scheduler: Scheduler, latents: Tensor, noise: Tensor, step: int, classifier_free_guidance: bool = True
) -> Tensor:
sag_mask = self.compute_sag_mask(latents=latents, classifier_free_guidance=classifier_free_guidance)
original_latents = scheduler.remove_noise(x=latents, noise=noise, step=step)
degraded_latents = gaussian_blur(original_latents, kernel_size=self.kernel_size, sigma=self.sigma)
degraded_latents = degraded_latents * sag_mask + original_latents * (1 - sag_mask)
return scheduler.add_noise(degraded_latents, noise=noise, step=step)
def init_context(self) -> Contexts:
return {"self_attention_map": {"middle_block_attn_map": None, "middle_block_attn_shape": []}}

View file

@ -6,6 +6,7 @@ from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_1.self_attention_guidance import SD1SAGAdapter
from PIL import Image
import numpy as np
from torch import device as Device, dtype as DType, Tensor
@ -54,6 +55,47 @@ class StableDiffusion_1(LatentDiffusionModel):
self.unet.set_timestep(timestep=timestep)
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
if enable:
if sag := self._find_sag_adapter():
sag.scale = scale
else:
sag = SD1SAGAdapter(target=self.unet, scale=scale)
sag.inject()
else:
if sag := self._find_sag_adapter():
sag.eject()
def has_self_attention_guidance(self) -> bool:
return self._find_sag_adapter() is not None
def _find_sag_adapter(self) -> SD1SAGAdapter | None:
for p in self.unet.get_parents():
if isinstance(p, SD1SAGAdapter):
return p
return None
def compute_self_attention_guidance(
self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
) -> Tensor:
sag = self._find_sag_adapter()
assert sag is not None
degraded_latents = sag.compute_degraded_latents(
scheduler=self.scheduler,
latents=x,
noise=noise,
step=step,
classifier_free_guidance=True,
)
negative_embedding, _ = clip_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
degraded_noise = self.unet(degraded_latents)
return sag.scale * (noise - degraded_noise)
class StableDiffusion_1_Inpainting(StableDiffusion_1):
def __init__(

View file

@ -0,0 +1,41 @@
from refiners.foundationals.latent_diffusion.self_attention_guidance import (
SAGAdapter,
SelfAttentionShape,
SelfAttentionMap,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet, MiddleBlock, ResidualBlock
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
import refiners.fluxion.layers as fl
class SD1SAGAdapter(SAGAdapter[SD1UNet]):
def __init__(self, target: SD1UNet, scale: float = 1.0, kernel_size: int = 9, sigma: float = 1.0) -> None:
super().__init__(
target=target,
scale=scale,
kernel_size=kernel_size,
sigma=sigma,
)
def inject(self: "SD1SAGAdapter", parent: fl.Chain | None = None) -> "SD1SAGAdapter":
middle_block = self.target.ensure_find(MiddleBlock)
middle_block.insert_after_type(ResidualBlock, SelfAttentionShape(context_key="middle_block_attn_shape"))
# An alternative would be to replace the ScaledDotProductAttention with a version which records the attention
# scores to avoid computing these scores twice
self_attn = middle_block.ensure_find(fl.SelfAttention)
self_attn.insert_before_type(
ScaledDotProductAttention,
SelfAttentionMap(num_heads=self_attn.num_heads, context_key="middle_block_attn_map"),
)
return super().inject(parent)
def eject(self) -> None:
middle_block = self.target.ensure_find(MiddleBlock)
middle_block.remove(middle_block.ensure_find(SelfAttentionShape))
self_attn = middle_block.ensure_find(fl.SelfAttention)
self_attn.remove(self_attn.ensure_find(SelfAttentionMap))
super().eject()

View file

@ -4,6 +4,7 @@ from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.self_attention_guidance import SDXLSAGAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
from torch import device as Device, dtype as DType, Tensor
@ -67,7 +68,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
clip_text_embedding: Tensor,
pooled_text_embedding: Tensor,
time_ids: Tensor,
**_: Tensor
**_: Tensor,
) -> None:
self.unet.set_timestep(timestep=timestep)
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
@ -83,7 +84,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
pooled_text_embedding: Tensor,
time_ids: Tensor,
condition_scale: float = 5.0,
**kwargs: Tensor
**kwargs: Tensor,
) -> Tensor:
return super().forward(
x=x,
@ -92,5 +93,62 @@ class StableDiffusion_XL(LatentDiffusionModel):
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=condition_scale,
**kwargs
**kwargs,
)
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
if enable:
if sag := self._find_sag_adapter():
sag.scale = scale
else:
sag = SDXLSAGAdapter(target=self.unet, scale=scale)
sag.inject()
else:
if sag := self._find_sag_adapter():
sag.eject()
def has_self_attention_guidance(self) -> bool:
return self._find_sag_adapter() is not None
def _find_sag_adapter(self) -> SDXLSAGAdapter | None:
for p in self.unet.get_parents():
if isinstance(p, SDXLSAGAdapter):
return p
return None
def compute_self_attention_guidance(
self,
x: Tensor,
noise: Tensor,
step: int,
*,
clip_text_embedding: Tensor,
pooled_text_embedding: Tensor,
time_ids: Tensor,
**kwargs: Tensor,
) -> Tensor:
sag = self._find_sag_adapter()
assert sag is not None
degraded_latents = sag.compute_degraded_latents(
scheduler=self.scheduler,
latents=x,
noise=noise,
step=step,
classifier_free_guidance=True,
)
negative_embedding, _ = clip_text_embedding.chunk(2)
negative_pooled_embedding, _ = pooled_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
time_ids, _ = time_ids.chunk(2)
self.set_unet_context(
timestep=timestep,
clip_text_embedding=negative_embedding,
pooled_text_embedding=negative_pooled_embedding,
time_ids=time_ids,
**kwargs,
)
degraded_noise = self.unet(degraded_latents)
return sag.scale * (noise - degraded_noise)

View file

@ -0,0 +1,41 @@
from refiners.foundationals.latent_diffusion.self_attention_guidance import (
SAGAdapter,
SelfAttentionShape,
SelfAttentionMap,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet, MiddleBlock, ResidualBlock
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
import refiners.fluxion.layers as fl
class SDXLSAGAdapter(SAGAdapter[SDXLUNet]):
def __init__(self, target: SDXLUNet, scale: float = 1.0, kernel_size: int = 9, sigma: float = 1.0) -> None:
super().__init__(
target=target,
scale=scale,
kernel_size=kernel_size,
sigma=sigma,
)
def inject(self: "SDXLSAGAdapter", parent: fl.Chain | None = None) -> "SDXLSAGAdapter":
middle_block = self.target.ensure_find(MiddleBlock)
middle_block.insert_after_type(ResidualBlock, SelfAttentionShape(context_key="middle_block_attn_shape"))
# An alternative would be to replace the ScaledDotProductAttention with a version which records the attention
# scores to avoid computing these scores twice
self_attn = middle_block.ensure_find(fl.SelfAttention)
self_attn.insert_before_type(
ScaledDotProductAttention,
SelfAttentionMap(num_heads=self_attn.num_heads, context_key="middle_block_attn_map"),
)
return super().inject(parent)
def eject(self) -> None:
middle_block = self.target.ensure_find(MiddleBlock)
middle_block.remove(middle_block.ensure_find(SelfAttentionShape))
self_attn = middle_block.ensure_find(fl.SelfAttention)
self_attn.remove(self_attn.ensure_find(SelfAttentionMap))
super().eject()

View file

@ -64,6 +64,11 @@ def expected_image_std_random_init(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_std_random_init.png").convert("RGB")
@pytest.fixture
def expected_image_std_random_init_sag(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_std_random_init_sag.png").convert("RGB")
@pytest.fixture
def expected_image_std_init_image(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_std_init_image.png").convert("RGB")
@ -109,6 +114,11 @@ def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image:
return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert(mode="RGB")
@pytest.fixture
def expected_sdxl_ddim_random_init_sag(ref_path: Path) -> Image.Image:
return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert(mode="RGB")
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
def controlnet_data(
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
@ -514,6 +524,35 @@ def test_diffusion_std_random_init_float16(
ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
def test_diffusion_std_random_init_sag(
sd15_std: StableDiffusion_1, expected_image_std_random_init_sag: Image.Image, test_device: torch.device
):
sd15 = sd15_std
n_steps = 30
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps)
sd15.set_self_attention_guidance(enable=True, scale=0.75)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image_std_random_init_sag)
@torch.no_grad()
def test_diffusion_std_init_image(
sd15_std: StableDiffusion_1,
@ -1364,6 +1403,42 @@ def test_sdxl_random_init(
ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
def test_sdxl_random_init_sag(
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init_sag: Image.Image, test_device: torch.device
) -> None:
sdxl = sdxl_ddim
expected_image = expected_sdxl_ddim_random_init_sag
n_steps = 30
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt
)
time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(num_inference_steps=n_steps)
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
manual_seed(seed=2)
x = torch.randn(1, 4, 128, 128, device=test_device)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=5,
)
predicted_image = sdxl.lda.decode_latents(x=x)
ensure_similar_images(img_1=predicted_image, img_2=expected_image)
@torch.no_grad()
def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None:
manual_seed(seed=2)

View file

@ -34,6 +34,7 @@ output.images[0].save("std_random_init_expected.png")
Special cases:
- For self-attention guidance, `StableDiffusionSAGPipeline` has been used instead of the default pipeline.
- `expected_refonly.png` has been generated [with Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui).
- The following references have been generated with refiners itself (and inspected so that they look reasonable):
- `expected_inpainting_refonly.png`,
@ -42,6 +43,7 @@ Special cases:
- `expected_ip_adapter_controlnet.png`
- `expected_t2i_adapter_xl_canny.png`
- `expected_image_sdxl_ip_adapter_plus_woman.png`
- `expected_cutecat_sdxl_ddim_random_init_sag.png`
## Other images

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 493 KiB