mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
add a test for SDXL with sliced attention
This commit is contained in:
parent
3ddd258d36
commit
5ac5373310
|
@ -79,7 +79,7 @@ class Adapter(Generic[T]):
|
|||
|
||||
def _pre_structural_copy(self) -> None:
|
||||
if isinstance(self.target, fl.Chain):
|
||||
raise RuntimeError("Chain adapters typically cannot be copied, eject them first.")
|
||||
raise RuntimeError(f"Chain adapters ({self}) typically cannot be copied, eject them first.")
|
||||
|
||||
def _post_structural_copy(self: TAdapter, source: TAdapter) -> None:
|
||||
self._target = [source.target]
|
||||
|
|
|
@ -19,6 +19,7 @@ class SD1Autoencoder(LatentDiffusionAutoencoder):
|
|||
class StableDiffusion_1(LatentDiffusionModel):
|
||||
unet: SD1UNet
|
||||
clip_text_encoder: CLIPTextEncoderL
|
||||
lda: SD1Autoencoder
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -16,6 +16,7 @@ class SDXLAutoencoder(LatentDiffusionAutoencoder):
|
|||
class StableDiffusion_XL(LatentDiffusionModel):
|
||||
unet: SDXLUNet
|
||||
clip_text_encoder: DoubleTextEncoder
|
||||
lda: SDXLAutoencoder
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -7,6 +7,7 @@ import pytest
|
|||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
||||
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad
|
||||
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||
from refiners.foundationals.latent_diffusion import (
|
||||
|
@ -1640,6 +1641,49 @@ def test_sdxl_random_init_sag(
|
|||
ensure_similar_images(img_1=predicted_image, img_2=expected_image)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_diffusion_sdxl_sliced_attention(
|
||||
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image
|
||||
) -> None:
|
||||
unet = sdxl_ddim.unet.structural_copy()
|
||||
for layer in unet.layers(ScaledDotProductAttention):
|
||||
layer.slice_size = 2048
|
||||
|
||||
sdxl = StableDiffusion_XL(
|
||||
unet=unet,
|
||||
lda=sdxl_ddim.lda,
|
||||
clip_text_encoder=sdxl_ddim.clip_text_encoder,
|
||||
solver=sdxl_ddim.solver,
|
||||
device=sdxl_ddim.device,
|
||||
dtype=sdxl_ddim.dtype,
|
||||
)
|
||||
|
||||
expected_image = expected_sdxl_ddim_random_init
|
||||
|
||||
prompt = "a cute cat, detailed high-quality professional image"
|
||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
|
||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||
text=prompt, negative_text=negative_prompt
|
||||
)
|
||||
time_ids = sdxl.default_time_ids
|
||||
sdxl.set_inference_steps(30)
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype)
|
||||
for step in sdxl.steps:
|
||||
x = sdxl(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
pooled_text_embedding=pooled_text_embedding,
|
||||
time_ids=time_ids,
|
||||
condition_scale=5,
|
||||
)
|
||||
|
||||
predicted_image = sdxl.lda.decode_latents(x)
|
||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None:
|
||||
manual_seed(seed=2)
|
||||
|
|
Loading…
Reference in a new issue