diff --git a/src/refiners/fluxion/adapters/adapter.py b/src/refiners/fluxion/adapters/adapter.py index 5f55d40..f862dde 100644 --- a/src/refiners/fluxion/adapters/adapter.py +++ b/src/refiners/fluxion/adapters/adapter.py @@ -79,7 +79,7 @@ class Adapter(Generic[T]): def _pre_structural_copy(self) -> None: if isinstance(self.target, fl.Chain): - raise RuntimeError("Chain adapters typically cannot be copied, eject them first.") + raise RuntimeError(f"Chain adapters ({self}) typically cannot be copied, eject them first.") def _post_structural_copy(self: TAdapter, source: TAdapter) -> None: self._target = [source.target] diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py index a4f003d..a4c312d 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py @@ -19,6 +19,7 @@ class SD1Autoencoder(LatentDiffusionAutoencoder): class StableDiffusion_1(LatentDiffusionModel): unet: SD1UNet clip_text_encoder: CLIPTextEncoderL + lda: SD1Autoencoder def __init__( self, diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py index 971cc1a..b272f0a 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py @@ -16,6 +16,7 @@ class SDXLAutoencoder(LatentDiffusionAutoencoder): class StableDiffusion_XL(LatentDiffusionModel): unet: SDXLUNet clip_text_encoder: DoubleTextEncoder + lda: SDXLAutoencoder def __init__( self, diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index a6426dc..76d3851 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -7,6 +7,7 @@ import pytest import torch from PIL import Image +from refiners.fluxion.layers.attentions import ScaledDotProductAttention from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad from refiners.foundationals.clip.concepts import ConceptExtender from refiners.foundationals.latent_diffusion import ( @@ -1640,6 +1641,49 @@ def test_sdxl_random_init_sag( ensure_similar_images(img_1=predicted_image, img_2=expected_image) +@no_grad() +def test_diffusion_sdxl_sliced_attention( + sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image +) -> None: + unet = sdxl_ddim.unet.structural_copy() + for layer in unet.layers(ScaledDotProductAttention): + layer.slice_size = 2048 + + sdxl = StableDiffusion_XL( + unet=unet, + lda=sdxl_ddim.lda, + clip_text_encoder=sdxl_ddim.clip_text_encoder, + solver=sdxl_ddim.solver, + device=sdxl_ddim.device, + dtype=sdxl_ddim.dtype, + ) + + expected_image = expected_sdxl_ddim_random_init + + prompt = "a cute cat, detailed high-quality professional image" + negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" + + clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( + text=prompt, negative_text=negative_prompt + ) + time_ids = sdxl.default_time_ids + sdxl.set_inference_steps(30) + manual_seed(2) + x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype) + for step in sdxl.steps: + x = sdxl( + x, + step=step, + clip_text_embedding=clip_text_embedding, + pooled_text_embedding=pooled_text_embedding, + time_ids=time_ids, + condition_scale=5, + ) + + predicted_image = sdxl.lda.decode_latents(x) + ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) + + @no_grad() def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None: manual_seed(seed=2)