mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 09:08:14 +00:00
add support for self-attention guidance
See https://arxiv.org/abs/2210.00939
This commit is contained in:
parent
976b55aea5
commit
d3365d6383
|
@ -7,7 +7,6 @@ import refiners.fluxion.layers as fl
|
||||||
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound="fl.Module")
|
T = TypeVar("T", bound="fl.Module")
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,6 +67,17 @@ class LatentDiffusionModel(fl.Module, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: ...
|
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def has_self_attention_guidance(self) -> bool: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute_self_attention_guidance(
|
||||||
|
self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
|
||||||
|
) -> Tensor: ...
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
|
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
@ -80,6 +90,12 @@ class LatentDiffusionModel(fl.Module, ABC):
|
||||||
# classifier-free guidance
|
# classifier-free guidance
|
||||||
noise = unconditional_prediction + condition_scale * (conditional_prediction - unconditional_prediction)
|
noise = unconditional_prediction + condition_scale * (conditional_prediction - unconditional_prediction)
|
||||||
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting
|
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting
|
||||||
|
|
||||||
|
if self.has_self_attention_guidance():
|
||||||
|
noise += self.compute_self_attention_guidance(
|
||||||
|
x=x, noise=unconditional_prediction, step=step, clip_text_embedding=clip_text_embedding, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
return self.scheduler(x, noise=noise, step=step)
|
return self.scheduler(x, noise=noise, step=step)
|
||||||
|
|
||||||
def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
|
def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
|
||||||
|
|
|
@ -0,0 +1,101 @@
|
||||||
|
from typing import Any, Generic, TypeVar, TYPE_CHECKING
|
||||||
|
import math
|
||||||
|
|
||||||
|
from torch import Tensor, Size
|
||||||
|
from jaxtyping import Float
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||||
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
from refiners.fluxion.context import Contexts
|
||||||
|
from refiners.fluxion.utils import interpolate, gaussian_blur
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="SD1UNet | SDXLUNet")
|
||||||
|
TSAGAdapter = TypeVar("TSAGAdapter", bound="SAGAdapter[Any]") # Self (see PEP 673)
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttentionMap(fl.Passthrough):
|
||||||
|
def __init__(self, num_heads: int, context_key: str) -> None:
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.context_key = context_key
|
||||||
|
super().__init__(
|
||||||
|
fl.Lambda(func=self.compute_attention_scores),
|
||||||
|
fl.SetContext(context="self_attention_map", key=context_key),
|
||||||
|
)
|
||||||
|
|
||||||
|
def split_to_multi_head(
|
||||||
|
self, x: Float[Tensor, "batch_size sequence_length embedding_dim"]
|
||||||
|
) -> Float[Tensor, "batch_size num_heads sequence_length (embedding_dim//num_heads)"]:
|
||||||
|
assert (
|
||||||
|
len(x.shape) == 3
|
||||||
|
), f"Expected tensor with shape (batch_size sequence_length embedding_dim), got {x.shape}"
|
||||||
|
assert (
|
||||||
|
x.shape[-1] % self.num_heads == 0
|
||||||
|
), f"Embedding dim (x.shape[-1]={x.shape[-1]}) must be divisible by num heads"
|
||||||
|
return x.reshape(x.shape[0], x.shape[1], self.num_heads, x.shape[-1] // self.num_heads).transpose(1, 2)
|
||||||
|
|
||||||
|
def compute_attention_scores(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
|
||||||
|
query, key = self.split_to_multi_head(query), self.split_to_multi_head(key)
|
||||||
|
_, _, _, dim = query.shape
|
||||||
|
attention = query @ key.permute(0, 1, 3, 2)
|
||||||
|
attention = attention / math.sqrt(dim)
|
||||||
|
return torch.softmax(input=attention, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttentionShape(fl.Passthrough):
|
||||||
|
def __init__(self, context_key: str) -> None:
|
||||||
|
self.context_key = context_key
|
||||||
|
super().__init__(
|
||||||
|
fl.SetContext(context="self_attention_map", key=context_key, callback=self.register_shape),
|
||||||
|
)
|
||||||
|
|
||||||
|
def register_shape(self, shapes: list[Size], x: Tensor) -> None:
|
||||||
|
assert x.ndim == 4, f"Expected 4D tensor, got {x.ndim}D with shape {x.shape}"
|
||||||
|
shapes.append(x.shape[-2:])
|
||||||
|
|
||||||
|
|
||||||
|
class SAGAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
|
def __init__(self, target: T, scale: float = 1.0, kernel_size: int = 9, sigma: float = 1.0) -> None:
|
||||||
|
self.scale = scale
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.sigma = sigma
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(target)
|
||||||
|
|
||||||
|
def inject(self: "TSAGAdapter", parent: fl.Chain | None = None) -> "TSAGAdapter":
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
super().eject()
|
||||||
|
|
||||||
|
def compute_sag_mask(
|
||||||
|
self, latents: Float[Tensor, "batch_size channels height width"], classifier_free_guidance: bool = True
|
||||||
|
) -> Float[Tensor, "batch_size channels height width"]:
|
||||||
|
attn_map = self.use_context("self_attention_map")["middle_block_attn_map"]
|
||||||
|
if classifier_free_guidance:
|
||||||
|
unconditional_attn, _ = attn_map.chunk(2)
|
||||||
|
attn_map = unconditional_attn
|
||||||
|
attn_shape = self.use_context("self_attention_map")["middle_block_attn_shape"].pop()
|
||||||
|
assert len(attn_shape) == 2
|
||||||
|
b, c, h, w = latents.shape
|
||||||
|
attn_h, attn_w = attn_shape
|
||||||
|
attn_mask = attn_map.mean(dim=1, keepdim=False).sum(dim=1, keepdim=False) > 1.0
|
||||||
|
attn_mask = attn_mask.reshape(b, attn_h, attn_w).unsqueeze(1).repeat(1, c, 1, 1).type(attn_map.dtype)
|
||||||
|
return interpolate(attn_mask, Size((h, w)))
|
||||||
|
|
||||||
|
def compute_degraded_latents(
|
||||||
|
self, scheduler: Scheduler, latents: Tensor, noise: Tensor, step: int, classifier_free_guidance: bool = True
|
||||||
|
) -> Tensor:
|
||||||
|
sag_mask = self.compute_sag_mask(latents=latents, classifier_free_guidance=classifier_free_guidance)
|
||||||
|
original_latents = scheduler.remove_noise(x=latents, noise=noise, step=step)
|
||||||
|
degraded_latents = gaussian_blur(original_latents, kernel_size=self.kernel_size, sigma=self.sigma)
|
||||||
|
degraded_latents = degraded_latents * sag_mask + original_latents * (1 - sag_mask)
|
||||||
|
return scheduler.add_noise(degraded_latents, noise=noise, step=step)
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {"self_attention_map": {"middle_block_attn_map": None, "middle_block_attn_shape": []}}
|
|
@ -6,6 +6,7 @@ from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.self_attention_guidance import SD1SAGAdapter
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import device as Device, dtype as DType, Tensor
|
from torch import device as Device, dtype as DType, Tensor
|
||||||
|
@ -54,6 +55,47 @@ class StableDiffusion_1(LatentDiffusionModel):
|
||||||
self.unet.set_timestep(timestep=timestep)
|
self.unet.set_timestep(timestep=timestep)
|
||||||
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
||||||
|
|
||||||
|
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
|
||||||
|
if enable:
|
||||||
|
if sag := self._find_sag_adapter():
|
||||||
|
sag.scale = scale
|
||||||
|
else:
|
||||||
|
sag = SD1SAGAdapter(target=self.unet, scale=scale)
|
||||||
|
sag.inject()
|
||||||
|
else:
|
||||||
|
if sag := self._find_sag_adapter():
|
||||||
|
sag.eject()
|
||||||
|
|
||||||
|
def has_self_attention_guidance(self) -> bool:
|
||||||
|
return self._find_sag_adapter() is not None
|
||||||
|
|
||||||
|
def _find_sag_adapter(self) -> SD1SAGAdapter | None:
|
||||||
|
for p in self.unet.get_parents():
|
||||||
|
if isinstance(p, SD1SAGAdapter):
|
||||||
|
return p
|
||||||
|
return None
|
||||||
|
|
||||||
|
def compute_self_attention_guidance(
|
||||||
|
self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
|
||||||
|
) -> Tensor:
|
||||||
|
sag = self._find_sag_adapter()
|
||||||
|
assert sag is not None
|
||||||
|
|
||||||
|
degraded_latents = sag.compute_degraded_latents(
|
||||||
|
scheduler=self.scheduler,
|
||||||
|
latents=x,
|
||||||
|
noise=noise,
|
||||||
|
step=step,
|
||||||
|
classifier_free_guidance=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
negative_embedding, _ = clip_text_embedding.chunk(2)
|
||||||
|
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
||||||
|
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
|
||||||
|
degraded_noise = self.unet(degraded_latents)
|
||||||
|
|
||||||
|
return sag.scale * (noise - degraded_noise)
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
from refiners.foundationals.latent_diffusion.self_attention_guidance import (
|
||||||
|
SAGAdapter,
|
||||||
|
SelfAttentionShape,
|
||||||
|
SelfAttentionMap,
|
||||||
|
)
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet, MiddleBlock, ResidualBlock
|
||||||
|
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
|
||||||
|
class SD1SAGAdapter(SAGAdapter[SD1UNet]):
|
||||||
|
def __init__(self, target: SD1UNet, scale: float = 1.0, kernel_size: int = 9, sigma: float = 1.0) -> None:
|
||||||
|
super().__init__(
|
||||||
|
target=target,
|
||||||
|
scale=scale,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
sigma=sigma,
|
||||||
|
)
|
||||||
|
|
||||||
|
def inject(self: "SD1SAGAdapter", parent: fl.Chain | None = None) -> "SD1SAGAdapter":
|
||||||
|
middle_block = self.target.ensure_find(MiddleBlock)
|
||||||
|
middle_block.insert_after_type(ResidualBlock, SelfAttentionShape(context_key="middle_block_attn_shape"))
|
||||||
|
|
||||||
|
# An alternative would be to replace the ScaledDotProductAttention with a version which records the attention
|
||||||
|
# scores to avoid computing these scores twice
|
||||||
|
self_attn = middle_block.ensure_find(fl.SelfAttention)
|
||||||
|
self_attn.insert_before_type(
|
||||||
|
ScaledDotProductAttention,
|
||||||
|
SelfAttentionMap(num_heads=self_attn.num_heads, context_key="middle_block_attn_map"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
middle_block = self.target.ensure_find(MiddleBlock)
|
||||||
|
middle_block.remove(middle_block.ensure_find(SelfAttentionShape))
|
||||||
|
|
||||||
|
self_attn = middle_block.ensure_find(fl.SelfAttention)
|
||||||
|
self_attn.remove(self_attn.ensure_find(SelfAttentionMap))
|
||||||
|
|
||||||
|
super().eject()
|
|
@ -4,6 +4,7 @@ from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
|
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.self_attention_guidance import SDXLSAGAdapter
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||||
from torch import device as Device, dtype as DType, Tensor
|
from torch import device as Device, dtype as DType, Tensor
|
||||||
|
|
||||||
|
@ -67,7 +68,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
||||||
clip_text_embedding: Tensor,
|
clip_text_embedding: Tensor,
|
||||||
pooled_text_embedding: Tensor,
|
pooled_text_embedding: Tensor,
|
||||||
time_ids: Tensor,
|
time_ids: Tensor,
|
||||||
**_: Tensor
|
**_: Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.unet.set_timestep(timestep=timestep)
|
self.unet.set_timestep(timestep=timestep)
|
||||||
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
||||||
|
@ -83,7 +84,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
||||||
pooled_text_embedding: Tensor,
|
pooled_text_embedding: Tensor,
|
||||||
time_ids: Tensor,
|
time_ids: Tensor,
|
||||||
condition_scale: float = 5.0,
|
condition_scale: float = 5.0,
|
||||||
**kwargs: Tensor
|
**kwargs: Tensor,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
return super().forward(
|
return super().forward(
|
||||||
x=x,
|
x=x,
|
||||||
|
@ -92,5 +93,62 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
||||||
pooled_text_embedding=pooled_text_embedding,
|
pooled_text_embedding=pooled_text_embedding,
|
||||||
time_ids=time_ids,
|
time_ids=time_ids,
|
||||||
condition_scale=condition_scale,
|
condition_scale=condition_scale,
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
|
||||||
|
if enable:
|
||||||
|
if sag := self._find_sag_adapter():
|
||||||
|
sag.scale = scale
|
||||||
|
else:
|
||||||
|
sag = SDXLSAGAdapter(target=self.unet, scale=scale)
|
||||||
|
sag.inject()
|
||||||
|
else:
|
||||||
|
if sag := self._find_sag_adapter():
|
||||||
|
sag.eject()
|
||||||
|
|
||||||
|
def has_self_attention_guidance(self) -> bool:
|
||||||
|
return self._find_sag_adapter() is not None
|
||||||
|
|
||||||
|
def _find_sag_adapter(self) -> SDXLSAGAdapter | None:
|
||||||
|
for p in self.unet.get_parents():
|
||||||
|
if isinstance(p, SDXLSAGAdapter):
|
||||||
|
return p
|
||||||
|
return None
|
||||||
|
|
||||||
|
def compute_self_attention_guidance(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
noise: Tensor,
|
||||||
|
step: int,
|
||||||
|
*,
|
||||||
|
clip_text_embedding: Tensor,
|
||||||
|
pooled_text_embedding: Tensor,
|
||||||
|
time_ids: Tensor,
|
||||||
|
**kwargs: Tensor,
|
||||||
|
) -> Tensor:
|
||||||
|
sag = self._find_sag_adapter()
|
||||||
|
assert sag is not None
|
||||||
|
|
||||||
|
degraded_latents = sag.compute_degraded_latents(
|
||||||
|
scheduler=self.scheduler,
|
||||||
|
latents=x,
|
||||||
|
noise=noise,
|
||||||
|
step=step,
|
||||||
|
classifier_free_guidance=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
negative_embedding, _ = clip_text_embedding.chunk(2)
|
||||||
|
negative_pooled_embedding, _ = pooled_text_embedding.chunk(2)
|
||||||
|
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
||||||
|
time_ids, _ = time_ids.chunk(2)
|
||||||
|
self.set_unet_context(
|
||||||
|
timestep=timestep,
|
||||||
|
clip_text_embedding=negative_embedding,
|
||||||
|
pooled_text_embedding=negative_pooled_embedding,
|
||||||
|
time_ids=time_ids,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
degraded_noise = self.unet(degraded_latents)
|
||||||
|
|
||||||
|
return sag.scale * (noise - degraded_noise)
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
from refiners.foundationals.latent_diffusion.self_attention_guidance import (
|
||||||
|
SAGAdapter,
|
||||||
|
SelfAttentionShape,
|
||||||
|
SelfAttentionMap,
|
||||||
|
)
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet, MiddleBlock, ResidualBlock
|
||||||
|
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLSAGAdapter(SAGAdapter[SDXLUNet]):
|
||||||
|
def __init__(self, target: SDXLUNet, scale: float = 1.0, kernel_size: int = 9, sigma: float = 1.0) -> None:
|
||||||
|
super().__init__(
|
||||||
|
target=target,
|
||||||
|
scale=scale,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
sigma=sigma,
|
||||||
|
)
|
||||||
|
|
||||||
|
def inject(self: "SDXLSAGAdapter", parent: fl.Chain | None = None) -> "SDXLSAGAdapter":
|
||||||
|
middle_block = self.target.ensure_find(MiddleBlock)
|
||||||
|
middle_block.insert_after_type(ResidualBlock, SelfAttentionShape(context_key="middle_block_attn_shape"))
|
||||||
|
|
||||||
|
# An alternative would be to replace the ScaledDotProductAttention with a version which records the attention
|
||||||
|
# scores to avoid computing these scores twice
|
||||||
|
self_attn = middle_block.ensure_find(fl.SelfAttention)
|
||||||
|
self_attn.insert_before_type(
|
||||||
|
ScaledDotProductAttention,
|
||||||
|
SelfAttentionMap(num_heads=self_attn.num_heads, context_key="middle_block_attn_map"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
middle_block = self.target.ensure_find(MiddleBlock)
|
||||||
|
middle_block.remove(middle_block.ensure_find(SelfAttentionShape))
|
||||||
|
|
||||||
|
self_attn = middle_block.ensure_find(fl.SelfAttention)
|
||||||
|
self_attn.remove(self_attn.ensure_find(SelfAttentionMap))
|
||||||
|
|
||||||
|
super().eject()
|
|
@ -64,6 +64,11 @@ def expected_image_std_random_init(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_std_random_init.png").convert("RGB")
|
return Image.open(ref_path / "expected_std_random_init.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_image_std_random_init_sag(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(ref_path / "expected_std_random_init_sag.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_std_init_image(ref_path: Path) -> Image.Image:
|
def expected_image_std_init_image(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_std_init_image.png").convert("RGB")
|
return Image.open(ref_path / "expected_std_init_image.png").convert("RGB")
|
||||||
|
@ -109,6 +114,11 @@ def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert(mode="RGB")
|
return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert(mode="RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_sdxl_ddim_random_init_sag(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert(mode="RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
|
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
|
||||||
def controlnet_data(
|
def controlnet_data(
|
||||||
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
|
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
|
||||||
|
@ -514,6 +524,35 @@ def test_diffusion_std_random_init_float16(
|
||||||
ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_diffusion_std_random_init_sag(
|
||||||
|
sd15_std: StableDiffusion_1, expected_image_std_random_init_sag: Image.Image, test_device: torch.device
|
||||||
|
):
|
||||||
|
sd15 = sd15_std
|
||||||
|
n_steps = 30
|
||||||
|
|
||||||
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
|
|
||||||
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
sd15.set_self_attention_guidance(enable=True, scale=0.75)
|
||||||
|
|
||||||
|
manual_seed(2)
|
||||||
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
|
||||||
|
for step in sd15.steps:
|
||||||
|
x = sd15(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
condition_scale=7.5,
|
||||||
|
)
|
||||||
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
|
||||||
|
ensure_similar_images(predicted_image, expected_image_std_random_init_sag)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test_diffusion_std_init_image(
|
def test_diffusion_std_init_image(
|
||||||
sd15_std: StableDiffusion_1,
|
sd15_std: StableDiffusion_1,
|
||||||
|
@ -1364,6 +1403,42 @@ def test_sdxl_random_init(
|
||||||
ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_sdxl_random_init_sag(
|
||||||
|
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init_sag: Image.Image, test_device: torch.device
|
||||||
|
) -> None:
|
||||||
|
sdxl = sdxl_ddim
|
||||||
|
expected_image = expected_sdxl_ddim_random_init_sag
|
||||||
|
n_steps = 30
|
||||||
|
|
||||||
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
|
||||||
|
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||||
|
text=prompt, negative_text=negative_prompt
|
||||||
|
)
|
||||||
|
time_ids = sdxl.default_time_ids
|
||||||
|
|
||||||
|
sdxl.set_num_inference_steps(num_inference_steps=n_steps)
|
||||||
|
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
|
||||||
|
|
||||||
|
manual_seed(seed=2)
|
||||||
|
x = torch.randn(1, 4, 128, 128, device=test_device)
|
||||||
|
|
||||||
|
for step in sdxl.steps:
|
||||||
|
x = sdxl(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
pooled_text_embedding=pooled_text_embedding,
|
||||||
|
time_ids=time_ids,
|
||||||
|
condition_scale=5,
|
||||||
|
)
|
||||||
|
predicted_image = sdxl.lda.decode_latents(x=x)
|
||||||
|
|
||||||
|
ensure_similar_images(img_1=predicted_image, img_2=expected_image)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None:
|
def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None:
|
||||||
manual_seed(seed=2)
|
manual_seed(seed=2)
|
||||||
|
|
|
@ -34,6 +34,7 @@ output.images[0].save("std_random_init_expected.png")
|
||||||
|
|
||||||
Special cases:
|
Special cases:
|
||||||
|
|
||||||
|
- For self-attention guidance, `StableDiffusionSAGPipeline` has been used instead of the default pipeline.
|
||||||
- `expected_refonly.png` has been generated [with Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui).
|
- `expected_refonly.png` has been generated [with Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui).
|
||||||
- The following references have been generated with refiners itself (and inspected so that they look reasonable):
|
- The following references have been generated with refiners itself (and inspected so that they look reasonable):
|
||||||
- `expected_inpainting_refonly.png`,
|
- `expected_inpainting_refonly.png`,
|
||||||
|
@ -42,6 +43,7 @@ Special cases:
|
||||||
- `expected_ip_adapter_controlnet.png`
|
- `expected_ip_adapter_controlnet.png`
|
||||||
- `expected_t2i_adapter_xl_canny.png`
|
- `expected_t2i_adapter_xl_canny.png`
|
||||||
- `expected_image_sdxl_ip_adapter_plus_woman.png`
|
- `expected_image_sdxl_ip_adapter_plus_woman.png`
|
||||||
|
- `expected_cutecat_sdxl_ddim_random_init_sag.png`
|
||||||
|
|
||||||
## Other images
|
## Other images
|
||||||
|
|
||||||
|
|
Binary file not shown.
After Width: | Height: | Size: 1.6 MiB |
BIN
tests/e2e/test_diffusion_ref/expected_std_random_init_sag.png
Normal file
BIN
tests/e2e/test_diffusion_ref/expected_std_random_init_sag.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 493 KiB |
Loading…
Reference in a new issue