mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
add a test for SDXL with sliced attention
This commit is contained in:
parent
3ddd258d36
commit
5ac5373310
|
@ -79,7 +79,7 @@ class Adapter(Generic[T]):
|
||||||
|
|
||||||
def _pre_structural_copy(self) -> None:
|
def _pre_structural_copy(self) -> None:
|
||||||
if isinstance(self.target, fl.Chain):
|
if isinstance(self.target, fl.Chain):
|
||||||
raise RuntimeError("Chain adapters typically cannot be copied, eject them first.")
|
raise RuntimeError(f"Chain adapters ({self}) typically cannot be copied, eject them first.")
|
||||||
|
|
||||||
def _post_structural_copy(self: TAdapter, source: TAdapter) -> None:
|
def _post_structural_copy(self: TAdapter, source: TAdapter) -> None:
|
||||||
self._target = [source.target]
|
self._target = [source.target]
|
||||||
|
|
|
@ -19,6 +19,7 @@ class SD1Autoencoder(LatentDiffusionAutoencoder):
|
||||||
class StableDiffusion_1(LatentDiffusionModel):
|
class StableDiffusion_1(LatentDiffusionModel):
|
||||||
unet: SD1UNet
|
unet: SD1UNet
|
||||||
clip_text_encoder: CLIPTextEncoderL
|
clip_text_encoder: CLIPTextEncoderL
|
||||||
|
lda: SD1Autoencoder
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -16,6 +16,7 @@ class SDXLAutoencoder(LatentDiffusionAutoencoder):
|
||||||
class StableDiffusion_XL(LatentDiffusionModel):
|
class StableDiffusion_XL(LatentDiffusionModel):
|
||||||
unet: SDXLUNet
|
unet: SDXLUNet
|
||||||
clip_text_encoder: DoubleTextEncoder
|
clip_text_encoder: DoubleTextEncoder
|
||||||
|
lda: SDXLAutoencoder
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -7,6 +7,7 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
||||||
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad
|
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad
|
||||||
from refiners.foundationals.clip.concepts import ConceptExtender
|
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||||
from refiners.foundationals.latent_diffusion import (
|
from refiners.foundationals.latent_diffusion import (
|
||||||
|
@ -1640,6 +1641,49 @@ def test_sdxl_random_init_sag(
|
||||||
ensure_similar_images(img_1=predicted_image, img_2=expected_image)
|
ensure_similar_images(img_1=predicted_image, img_2=expected_image)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_diffusion_sdxl_sliced_attention(
|
||||||
|
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image
|
||||||
|
) -> None:
|
||||||
|
unet = sdxl_ddim.unet.structural_copy()
|
||||||
|
for layer in unet.layers(ScaledDotProductAttention):
|
||||||
|
layer.slice_size = 2048
|
||||||
|
|
||||||
|
sdxl = StableDiffusion_XL(
|
||||||
|
unet=unet,
|
||||||
|
lda=sdxl_ddim.lda,
|
||||||
|
clip_text_encoder=sdxl_ddim.clip_text_encoder,
|
||||||
|
solver=sdxl_ddim.solver,
|
||||||
|
device=sdxl_ddim.device,
|
||||||
|
dtype=sdxl_ddim.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_image = expected_sdxl_ddim_random_init
|
||||||
|
|
||||||
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
|
||||||
|
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||||
|
text=prompt, negative_text=negative_prompt
|
||||||
|
)
|
||||||
|
time_ids = sdxl.default_time_ids
|
||||||
|
sdxl.set_inference_steps(30)
|
||||||
|
manual_seed(2)
|
||||||
|
x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype)
|
||||||
|
for step in sdxl.steps:
|
||||||
|
x = sdxl(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
pooled_text_embedding=pooled_text_embedding,
|
||||||
|
time_ids=time_ids,
|
||||||
|
condition_scale=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
predicted_image = sdxl.lda.decode_latents(x)
|
||||||
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None:
|
def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None:
|
||||||
manual_seed(seed=2)
|
manual_seed(seed=2)
|
||||||
|
|
Loading…
Reference in a new issue