add test for "Adapting SDXL" guide

This commit is contained in:
Pierre Chapuis 2024-03-06 10:56:46 +01:00
parent cd5fa97c20
commit 5d784bedab
9 changed files with 354 additions and 14 deletions

View file

@ -58,9 +58,9 @@ Then, define the inference parameters by setting the appropriate prompt / seed /
prompt = "a futuristic castle surrounded by a forest, mountains in the background"
seed = 42
sdxl.set_inference_steps(50, first_step=0)
sdxl.set_self_attention_guidance(
enable=True, scale=0.75
) # Enable self-attention guidance to enhance the quality of the generated images
# Enable self-attention guidance to enhance the quality of the generated images
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
# ... Inference process
@ -76,10 +76,10 @@ with no_grad(): # Disable gradient calculation for memory-efficient inference
)
time_ids = sdxl.default_time_ids
manual_seed(seed=seed)
manual_seed(seed)
# Using a higher latents inner dim to improve resolution of generated images
x = torch.randn(size=(1, 4, 256, 256), device=sdxl.device, dtype=sdxl.dtype)
# SDXL typically generates 1024x1024, here we use a higher resolution.
x = sdxl.init_latents((2048, 2048)).to(sdxl.device, sdxl.dtype)
# Diffusion process
for step in sdxl.steps:
@ -131,8 +131,8 @@ predicted_image.save("vanilla_sdxl.png")
manual_seed(seed=seed)
# Using a higher latents inner dim to improve resolution of generated images
x = torch.randn(size=(1, 4, 256, 256), device=sdxl.device, dtype=sdxl.dtype)
# SDXL typically generates 1024x1024, here we use a higher resolution.
x = sdxl.init_latents((2048, 2048)).to(sdxl.device, sdxl.dtype)
# Diffusion process
for step in sdxl.steps:
@ -213,8 +213,8 @@ manager.add_loras("scifi-lora", tensors=scifi_lora_weights)
manual_seed(seed=seed)
# Using a higher latents inner dim to improve resolution of generated images
x = torch.randn(size=(1, 4, 256, 256), device=sdxl.device, dtype=sdxl.dtype)
# SDXL typically generates 1024x1024, here we use a higher resolution.
x = sdxl.init_latents((2048, 2048)).to(sdxl.device, sdxl.dtype)
# Diffusion process
for step in sdxl.steps:
@ -304,8 +304,8 @@ manager.add_loras("pixel-art-lora", load_from_safetensors("pixel-art-xl-v1.1.saf
manual_seed(seed=seed)
# Using a higher latents inner dim to improve resolution of generated images
x = torch.randn(size=(1, 4, 256, 256), device=sdxl.device, dtype=sdxl.dtype)
# SDXL typically generates 1024x1024, here we use a higher resolution.
x = sdxl.init_latents((2048, 2048)).to(sdxl.device, sdxl.dtype)
# Diffusion process
for step in sdxl.steps:
@ -440,7 +440,7 @@ with torch.no_grad():
ip_adapter.set_clip_image_embedding(clip_image_embedding)
manual_seed(seed=seed)
x = torch.randn(size=(1, 4, 128, 128), device=sdxl.device, dtype=sdxl.dtype)
x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype)
# Diffusion process
for step in sdxl.steps:
@ -578,7 +578,7 @@ with torch.no_grad():
t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))
manual_seed(seed=seed)
x = torch.randn(size=(1, 4, 128, 128), device=sdxl.device, dtype=sdxl.dtype)
x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype)
# Diffusion process
for step in sdxl.steps:

View file

@ -253,6 +253,20 @@ def download_loras():
)
download_file("https://sliders.baulab.info/weights/xl_sliders/eyesize.pt", dest_folder, expected_hash="ee170e4d")
dest_folder = os.path.join(test_weights_dir, "loras")
download_file(
"https://civitai.com/api/download/models/140624",
filename="Sci-fi_Environments_sdxl.safetensors",
dest_folder=dest_folder,
expected_hash="6a4afda8",
)
download_file(
"https://civitai.com/api/download/models/135931",
filename="pixel-art-xl-v1.1.safetensors",
dest_folder=dest_folder,
expected_hash="71aaa6ca",
)
def download_preprocessors():
dest_folder = os.path.join(test_weights_dir, "carolineec", "informativedrawings")

View file

@ -0,0 +1,321 @@
import gc
from pathlib import Path
from warnings import warn
import pytest
import torch
from PIL import Image
from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad
from refiners.foundationals.latent_diffusion import SDXLIPAdapter
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
from tests.utils import ensure_similar_images
@pytest.fixture(autouse=True)
def ensure_gc():
# Avoid GPU OOMs
# See https://github.com/pytest-dev/pytest/discussions/8153#discussioncomment-214812
gc.collect()
@pytest.fixture(scope="module")
def ref_path(test_e2e_path: Path) -> Path:
return test_e2e_path / "test_doc_examples_ref"
@pytest.fixture(scope="module")
def sdxl_text_encoder_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
if not path.is_file():
warn(message=f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture(scope="module")
def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
if not path.is_file():
warn(message=f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture(scope="module")
def sdxl_unet_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "sdxl-unet.safetensors"
if not path.is_file():
warn(message=f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture(scope="module")
def sdxl_ip_adapter_plus_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "ip-adapter-plus_sdxl_vit-h.safetensors"
if not path.is_file():
warn(f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture(scope="module")
def image_encoder_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "CLIPImageEncoderH.safetensors"
if not path.is_file():
warn(f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture
def scifi_lora_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "loras" / "Sci-fi_Environments_sdxl.safetensors"
if not path.is_file():
warn(message=f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture
def pixelart_lora_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "loras" / "pixel-art-xl-v1.1.safetensors"
if not path.is_file():
warn(message=f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture
def sdxl(
sdxl_text_encoder_weights: Path,
sdxl_lda_fp16_fix_weights: Path,
sdxl_unet_weights: Path,
test_device: torch.device,
) -> StableDiffusion_XL:
if test_device.type == "cpu":
warn(message="not running on CPU, skipping")
pytest.skip()
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16)
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights)
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights)
return sdxl
@pytest.fixture
def image_prompt_german_castle(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "german-castle.jpg").convert("RGB")
@pytest.fixture
def expected_image_guide_adapting_sdxl_vanilla(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_guide_adapting_sdxl_vanilla.png").convert("RGB")
@pytest.fixture
def expected_image_guide_adapting_sdxl_single_lora(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_guide_adapting_sdxl_single_lora.png").convert("RGB")
@pytest.fixture
def expected_image_guide_adapting_sdxl_multiple_loras(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_guide_adapting_sdxl_multiple_loras.png").convert("RGB")
@pytest.fixture
def expected_image_guide_adapting_sdxl_loras_ip_adapter(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_guide_adapting_sdxl_loras_ip_adapter.png").convert("RGB")
@no_grad()
def test_guide_adapting_sdxl_vanilla(
test_device: torch.device,
sdxl: StableDiffusion_XL,
expected_image_guide_adapting_sdxl_vanilla: Image.Image,
) -> None:
if test_device.type == "cpu":
warn(message="not running on CPU, skipping")
pytest.skip()
expected_image = expected_image_guide_adapting_sdxl_vanilla
prompt = "a futuristic castle surrounded by a forest, mountains in the background"
seed = 42
sdxl.set_inference_steps(50, first_step=0)
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt + ", best quality, high quality",
negative_text="monochrome, lowres, bad anatomy, worst quality, low quality",
)
time_ids = sdxl.default_time_ids
manual_seed(seed)
# The guide uses 2048x2048 but it is too slow for tests.
x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
)
predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image)
@no_grad()
def test_guide_adapting_sdxl_single_lora(
test_device: torch.device,
sdxl: StableDiffusion_XL,
scifi_lora_weights: Path,
expected_image_guide_adapting_sdxl_single_lora: Image.Image,
) -> None:
if test_device.type == "cpu":
warn(message="not running on CPU, skipping")
pytest.skip()
expected_image = expected_image_guide_adapting_sdxl_single_lora
prompt = "a futuristic castle surrounded by a forest, mountains in the background"
seed = 42
sdxl.set_inference_steps(50, first_step=0)
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
manager = SDLoraManager(sdxl)
manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights))
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt + ", best quality, high quality",
negative_text="monochrome, lowres, bad anatomy, worst quality, low quality",
)
time_ids = sdxl.default_time_ids
manual_seed(seed)
x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
)
predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image)
@no_grad()
def test_guide_adapting_sdxl_multiple_loras(
test_device: torch.device,
sdxl: StableDiffusion_XL,
scifi_lora_weights: Path,
pixelart_lora_weights: Path,
expected_image_guide_adapting_sdxl_multiple_loras: Image.Image,
) -> None:
if test_device.type == "cpu":
warn(message="not running on CPU, skipping")
pytest.skip()
expected_image = expected_image_guide_adapting_sdxl_multiple_loras
prompt = "a futuristic castle surrounded by a forest, mountains in the background"
seed = 42
sdxl.set_inference_steps(50, first_step=0)
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
manager = SDLoraManager(sdxl)
manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights))
manager.add_loras("pixel-art-lora", load_from_safetensors(pixelart_lora_weights), scale=1.4)
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt + ", best quality, high quality",
negative_text="monochrome, lowres, bad anatomy, worst quality, low quality",
)
time_ids = sdxl.default_time_ids
manual_seed(seed)
x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
)
predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image)
@no_grad()
def test_guide_adapting_sdxl_loras_ip_adapter(
test_device: torch.device,
sdxl: StableDiffusion_XL,
sdxl_ip_adapter_plus_weights: Path,
image_encoder_weights: Path,
scifi_lora_weights: Path,
pixelart_lora_weights: Path,
image_prompt_german_castle: Image.Image,
expected_image_guide_adapting_sdxl_loras_ip_adapter: Image.Image,
) -> None:
if test_device.type == "cpu":
warn(message="not running on CPU, skipping")
pytest.skip()
expected_image = expected_image_guide_adapting_sdxl_loras_ip_adapter
prompt = "a futuristic castle surrounded by a forest, mountains in the background"
seed = 42
sdxl.set_inference_steps(50, first_step=0)
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
manager = SDLoraManager(sdxl)
manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights), scale=1.5)
manager.add_loras("pixel-art-lora", load_from_safetensors(pixelart_lora_weights), scale=1.55)
ip_adapter = SDXLIPAdapter(
target=sdxl.unet,
weights=load_from_safetensors(sdxl_ip_adapter_plus_weights),
scale=1.0,
fine_grained=True,
)
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.inject()
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt + ", best quality, high quality",
negative_text="monochrome, lowres, bad anatomy, worst quality, low quality",
)
time_ids = sdxl.default_time_ids
image_prompt_preprocessed = ip_adapter.preprocess_image(image_prompt_german_castle)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(image_prompt_preprocessed)
ip_adapter.set_clip_image_embedding(clip_image_embedding)
manual_seed(seed)
x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
)
predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image)
# We do not (yet) test the last example using T2i-Adapter with Zoe Depth.

View file

@ -0,0 +1,5 @@
# Note about this data
Everything in this directory comes from Refiners' documentation.
Some outputs are different because we perform inference in 1024x1024 and not 2048x2048.

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.7 MiB