mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
add test for "Adapting SDXL" guide
This commit is contained in:
parent
cd5fa97c20
commit
5d784bedab
|
@ -58,9 +58,9 @@ Then, define the inference parameters by setting the appropriate prompt / seed /
|
|||
prompt = "a futuristic castle surrounded by a forest, mountains in the background"
|
||||
seed = 42
|
||||
sdxl.set_inference_steps(50, first_step=0)
|
||||
sdxl.set_self_attention_guidance(
|
||||
enable=True, scale=0.75
|
||||
) # Enable self-attention guidance to enhance the quality of the generated images
|
||||
|
||||
# Enable self-attention guidance to enhance the quality of the generated images
|
||||
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
|
||||
|
||||
# ... Inference process
|
||||
|
||||
|
@ -76,10 +76,10 @@ with no_grad(): # Disable gradient calculation for memory-efficient inference
|
|||
)
|
||||
time_ids = sdxl.default_time_ids
|
||||
|
||||
manual_seed(seed=seed)
|
||||
manual_seed(seed)
|
||||
|
||||
# Using a higher latents inner dim to improve resolution of generated images
|
||||
x = torch.randn(size=(1, 4, 256, 256), device=sdxl.device, dtype=sdxl.dtype)
|
||||
# SDXL typically generates 1024x1024, here we use a higher resolution.
|
||||
x = sdxl.init_latents((2048, 2048)).to(sdxl.device, sdxl.dtype)
|
||||
|
||||
# Diffusion process
|
||||
for step in sdxl.steps:
|
||||
|
@ -131,8 +131,8 @@ predicted_image.save("vanilla_sdxl.png")
|
|||
|
||||
manual_seed(seed=seed)
|
||||
|
||||
# Using a higher latents inner dim to improve resolution of generated images
|
||||
x = torch.randn(size=(1, 4, 256, 256), device=sdxl.device, dtype=sdxl.dtype)
|
||||
# SDXL typically generates 1024x1024, here we use a higher resolution.
|
||||
x = sdxl.init_latents((2048, 2048)).to(sdxl.device, sdxl.dtype)
|
||||
|
||||
# Diffusion process
|
||||
for step in sdxl.steps:
|
||||
|
@ -213,8 +213,8 @@ manager.add_loras("scifi-lora", tensors=scifi_lora_weights)
|
|||
|
||||
manual_seed(seed=seed)
|
||||
|
||||
# Using a higher latents inner dim to improve resolution of generated images
|
||||
x = torch.randn(size=(1, 4, 256, 256), device=sdxl.device, dtype=sdxl.dtype)
|
||||
# SDXL typically generates 1024x1024, here we use a higher resolution.
|
||||
x = sdxl.init_latents((2048, 2048)).to(sdxl.device, sdxl.dtype)
|
||||
|
||||
# Diffusion process
|
||||
for step in sdxl.steps:
|
||||
|
@ -304,8 +304,8 @@ manager.add_loras("pixel-art-lora", load_from_safetensors("pixel-art-xl-v1.1.saf
|
|||
|
||||
manual_seed(seed=seed)
|
||||
|
||||
# Using a higher latents inner dim to improve resolution of generated images
|
||||
x = torch.randn(size=(1, 4, 256, 256), device=sdxl.device, dtype=sdxl.dtype)
|
||||
# SDXL typically generates 1024x1024, here we use a higher resolution.
|
||||
x = sdxl.init_latents((2048, 2048)).to(sdxl.device, sdxl.dtype)
|
||||
|
||||
# Diffusion process
|
||||
for step in sdxl.steps:
|
||||
|
@ -440,7 +440,7 @@ with torch.no_grad():
|
|||
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
||||
|
||||
manual_seed(seed=seed)
|
||||
x = torch.randn(size=(1, 4, 128, 128), device=sdxl.device, dtype=sdxl.dtype)
|
||||
x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype)
|
||||
|
||||
# Diffusion process
|
||||
for step in sdxl.steps:
|
||||
|
@ -578,7 +578,7 @@ with torch.no_grad():
|
|||
t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))
|
||||
|
||||
manual_seed(seed=seed)
|
||||
x = torch.randn(size=(1, 4, 128, 128), device=sdxl.device, dtype=sdxl.dtype)
|
||||
x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype)
|
||||
|
||||
# Diffusion process
|
||||
for step in sdxl.steps:
|
||||
|
|
|
@ -253,6 +253,20 @@ def download_loras():
|
|||
)
|
||||
download_file("https://sliders.baulab.info/weights/xl_sliders/eyesize.pt", dest_folder, expected_hash="ee170e4d")
|
||||
|
||||
dest_folder = os.path.join(test_weights_dir, "loras")
|
||||
download_file(
|
||||
"https://civitai.com/api/download/models/140624",
|
||||
filename="Sci-fi_Environments_sdxl.safetensors",
|
||||
dest_folder=dest_folder,
|
||||
expected_hash="6a4afda8",
|
||||
)
|
||||
download_file(
|
||||
"https://civitai.com/api/download/models/135931",
|
||||
filename="pixel-art-xl-v1.1.safetensors",
|
||||
dest_folder=dest_folder,
|
||||
expected_hash="71aaa6ca",
|
||||
)
|
||||
|
||||
|
||||
def download_preprocessors():
|
||||
dest_folder = os.path.join(test_weights_dir, "carolineec", "informativedrawings")
|
||||
|
|
321
tests/e2e/test_doc_examples.py
Normal file
321
tests/e2e/test_doc_examples.py
Normal file
|
@ -0,0 +1,321 @@
|
|||
import gc
|
||||
from pathlib import Path
|
||||
from warnings import warn
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad
|
||||
from refiners.foundationals.latent_diffusion import SDXLIPAdapter
|
||||
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||
from tests.utils import ensure_similar_images
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def ensure_gc():
|
||||
# Avoid GPU OOMs
|
||||
# See https://github.com/pytest-dev/pytest/discussions/8153#discussioncomment-214812
|
||||
gc.collect()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ref_path(test_e2e_path: Path) -> Path:
|
||||
return test_e2e_path / "test_doc_examples_ref"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sdxl_text_encoder_weights(test_weights_path: Path) -> Path:
|
||||
path = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
|
||||
if not path.is_file():
|
||||
warn(message=f"could not find weights at {path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path:
|
||||
path = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
|
||||
if not path.is_file():
|
||||
warn(message=f"could not find weights at {path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sdxl_unet_weights(test_weights_path: Path) -> Path:
|
||||
path = test_weights_path / "sdxl-unet.safetensors"
|
||||
if not path.is_file():
|
||||
warn(message=f"could not find weights at {path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sdxl_ip_adapter_plus_weights(test_weights_path: Path) -> Path:
|
||||
path = test_weights_path / "ip-adapter-plus_sdxl_vit-h.safetensors"
|
||||
if not path.is_file():
|
||||
warn(f"could not find weights at {path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def image_encoder_weights(test_weights_path: Path) -> Path:
|
||||
path = test_weights_path / "CLIPImageEncoderH.safetensors"
|
||||
if not path.is_file():
|
||||
warn(f"could not find weights at {path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scifi_lora_weights(test_weights_path: Path) -> Path:
|
||||
path = test_weights_path / "loras" / "Sci-fi_Environments_sdxl.safetensors"
|
||||
if not path.is_file():
|
||||
warn(message=f"could not find weights at {path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pixelart_lora_weights(test_weights_path: Path) -> Path:
|
||||
path = test_weights_path / "loras" / "pixel-art-xl-v1.1.safetensors"
|
||||
if not path.is_file():
|
||||
warn(message=f"could not find weights at {path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sdxl(
|
||||
sdxl_text_encoder_weights: Path,
|
||||
sdxl_lda_fp16_fix_weights: Path,
|
||||
sdxl_unet_weights: Path,
|
||||
test_device: torch.device,
|
||||
) -> StableDiffusion_XL:
|
||||
if test_device.type == "cpu":
|
||||
warn(message="not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
||||
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16)
|
||||
|
||||
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
|
||||
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights)
|
||||
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights)
|
||||
|
||||
return sdxl
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def image_prompt_german_castle(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "german-castle.jpg").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_guide_adapting_sdxl_vanilla(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_image_guide_adapting_sdxl_vanilla.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_guide_adapting_sdxl_single_lora(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_image_guide_adapting_sdxl_single_lora.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_guide_adapting_sdxl_multiple_loras(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_image_guide_adapting_sdxl_multiple_loras.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_guide_adapting_sdxl_loras_ip_adapter(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_image_guide_adapting_sdxl_loras_ip_adapter.png").convert("RGB")
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_guide_adapting_sdxl_vanilla(
|
||||
test_device: torch.device,
|
||||
sdxl: StableDiffusion_XL,
|
||||
expected_image_guide_adapting_sdxl_vanilla: Image.Image,
|
||||
) -> None:
|
||||
if test_device.type == "cpu":
|
||||
warn(message="not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
||||
expected_image = expected_image_guide_adapting_sdxl_vanilla
|
||||
|
||||
prompt = "a futuristic castle surrounded by a forest, mountains in the background"
|
||||
seed = 42
|
||||
sdxl.set_inference_steps(50, first_step=0)
|
||||
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
|
||||
|
||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||
text=prompt + ", best quality, high quality",
|
||||
negative_text="monochrome, lowres, bad anatomy, worst quality, low quality",
|
||||
)
|
||||
time_ids = sdxl.default_time_ids
|
||||
|
||||
manual_seed(seed)
|
||||
# The guide uses 2048x2048 but it is too slow for tests.
|
||||
x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype)
|
||||
for step in sdxl.steps:
|
||||
x = sdxl(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
pooled_text_embedding=pooled_text_embedding,
|
||||
time_ids=time_ids,
|
||||
)
|
||||
|
||||
predicted_image = sdxl.lda.decode_latents(x)
|
||||
ensure_similar_images(predicted_image, expected_image)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_guide_adapting_sdxl_single_lora(
|
||||
test_device: torch.device,
|
||||
sdxl: StableDiffusion_XL,
|
||||
scifi_lora_weights: Path,
|
||||
expected_image_guide_adapting_sdxl_single_lora: Image.Image,
|
||||
) -> None:
|
||||
if test_device.type == "cpu":
|
||||
warn(message="not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
||||
expected_image = expected_image_guide_adapting_sdxl_single_lora
|
||||
|
||||
prompt = "a futuristic castle surrounded by a forest, mountains in the background"
|
||||
seed = 42
|
||||
sdxl.set_inference_steps(50, first_step=0)
|
||||
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
|
||||
|
||||
manager = SDLoraManager(sdxl)
|
||||
manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights))
|
||||
|
||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||
text=prompt + ", best quality, high quality",
|
||||
negative_text="monochrome, lowres, bad anatomy, worst quality, low quality",
|
||||
)
|
||||
time_ids = sdxl.default_time_ids
|
||||
|
||||
manual_seed(seed)
|
||||
x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype)
|
||||
for step in sdxl.steps:
|
||||
x = sdxl(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
pooled_text_embedding=pooled_text_embedding,
|
||||
time_ids=time_ids,
|
||||
)
|
||||
|
||||
predicted_image = sdxl.lda.decode_latents(x)
|
||||
ensure_similar_images(predicted_image, expected_image)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_guide_adapting_sdxl_multiple_loras(
|
||||
test_device: torch.device,
|
||||
sdxl: StableDiffusion_XL,
|
||||
scifi_lora_weights: Path,
|
||||
pixelart_lora_weights: Path,
|
||||
expected_image_guide_adapting_sdxl_multiple_loras: Image.Image,
|
||||
) -> None:
|
||||
if test_device.type == "cpu":
|
||||
warn(message="not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
||||
expected_image = expected_image_guide_adapting_sdxl_multiple_loras
|
||||
|
||||
prompt = "a futuristic castle surrounded by a forest, mountains in the background"
|
||||
seed = 42
|
||||
sdxl.set_inference_steps(50, first_step=0)
|
||||
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
|
||||
|
||||
manager = SDLoraManager(sdxl)
|
||||
manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights))
|
||||
manager.add_loras("pixel-art-lora", load_from_safetensors(pixelart_lora_weights), scale=1.4)
|
||||
|
||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||
text=prompt + ", best quality, high quality",
|
||||
negative_text="monochrome, lowres, bad anatomy, worst quality, low quality",
|
||||
)
|
||||
time_ids = sdxl.default_time_ids
|
||||
|
||||
manual_seed(seed)
|
||||
x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype)
|
||||
for step in sdxl.steps:
|
||||
x = sdxl(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
pooled_text_embedding=pooled_text_embedding,
|
||||
time_ids=time_ids,
|
||||
)
|
||||
|
||||
predicted_image = sdxl.lda.decode_latents(x)
|
||||
ensure_similar_images(predicted_image, expected_image)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_guide_adapting_sdxl_loras_ip_adapter(
|
||||
test_device: torch.device,
|
||||
sdxl: StableDiffusion_XL,
|
||||
sdxl_ip_adapter_plus_weights: Path,
|
||||
image_encoder_weights: Path,
|
||||
scifi_lora_weights: Path,
|
||||
pixelart_lora_weights: Path,
|
||||
image_prompt_german_castle: Image.Image,
|
||||
expected_image_guide_adapting_sdxl_loras_ip_adapter: Image.Image,
|
||||
) -> None:
|
||||
if test_device.type == "cpu":
|
||||
warn(message="not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
||||
expected_image = expected_image_guide_adapting_sdxl_loras_ip_adapter
|
||||
|
||||
prompt = "a futuristic castle surrounded by a forest, mountains in the background"
|
||||
seed = 42
|
||||
sdxl.set_inference_steps(50, first_step=0)
|
||||
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
|
||||
|
||||
manager = SDLoraManager(sdxl)
|
||||
manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights), scale=1.5)
|
||||
manager.add_loras("pixel-art-lora", load_from_safetensors(pixelart_lora_weights), scale=1.55)
|
||||
|
||||
ip_adapter = SDXLIPAdapter(
|
||||
target=sdxl.unet,
|
||||
weights=load_from_safetensors(sdxl_ip_adapter_plus_weights),
|
||||
scale=1.0,
|
||||
fine_grained=True,
|
||||
)
|
||||
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
||||
ip_adapter.inject()
|
||||
|
||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||
text=prompt + ", best quality, high quality",
|
||||
negative_text="monochrome, lowres, bad anatomy, worst quality, low quality",
|
||||
)
|
||||
time_ids = sdxl.default_time_ids
|
||||
|
||||
image_prompt_preprocessed = ip_adapter.preprocess_image(image_prompt_german_castle)
|
||||
clip_image_embedding = ip_adapter.compute_clip_image_embedding(image_prompt_preprocessed)
|
||||
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
||||
|
||||
manual_seed(seed)
|
||||
x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype)
|
||||
for step in sdxl.steps:
|
||||
x = sdxl(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
pooled_text_embedding=pooled_text_embedding,
|
||||
time_ids=time_ids,
|
||||
)
|
||||
|
||||
predicted_image = sdxl.lda.decode_latents(x)
|
||||
ensure_similar_images(predicted_image, expected_image)
|
||||
|
||||
|
||||
# We do not (yet) test the last example using T2i-Adapter with Zoe Depth.
|
5
tests/e2e/test_doc_examples_ref/README.md
Normal file
5
tests/e2e/test_doc_examples_ref/README.md
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Note about this data
|
||||
|
||||
Everything in this directory comes from Refiners' documentation.
|
||||
|
||||
Some outputs are different because we perform inference in 1024x1024 and not 2048x2048.
|
Binary file not shown.
After Width: | Height: | Size: 1.8 MiB |
Binary file not shown.
After Width: | Height: | Size: 1.1 MiB |
Binary file not shown.
After Width: | Height: | Size: 1.2 MiB |
Binary file not shown.
After Width: | Height: | Size: 1.7 MiB |
BIN
tests/e2e/test_doc_examples_ref/german-castle.jpg
Normal file
BIN
tests/e2e/test_doc_examples_ref/german-castle.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.7 MiB |
Loading…
Reference in a new issue