add support for self-attention guidance

See https://arxiv.org/abs/2210.00939
This commit is contained in:
Cédric Deltheil 2023-10-09 16:57:58 +02:00 committed by Cédric Deltheil
parent 976b55aea5
commit b80769939d
10 changed files with 380 additions and 4 deletions

View file

@ -7,7 +7,6 @@ import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
T = TypeVar("T", bound="fl.Module") T = TypeVar("T", bound="fl.Module")
@ -68,6 +67,17 @@ class LatentDiffusionModel(fl.Module, ABC):
@abstractmethod @abstractmethod
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: ... def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: ...
@abstractmethod
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None: ...
@abstractmethod
def has_self_attention_guidance(self) -> bool: ...
@abstractmethod
def compute_self_attention_guidance(
self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
) -> Tensor: ...
def forward( def forward(
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
) -> Tensor: ) -> Tensor:
@ -80,6 +90,12 @@ class LatentDiffusionModel(fl.Module, ABC):
# classifier-free guidance # classifier-free guidance
noise = unconditional_prediction + condition_scale * (conditional_prediction - unconditional_prediction) noise = unconditional_prediction + condition_scale * (conditional_prediction - unconditional_prediction)
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting
if self.has_self_attention_guidance():
noise += self.compute_self_attention_guidance(
x=x, noise=unconditional_prediction, step=step, clip_text_embedding=clip_text_embedding, **kwargs
)
return self.scheduler(x, noise=noise, step=step) return self.scheduler(x, noise=noise, step=step)
def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel: def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:

View file

@ -0,0 +1,101 @@
from typing import Any, Generic, TypeVar, TYPE_CHECKING
import math
from torch import Tensor, Size
from jaxtyping import Float
import torch
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.context import Contexts
from refiners.fluxion.utils import interpolate, gaussian_blur
import refiners.fluxion.layers as fl
if TYPE_CHECKING:
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
T = TypeVar("T", bound="SD1UNet | SDXLUNet")
TSAGAdapter = TypeVar("TSAGAdapter", bound="SAGAdapter[Any]") # Self (see PEP 673)
class SelfAttentionMap(fl.Passthrough):
def __init__(self, num_heads: int, context_key: str) -> None:
self.num_heads = num_heads
self.context_key = context_key
super().__init__(
fl.Lambda(func=self.compute_attention_scores),
fl.SetContext(context="self_attention_map", key=context_key),
)
def split_to_multi_head(
self, x: Float[Tensor, "batch_size sequence_length embedding_dim"]
) -> Float[Tensor, "batch_size num_heads sequence_length (embedding_dim//num_heads)"]:
assert (
len(x.shape) == 3
), f"Expected tensor with shape (batch_size sequence_length embedding_dim), got {x.shape}"
assert (
x.shape[-1] % self.num_heads == 0
), f"Embedding dim (x.shape[-1]={x.shape[-1]}) must be divisible by num heads"
return x.reshape(x.shape[0], x.shape[1], self.num_heads, x.shape[-1] // self.num_heads).transpose(1, 2)
def compute_attention_scores(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
query, key = self.split_to_multi_head(query), self.split_to_multi_head(key)
_, _, _, dim = query.shape
attention = query @ key.permute(0, 1, 3, 2)
attention = attention / math.sqrt(dim)
return torch.softmax(input=attention, dim=-1)
class SelfAttentionShape(fl.Passthrough):
def __init__(self, context_key: str) -> None:
self.context_key = context_key
super().__init__(
fl.SetContext(context="self_attention_map", key=context_key, callback=self.register_shape),
)
def register_shape(self, shapes: list[Size], x: Tensor) -> None:
assert x.ndim == 4, f"Expected 4D tensor, got {x.ndim}D with shape {x.shape}"
shapes.append(x.shape[-2:])
class SAGAdapter(Generic[T], fl.Chain, Adapter[T]):
def __init__(self, target: T, scale: float = 1.0, kernel_size: int = 9, sigma: float = 1.0) -> None:
self.scale = scale
self.kernel_size = kernel_size
self.sigma = sigma
with self.setup_adapter(target):
super().__init__(target)
def inject(self: "TSAGAdapter", parent: fl.Chain | None = None) -> "TSAGAdapter":
return super().inject(parent)
def eject(self) -> None:
super().eject()
def compute_sag_mask(
self, latents: Float[Tensor, "batch_size channels height width"], classifier_free_guidance: bool = True
) -> Float[Tensor, "batch_size channels height width"]:
attn_map = self.use_context("self_attention_map")["middle_block_attn_map"]
if classifier_free_guidance:
unconditional_attn, _ = attn_map.chunk(2)
attn_map = unconditional_attn
attn_shape = self.use_context("self_attention_map")["middle_block_attn_shape"].pop()
assert len(attn_shape) == 2
b, c, h, w = latents.shape
attn_h, attn_w = attn_shape
attn_mask = attn_map.mean(dim=1, keepdim=False).sum(dim=1, keepdim=False) > 1.0
attn_mask = attn_mask.reshape(b, attn_h, attn_w).unsqueeze(1).repeat(1, c, 1, 1).type(attn_map.dtype)
return interpolate(attn_mask, Size((h, w)))
def compute_degraded_latents(
self, scheduler: Scheduler, latents: Tensor, noise: Tensor, step: int, classifier_free_guidance: bool = True
) -> Tensor:
sag_mask = self.compute_sag_mask(latents=latents, classifier_free_guidance=classifier_free_guidance)
original_latents = scheduler.remove_noise(x=latents, noise=noise, step=step)
degraded_latents = gaussian_blur(original_latents, kernel_size=self.kernel_size, sigma=self.sigma)
degraded_latents = degraded_latents * sag_mask + original_latents * (1 - sag_mask)
return scheduler.add_noise(degraded_latents, noise=noise, step=step)
def init_context(self) -> Contexts:
return {"self_attention_map": {"middle_block_attn_map": None, "middle_block_attn_shape": []}}

View file

@ -6,6 +6,7 @@ from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_1.self_attention_guidance import SD1SAGAdapter
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from torch import device as Device, dtype as DType, Tensor from torch import device as Device, dtype as DType, Tensor
@ -54,6 +55,47 @@ class StableDiffusion_1(LatentDiffusionModel):
self.unet.set_timestep(timestep=timestep) self.unet.set_timestep(timestep=timestep)
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding) self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
if enable:
if sag := self._find_sag_adapter():
sag.scale = scale
else:
sag = SD1SAGAdapter(target=self.unet, scale=scale)
sag.inject()
else:
if sag := self._find_sag_adapter():
sag.eject()
def has_self_attention_guidance(self) -> bool:
return self._find_sag_adapter() is not None
def _find_sag_adapter(self) -> SD1SAGAdapter | None:
for p in self.unet.get_parents():
if isinstance(p, SD1SAGAdapter):
return p
return None
def compute_self_attention_guidance(
self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
) -> Tensor:
sag = self._find_sag_adapter()
assert sag is not None
degraded_latents = sag.compute_degraded_latents(
scheduler=self.scheduler,
latents=x,
noise=noise,
step=step,
classifier_free_guidance=True,
)
negative_embedding, _ = clip_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
degraded_noise = self.unet(degraded_latents)
return sag.scale * (noise - degraded_noise)
class StableDiffusion_1_Inpainting(StableDiffusion_1): class StableDiffusion_1_Inpainting(StableDiffusion_1):
def __init__( def __init__(

View file

@ -0,0 +1,41 @@
from refiners.foundationals.latent_diffusion.self_attention_guidance import (
SAGAdapter,
SelfAttentionShape,
SelfAttentionMap,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet, MiddleBlock, ResidualBlock
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
import refiners.fluxion.layers as fl
class SD1SAGAdapter(SAGAdapter[SD1UNet]):
def __init__(self, target: SD1UNet, scale: float = 1.0, kernel_size: int = 9, sigma: float = 1.0) -> None:
super().__init__(
target=target,
scale=scale,
kernel_size=kernel_size,
sigma=sigma,
)
def inject(self: "SD1SAGAdapter", parent: fl.Chain | None = None) -> "SD1SAGAdapter":
middle_block = self.target.ensure_find(MiddleBlock)
middle_block.insert_after_type(ResidualBlock, SelfAttentionShape(context_key="middle_block_attn_shape"))
# An alternative would be to replace the ScaledDotProductAttention with a version which records the attention
# scores to avoid computing these scores twice
self_attn = middle_block.ensure_find(fl.SelfAttention)
self_attn.insert_before_type(
ScaledDotProductAttention,
SelfAttentionMap(num_heads=self_attn.num_heads, context_key="middle_block_attn_map"),
)
return super().inject(parent)
def eject(self) -> None:
middle_block = self.target.ensure_find(MiddleBlock)
middle_block.remove(middle_block.ensure_find(SelfAttentionShape))
self_attn = middle_block.ensure_find(fl.SelfAttention)
self_attn.remove(self_attn.ensure_find(SelfAttentionMap))
super().eject()

View file

@ -4,6 +4,7 @@ from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.self_attention_guidance import SDXLSAGAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
from torch import device as Device, dtype as DType, Tensor from torch import device as Device, dtype as DType, Tensor
@ -67,7 +68,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
clip_text_embedding: Tensor, clip_text_embedding: Tensor,
pooled_text_embedding: Tensor, pooled_text_embedding: Tensor,
time_ids: Tensor, time_ids: Tensor,
**_: Tensor **_: Tensor,
) -> None: ) -> None:
self.unet.set_timestep(timestep=timestep) self.unet.set_timestep(timestep=timestep)
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding) self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
@ -83,7 +84,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
pooled_text_embedding: Tensor, pooled_text_embedding: Tensor,
time_ids: Tensor, time_ids: Tensor,
condition_scale: float = 5.0, condition_scale: float = 5.0,
**kwargs: Tensor **kwargs: Tensor,
) -> Tensor: ) -> Tensor:
return super().forward( return super().forward(
x=x, x=x,
@ -92,5 +93,62 @@ class StableDiffusion_XL(LatentDiffusionModel):
pooled_text_embedding=pooled_text_embedding, pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids, time_ids=time_ids,
condition_scale=condition_scale, condition_scale=condition_scale,
**kwargs **kwargs,
) )
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
if enable:
if sag := self._find_sag_adapter():
sag.scale = scale
else:
sag = SDXLSAGAdapter(target=self.unet, scale=scale)
sag.inject()
else:
if sag := self._find_sag_adapter():
sag.eject()
def has_self_attention_guidance(self) -> bool:
return self._find_sag_adapter() is not None
def _find_sag_adapter(self) -> SDXLSAGAdapter | None:
for p in self.unet.get_parents():
if isinstance(p, SDXLSAGAdapter):
return p
return None
def compute_self_attention_guidance(
self,
x: Tensor,
noise: Tensor,
step: int,
*,
clip_text_embedding: Tensor,
pooled_text_embedding: Tensor,
time_ids: Tensor,
**kwargs: Tensor,
) -> Tensor:
sag = self._find_sag_adapter()
assert sag is not None
degraded_latents = sag.compute_degraded_latents(
scheduler=self.scheduler,
latents=x,
noise=noise,
step=step,
classifier_free_guidance=True,
)
negative_embedding, _ = clip_text_embedding.chunk(2)
negative_pooled_embedding, _ = pooled_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
time_ids, _ = time_ids.chunk(2)
self.set_unet_context(
timestep=timestep,
clip_text_embedding=negative_embedding,
pooled_text_embedding=negative_pooled_embedding,
time_ids=time_ids,
**kwargs,
)
degraded_noise = self.unet(degraded_latents)
return sag.scale * (noise - degraded_noise)

View file

@ -0,0 +1,41 @@
from refiners.foundationals.latent_diffusion.self_attention_guidance import (
SAGAdapter,
SelfAttentionShape,
SelfAttentionMap,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet, MiddleBlock, ResidualBlock
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
import refiners.fluxion.layers as fl
class SDXLSAGAdapter(SAGAdapter[SDXLUNet]):
def __init__(self, target: SDXLUNet, scale: float = 1.0, kernel_size: int = 9, sigma: float = 1.0) -> None:
super().__init__(
target=target,
scale=scale,
kernel_size=kernel_size,
sigma=sigma,
)
def inject(self: "SDXLSAGAdapter", parent: fl.Chain | None = None) -> "SDXLSAGAdapter":
middle_block = self.target.ensure_find(MiddleBlock)
middle_block.insert_after_type(ResidualBlock, SelfAttentionShape(context_key="middle_block_attn_shape"))
# An alternative would be to replace the ScaledDotProductAttention with a version which records the attention
# scores to avoid computing these scores twice
self_attn = middle_block.ensure_find(fl.SelfAttention)
self_attn.insert_before_type(
ScaledDotProductAttention,
SelfAttentionMap(num_heads=self_attn.num_heads, context_key="middle_block_attn_map"),
)
return super().inject(parent)
def eject(self) -> None:
middle_block = self.target.ensure_find(MiddleBlock)
middle_block.remove(middle_block.ensure_find(SelfAttentionShape))
self_attn = middle_block.ensure_find(fl.SelfAttention)
self_attn.remove(self_attn.ensure_find(SelfAttentionMap))
super().eject()

View file

@ -64,6 +64,11 @@ def expected_image_std_random_init(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_std_random_init.png").convert("RGB") return Image.open(ref_path / "expected_std_random_init.png").convert("RGB")
@pytest.fixture
def expected_image_std_random_init_sag(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_std_random_init_sag.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_std_init_image(ref_path: Path) -> Image.Image: def expected_image_std_init_image(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_std_init_image.png").convert("RGB") return Image.open(ref_path / "expected_std_init_image.png").convert("RGB")
@ -109,6 +114,11 @@ def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image:
return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert(mode="RGB") return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert(mode="RGB")
@pytest.fixture
def expected_sdxl_ddim_random_init_sag(ref_path: Path) -> Image.Image:
return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert(mode="RGB")
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"]) @pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
def controlnet_data( def controlnet_data(
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
@ -514,6 +524,35 @@ def test_diffusion_std_random_init_float16(
ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
def test_diffusion_std_random_init_sag(
sd15_std: StableDiffusion_1, expected_image_std_random_init_sag: Image.Image, test_device: torch.device
):
sd15 = sd15_std
n_steps = 30
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps)
sd15.set_self_attention_guidance(enable=True, scale=0.75)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image_std_random_init_sag)
@torch.no_grad() @torch.no_grad()
def test_diffusion_std_init_image( def test_diffusion_std_init_image(
sd15_std: StableDiffusion_1, sd15_std: StableDiffusion_1,
@ -1364,6 +1403,42 @@ def test_sdxl_random_init(
ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98) ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
def test_sdxl_random_init_sag(
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init_sag: Image.Image, test_device: torch.device
) -> None:
sdxl = sdxl_ddim
expected_image = expected_sdxl_ddim_random_init_sag
n_steps = 30
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt
)
time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(num_inference_steps=n_steps)
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
manual_seed(seed=2)
x = torch.randn(1, 4, 128, 128, device=test_device)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=5,
)
predicted_image = sdxl.lda.decode_latents(x=x)
ensure_similar_images(img_1=predicted_image, img_2=expected_image)
@torch.no_grad() @torch.no_grad()
def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None: def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None:
manual_seed(seed=2) manual_seed(seed=2)

View file

@ -34,6 +34,7 @@ output.images[0].save("std_random_init_expected.png")
Special cases: Special cases:
- For self-attention guidance, `StableDiffusionSAGPipeline` has been used instead of the default pipeline.
- `expected_refonly.png` has been generated [with Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui). - `expected_refonly.png` has been generated [with Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui).
- The following references have been generated with refiners itself (and inspected so that they look reasonable): - The following references have been generated with refiners itself (and inspected so that they look reasonable):
- `expected_inpainting_refonly.png`, - `expected_inpainting_refonly.png`,
@ -42,6 +43,7 @@ Special cases:
- `expected_ip_adapter_controlnet.png` - `expected_ip_adapter_controlnet.png`
- `expected_t2i_adapter_xl_canny.png` - `expected_t2i_adapter_xl_canny.png`
- `expected_image_sdxl_ip_adapter_plus_woman.png` - `expected_image_sdxl_ip_adapter_plus_woman.png`
- `expected_cutecat_sdxl_ddim_random_init_sag.png`
## Other images ## Other images

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 493 KiB