diff --git a/docs/guides/adapting_sdxl/index.md b/docs/guides/adapting_sdxl/index.md index 5ac1a96..59a9873 100644 --- a/docs/guides/adapting_sdxl/index.md +++ b/docs/guides/adapting_sdxl/index.md @@ -58,9 +58,9 @@ Then, define the inference parameters by setting the appropriate prompt / seed / prompt = "a futuristic castle surrounded by a forest, mountains in the background" seed = 42 sdxl.set_inference_steps(50, first_step=0) -sdxl.set_self_attention_guidance( - enable=True, scale=0.75 -) # Enable self-attention guidance to enhance the quality of the generated images + +# Enable self-attention guidance to enhance the quality of the generated images +sdxl.set_self_attention_guidance(enable=True, scale=0.75) # ... Inference process @@ -76,10 +76,10 @@ with no_grad(): # Disable gradient calculation for memory-efficient inference ) time_ids = sdxl.default_time_ids - manual_seed(seed=seed) + manual_seed(seed) - # Using a higher latents inner dim to improve resolution of generated images - x = torch.randn(size=(1, 4, 256, 256), device=sdxl.device, dtype=sdxl.dtype) + # SDXL typically generates 1024x1024, here we use a higher resolution. + x = sdxl.init_latents((2048, 2048)).to(sdxl.device, sdxl.dtype) # Diffusion process for step in sdxl.steps: @@ -131,8 +131,8 @@ predicted_image.save("vanilla_sdxl.png") manual_seed(seed=seed) - # Using a higher latents inner dim to improve resolution of generated images - x = torch.randn(size=(1, 4, 256, 256), device=sdxl.device, dtype=sdxl.dtype) + # SDXL typically generates 1024x1024, here we use a higher resolution. + x = sdxl.init_latents((2048, 2048)).to(sdxl.device, sdxl.dtype) # Diffusion process for step in sdxl.steps: @@ -213,8 +213,8 @@ manager.add_loras("scifi-lora", tensors=scifi_lora_weights) manual_seed(seed=seed) - # Using a higher latents inner dim to improve resolution of generated images - x = torch.randn(size=(1, 4, 256, 256), device=sdxl.device, dtype=sdxl.dtype) + # SDXL typically generates 1024x1024, here we use a higher resolution. + x = sdxl.init_latents((2048, 2048)).to(sdxl.device, sdxl.dtype) # Diffusion process for step in sdxl.steps: @@ -304,8 +304,8 @@ manager.add_loras("pixel-art-lora", load_from_safetensors("pixel-art-xl-v1.1.saf manual_seed(seed=seed) - # Using a higher latents inner dim to improve resolution of generated images - x = torch.randn(size=(1, 4, 256, 256), device=sdxl.device, dtype=sdxl.dtype) + # SDXL typically generates 1024x1024, here we use a higher resolution. + x = sdxl.init_latents((2048, 2048)).to(sdxl.device, sdxl.dtype) # Diffusion process for step in sdxl.steps: @@ -440,7 +440,7 @@ with torch.no_grad(): ip_adapter.set_clip_image_embedding(clip_image_embedding) manual_seed(seed=seed) - x = torch.randn(size=(1, 4, 128, 128), device=sdxl.device, dtype=sdxl.dtype) + x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype) # Diffusion process for step in sdxl.steps: @@ -578,7 +578,7 @@ with torch.no_grad(): t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition)) manual_seed(seed=seed) - x = torch.randn(size=(1, 4, 128, 128), device=sdxl.device, dtype=sdxl.dtype) + x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype) # Diffusion process for step in sdxl.steps: diff --git a/scripts/prepare_test_weights.py b/scripts/prepare_test_weights.py index 8243d9d..58213ac 100644 --- a/scripts/prepare_test_weights.py +++ b/scripts/prepare_test_weights.py @@ -253,6 +253,20 @@ def download_loras(): ) download_file("https://sliders.baulab.info/weights/xl_sliders/eyesize.pt", dest_folder, expected_hash="ee170e4d") + dest_folder = os.path.join(test_weights_dir, "loras") + download_file( + "https://civitai.com/api/download/models/140624", + filename="Sci-fi_Environments_sdxl.safetensors", + dest_folder=dest_folder, + expected_hash="6a4afda8", + ) + download_file( + "https://civitai.com/api/download/models/135931", + filename="pixel-art-xl-v1.1.safetensors", + dest_folder=dest_folder, + expected_hash="71aaa6ca", + ) + def download_preprocessors(): dest_folder = os.path.join(test_weights_dir, "carolineec", "informativedrawings") diff --git a/tests/e2e/test_doc_examples.py b/tests/e2e/test_doc_examples.py new file mode 100644 index 0000000..c2b51ce --- /dev/null +++ b/tests/e2e/test_doc_examples.py @@ -0,0 +1,321 @@ +import gc +from pathlib import Path +from warnings import warn + +import pytest +import torch +from PIL import Image + +from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad +from refiners.foundationals.latent_diffusion import SDXLIPAdapter +from refiners.foundationals.latent_diffusion.lora import SDLoraManager +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL +from tests.utils import ensure_similar_images + + +@pytest.fixture(autouse=True) +def ensure_gc(): + # Avoid GPU OOMs + # See https://github.com/pytest-dev/pytest/discussions/8153#discussioncomment-214812 + gc.collect() + + +@pytest.fixture(scope="module") +def ref_path(test_e2e_path: Path) -> Path: + return test_e2e_path / "test_doc_examples_ref" + + +@pytest.fixture(scope="module") +def sdxl_text_encoder_weights(test_weights_path: Path) -> Path: + path = test_weights_path / "DoubleCLIPTextEncoder.safetensors" + if not path.is_file(): + warn(message=f"could not find weights at {path}, skipping") + pytest.skip(allow_module_level=True) + return path + + +@pytest.fixture(scope="module") +def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path: + path = test_weights_path / "sdxl-lda-fp16-fix.safetensors" + if not path.is_file(): + warn(message=f"could not find weights at {path}, skipping") + pytest.skip(allow_module_level=True) + return path + + +@pytest.fixture(scope="module") +def sdxl_unet_weights(test_weights_path: Path) -> Path: + path = test_weights_path / "sdxl-unet.safetensors" + if not path.is_file(): + warn(message=f"could not find weights at {path}, skipping") + pytest.skip(allow_module_level=True) + return path + + +@pytest.fixture(scope="module") +def sdxl_ip_adapter_plus_weights(test_weights_path: Path) -> Path: + path = test_weights_path / "ip-adapter-plus_sdxl_vit-h.safetensors" + if not path.is_file(): + warn(f"could not find weights at {path}, skipping") + pytest.skip(allow_module_level=True) + return path + + +@pytest.fixture(scope="module") +def image_encoder_weights(test_weights_path: Path) -> Path: + path = test_weights_path / "CLIPImageEncoderH.safetensors" + if not path.is_file(): + warn(f"could not find weights at {path}, skipping") + pytest.skip(allow_module_level=True) + return path + + +@pytest.fixture +def scifi_lora_weights(test_weights_path: Path) -> Path: + path = test_weights_path / "loras" / "Sci-fi_Environments_sdxl.safetensors" + if not path.is_file(): + warn(message=f"could not find weights at {path}, skipping") + pytest.skip(allow_module_level=True) + return path + + +@pytest.fixture +def pixelart_lora_weights(test_weights_path: Path) -> Path: + path = test_weights_path / "loras" / "pixel-art-xl-v1.1.safetensors" + if not path.is_file(): + warn(message=f"could not find weights at {path}, skipping") + pytest.skip(allow_module_level=True) + return path + + +@pytest.fixture +def sdxl( + sdxl_text_encoder_weights: Path, + sdxl_lda_fp16_fix_weights: Path, + sdxl_unet_weights: Path, + test_device: torch.device, +) -> StableDiffusion_XL: + if test_device.type == "cpu": + warn(message="not running on CPU, skipping") + pytest.skip() + + sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16) + + sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights) + sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights) + sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights) + + return sdxl + + +@pytest.fixture +def image_prompt_german_castle(ref_path: Path) -> Image.Image: + return Image.open(ref_path / "german-castle.jpg").convert("RGB") + + +@pytest.fixture +def expected_image_guide_adapting_sdxl_vanilla(ref_path: Path) -> Image.Image: + return Image.open(ref_path / "expected_image_guide_adapting_sdxl_vanilla.png").convert("RGB") + + +@pytest.fixture +def expected_image_guide_adapting_sdxl_single_lora(ref_path: Path) -> Image.Image: + return Image.open(ref_path / "expected_image_guide_adapting_sdxl_single_lora.png").convert("RGB") + + +@pytest.fixture +def expected_image_guide_adapting_sdxl_multiple_loras(ref_path: Path) -> Image.Image: + return Image.open(ref_path / "expected_image_guide_adapting_sdxl_multiple_loras.png").convert("RGB") + + +@pytest.fixture +def expected_image_guide_adapting_sdxl_loras_ip_adapter(ref_path: Path) -> Image.Image: + return Image.open(ref_path / "expected_image_guide_adapting_sdxl_loras_ip_adapter.png").convert("RGB") + + +@no_grad() +def test_guide_adapting_sdxl_vanilla( + test_device: torch.device, + sdxl: StableDiffusion_XL, + expected_image_guide_adapting_sdxl_vanilla: Image.Image, +) -> None: + if test_device.type == "cpu": + warn(message="not running on CPU, skipping") + pytest.skip() + + expected_image = expected_image_guide_adapting_sdxl_vanilla + + prompt = "a futuristic castle surrounded by a forest, mountains in the background" + seed = 42 + sdxl.set_inference_steps(50, first_step=0) + sdxl.set_self_attention_guidance(enable=True, scale=0.75) + + clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( + text=prompt + ", best quality, high quality", + negative_text="monochrome, lowres, bad anatomy, worst quality, low quality", + ) + time_ids = sdxl.default_time_ids + + manual_seed(seed) + # The guide uses 2048x2048 but it is too slow for tests. + x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype) + for step in sdxl.steps: + x = sdxl( + x, + step=step, + clip_text_embedding=clip_text_embedding, + pooled_text_embedding=pooled_text_embedding, + time_ids=time_ids, + ) + + predicted_image = sdxl.lda.decode_latents(x) + ensure_similar_images(predicted_image, expected_image) + + +@no_grad() +def test_guide_adapting_sdxl_single_lora( + test_device: torch.device, + sdxl: StableDiffusion_XL, + scifi_lora_weights: Path, + expected_image_guide_adapting_sdxl_single_lora: Image.Image, +) -> None: + if test_device.type == "cpu": + warn(message="not running on CPU, skipping") + pytest.skip() + + expected_image = expected_image_guide_adapting_sdxl_single_lora + + prompt = "a futuristic castle surrounded by a forest, mountains in the background" + seed = 42 + sdxl.set_inference_steps(50, first_step=0) + sdxl.set_self_attention_guidance(enable=True, scale=0.75) + + manager = SDLoraManager(sdxl) + manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights)) + + clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( + text=prompt + ", best quality, high quality", + negative_text="monochrome, lowres, bad anatomy, worst quality, low quality", + ) + time_ids = sdxl.default_time_ids + + manual_seed(seed) + x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype) + for step in sdxl.steps: + x = sdxl( + x, + step=step, + clip_text_embedding=clip_text_embedding, + pooled_text_embedding=pooled_text_embedding, + time_ids=time_ids, + ) + + predicted_image = sdxl.lda.decode_latents(x) + ensure_similar_images(predicted_image, expected_image) + + +@no_grad() +def test_guide_adapting_sdxl_multiple_loras( + test_device: torch.device, + sdxl: StableDiffusion_XL, + scifi_lora_weights: Path, + pixelart_lora_weights: Path, + expected_image_guide_adapting_sdxl_multiple_loras: Image.Image, +) -> None: + if test_device.type == "cpu": + warn(message="not running on CPU, skipping") + pytest.skip() + + expected_image = expected_image_guide_adapting_sdxl_multiple_loras + + prompt = "a futuristic castle surrounded by a forest, mountains in the background" + seed = 42 + sdxl.set_inference_steps(50, first_step=0) + sdxl.set_self_attention_guidance(enable=True, scale=0.75) + + manager = SDLoraManager(sdxl) + manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights)) + manager.add_loras("pixel-art-lora", load_from_safetensors(pixelart_lora_weights), scale=1.4) + + clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( + text=prompt + ", best quality, high quality", + negative_text="monochrome, lowres, bad anatomy, worst quality, low quality", + ) + time_ids = sdxl.default_time_ids + + manual_seed(seed) + x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype) + for step in sdxl.steps: + x = sdxl( + x, + step=step, + clip_text_embedding=clip_text_embedding, + pooled_text_embedding=pooled_text_embedding, + time_ids=time_ids, + ) + + predicted_image = sdxl.lda.decode_latents(x) + ensure_similar_images(predicted_image, expected_image) + + +@no_grad() +def test_guide_adapting_sdxl_loras_ip_adapter( + test_device: torch.device, + sdxl: StableDiffusion_XL, + sdxl_ip_adapter_plus_weights: Path, + image_encoder_weights: Path, + scifi_lora_weights: Path, + pixelart_lora_weights: Path, + image_prompt_german_castle: Image.Image, + expected_image_guide_adapting_sdxl_loras_ip_adapter: Image.Image, +) -> None: + if test_device.type == "cpu": + warn(message="not running on CPU, skipping") + pytest.skip() + + expected_image = expected_image_guide_adapting_sdxl_loras_ip_adapter + + prompt = "a futuristic castle surrounded by a forest, mountains in the background" + seed = 42 + sdxl.set_inference_steps(50, first_step=0) + sdxl.set_self_attention_guidance(enable=True, scale=0.75) + + manager = SDLoraManager(sdxl) + manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights), scale=1.5) + manager.add_loras("pixel-art-lora", load_from_safetensors(pixelart_lora_weights), scale=1.55) + + ip_adapter = SDXLIPAdapter( + target=sdxl.unet, + weights=load_from_safetensors(sdxl_ip_adapter_plus_weights), + scale=1.0, + fine_grained=True, + ) + ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) + ip_adapter.inject() + + clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( + text=prompt + ", best quality, high quality", + negative_text="monochrome, lowres, bad anatomy, worst quality, low quality", + ) + time_ids = sdxl.default_time_ids + + image_prompt_preprocessed = ip_adapter.preprocess_image(image_prompt_german_castle) + clip_image_embedding = ip_adapter.compute_clip_image_embedding(image_prompt_preprocessed) + ip_adapter.set_clip_image_embedding(clip_image_embedding) + + manual_seed(seed) + x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype) + for step in sdxl.steps: + x = sdxl( + x, + step=step, + clip_text_embedding=clip_text_embedding, + pooled_text_embedding=pooled_text_embedding, + time_ids=time_ids, + ) + + predicted_image = sdxl.lda.decode_latents(x) + ensure_similar_images(predicted_image, expected_image) + + +# We do not (yet) test the last example using T2i-Adapter with Zoe Depth. diff --git a/tests/e2e/test_doc_examples_ref/README.md b/tests/e2e/test_doc_examples_ref/README.md new file mode 100644 index 0000000..bf9bb65 --- /dev/null +++ b/tests/e2e/test_doc_examples_ref/README.md @@ -0,0 +1,5 @@ +# Note about this data + +Everything in this directory comes from Refiners' documentation. + +Some outputs are different because we perform inference in 1024x1024 and not 2048x2048. diff --git a/tests/e2e/test_doc_examples_ref/expected_image_guide_adapting_sdxl_loras_ip_adapter.png b/tests/e2e/test_doc_examples_ref/expected_image_guide_adapting_sdxl_loras_ip_adapter.png new file mode 100644 index 0000000..2771ea7 Binary files /dev/null and b/tests/e2e/test_doc_examples_ref/expected_image_guide_adapting_sdxl_loras_ip_adapter.png differ diff --git a/tests/e2e/test_doc_examples_ref/expected_image_guide_adapting_sdxl_multiple_loras.png b/tests/e2e/test_doc_examples_ref/expected_image_guide_adapting_sdxl_multiple_loras.png new file mode 100644 index 0000000..7cda803 Binary files /dev/null and b/tests/e2e/test_doc_examples_ref/expected_image_guide_adapting_sdxl_multiple_loras.png differ diff --git a/tests/e2e/test_doc_examples_ref/expected_image_guide_adapting_sdxl_single_lora.png b/tests/e2e/test_doc_examples_ref/expected_image_guide_adapting_sdxl_single_lora.png new file mode 100644 index 0000000..1c76375 Binary files /dev/null and b/tests/e2e/test_doc_examples_ref/expected_image_guide_adapting_sdxl_single_lora.png differ diff --git a/tests/e2e/test_doc_examples_ref/expected_image_guide_adapting_sdxl_vanilla.png b/tests/e2e/test_doc_examples_ref/expected_image_guide_adapting_sdxl_vanilla.png new file mode 100644 index 0000000..30460dc Binary files /dev/null and b/tests/e2e/test_doc_examples_ref/expected_image_guide_adapting_sdxl_vanilla.png differ diff --git a/tests/e2e/test_doc_examples_ref/german-castle.jpg b/tests/e2e/test_doc_examples_ref/german-castle.jpg new file mode 100644 index 0000000..ec9015b Binary files /dev/null and b/tests/e2e/test_doc_examples_ref/german-castle.jpg differ