add a test for SDXL with sliced attention

This commit is contained in:
Pierre Chapuis 2024-01-30 16:49:30 +01:00 committed by Cédric Deltheil
parent 3ddd258d36
commit 5ac5373310
4 changed files with 47 additions and 1 deletions

View file

@ -79,7 +79,7 @@ class Adapter(Generic[T]):
def _pre_structural_copy(self) -> None:
if isinstance(self.target, fl.Chain):
raise RuntimeError("Chain adapters typically cannot be copied, eject them first.")
raise RuntimeError(f"Chain adapters ({self}) typically cannot be copied, eject them first.")
def _post_structural_copy(self: TAdapter, source: TAdapter) -> None:
self._target = [source.target]

View file

@ -19,6 +19,7 @@ class SD1Autoencoder(LatentDiffusionAutoencoder):
class StableDiffusion_1(LatentDiffusionModel):
unet: SD1UNet
clip_text_encoder: CLIPTextEncoderL
lda: SD1Autoencoder
def __init__(
self,

View file

@ -16,6 +16,7 @@ class SDXLAutoencoder(LatentDiffusionAutoencoder):
class StableDiffusion_XL(LatentDiffusionModel):
unet: SDXLUNet
clip_text_encoder: DoubleTextEncoder
lda: SDXLAutoencoder
def __init__(
self,

View file

@ -7,6 +7,7 @@ import pytest
import torch
from PIL import Image
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad
from refiners.foundationals.clip.concepts import ConceptExtender
from refiners.foundationals.latent_diffusion import (
@ -1640,6 +1641,49 @@ def test_sdxl_random_init_sag(
ensure_similar_images(img_1=predicted_image, img_2=expected_image)
@no_grad()
def test_diffusion_sdxl_sliced_attention(
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image
) -> None:
unet = sdxl_ddim.unet.structural_copy()
for layer in unet.layers(ScaledDotProductAttention):
layer.slice_size = 2048
sdxl = StableDiffusion_XL(
unet=unet,
lda=sdxl_ddim.lda,
clip_text_encoder=sdxl_ddim.clip_text_encoder,
solver=sdxl_ddim.solver,
device=sdxl_ddim.device,
dtype=sdxl_ddim.dtype,
)
expected_image = expected_sdxl_ddim_random_init
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt
)
time_ids = sdxl.default_time_ids
sdxl.set_inference_steps(30)
manual_seed(2)
x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=5,
)
predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None:
manual_seed(seed=2)