refiners/tests/e2e/test_doc_examples.py

326 lines
11 KiB
Python
Raw Normal View History

2024-03-06 09:56:46 +00:00
import gc
from pathlib import Path
from warnings import warn
import pytest
import torch
from PIL import Image
2024-06-24 14:22:11 +00:00
from tests.utils import ensure_similar_images
2024-03-06 09:56:46 +00:00
from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad
from refiners.foundationals.latent_diffusion import SDXLIPAdapter
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
def _img_open(path: Path) -> Image.Image:
return Image.open(path) # type: ignore
2024-03-06 09:56:46 +00:00
@pytest.fixture(autouse=True)
def ensure_gc():
# Avoid GPU OOMs
# See https://github.com/pytest-dev/pytest/discussions/8153#discussioncomment-214812
gc.collect()
@pytest.fixture(scope="module")
def ref_path(test_e2e_path: Path) -> Path:
return test_e2e_path / "test_doc_examples_ref"
@pytest.fixture(scope="module")
def sdxl_text_encoder_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
if not path.is_file():
warn(message=f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture(scope="module")
def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
if not path.is_file():
warn(message=f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture(scope="module")
def sdxl_unet_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "sdxl-unet.safetensors"
if not path.is_file():
warn(message=f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture(scope="module")
def sdxl_ip_adapter_plus_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "ip-adapter-plus_sdxl_vit-h.safetensors"
if not path.is_file():
warn(f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture(scope="module")
def image_encoder_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "CLIPImageEncoderH.safetensors"
if not path.is_file():
warn(f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture
def scifi_lora_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "loras" / "Sci-fi_Environments_sdxl.safetensors"
if not path.is_file():
warn(message=f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture
def pixelart_lora_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "loras" / "pixel-art-xl-v1.1.safetensors"
if not path.is_file():
warn(message=f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture
def sdxl(
sdxl_text_encoder_weights: Path,
sdxl_lda_fp16_fix_weights: Path,
sdxl_unet_weights: Path,
test_device: torch.device,
) -> StableDiffusion_XL:
if test_device.type == "cpu":
warn(message="not running on CPU, skipping")
pytest.skip()
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16)
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights)
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights)
return sdxl
@pytest.fixture
def image_prompt_german_castle(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "german-castle.jpg").convert("RGB")
2024-03-06 09:56:46 +00:00
@pytest.fixture
def expected_image_guide_adapting_sdxl_vanilla(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_guide_adapting_sdxl_vanilla.png").convert("RGB")
2024-03-06 09:56:46 +00:00
@pytest.fixture
def expected_image_guide_adapting_sdxl_single_lora(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_guide_adapting_sdxl_single_lora.png").convert("RGB")
2024-03-06 09:56:46 +00:00
@pytest.fixture
def expected_image_guide_adapting_sdxl_multiple_loras(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_guide_adapting_sdxl_multiple_loras.png").convert("RGB")
2024-03-06 09:56:46 +00:00
@pytest.fixture
def expected_image_guide_adapting_sdxl_loras_ip_adapter(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_guide_adapting_sdxl_loras_ip_adapter.png").convert("RGB")
2024-03-06 09:56:46 +00:00
@no_grad()
def test_guide_adapting_sdxl_vanilla(
test_device: torch.device,
sdxl: StableDiffusion_XL,
expected_image_guide_adapting_sdxl_vanilla: Image.Image,
) -> None:
if test_device.type == "cpu":
warn(message="not running on CPU, skipping")
pytest.skip()
expected_image = expected_image_guide_adapting_sdxl_vanilla
prompt = "a futuristic castle surrounded by a forest, mountains in the background"
seed = 42
sdxl.set_inference_steps(50, first_step=0)
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt + ", best quality, high quality",
negative_text="monochrome, lowres, bad anatomy, worst quality, low quality",
)
time_ids = sdxl.default_time_ids
manual_seed(seed)
# The guide uses 2048x2048 but it is too slow for tests.
x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
)
predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image)
@no_grad()
def test_guide_adapting_sdxl_single_lora(
test_device: torch.device,
sdxl: StableDiffusion_XL,
scifi_lora_weights: Path,
expected_image_guide_adapting_sdxl_single_lora: Image.Image,
) -> None:
if test_device.type == "cpu":
warn(message="not running on CPU, skipping")
pytest.skip()
expected_image = expected_image_guide_adapting_sdxl_single_lora
prompt = "a futuristic castle surrounded by a forest, mountains in the background"
seed = 42
sdxl.set_inference_steps(50, first_step=0)
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
manager = SDLoraManager(sdxl)
manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights))
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt + ", best quality, high quality",
negative_text="monochrome, lowres, bad anatomy, worst quality, low quality",
)
time_ids = sdxl.default_time_ids
manual_seed(seed)
x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
)
predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image)
@no_grad()
def test_guide_adapting_sdxl_multiple_loras(
test_device: torch.device,
sdxl: StableDiffusion_XL,
scifi_lora_weights: Path,
pixelart_lora_weights: Path,
expected_image_guide_adapting_sdxl_multiple_loras: Image.Image,
) -> None:
if test_device.type == "cpu":
warn(message="not running on CPU, skipping")
pytest.skip()
expected_image = expected_image_guide_adapting_sdxl_multiple_loras
prompt = "a futuristic castle surrounded by a forest, mountains in the background"
seed = 42
sdxl.set_inference_steps(50, first_step=0)
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
manager = SDLoraManager(sdxl)
manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights))
manager.add_loras("pixel-art-lora", load_from_safetensors(pixelart_lora_weights), scale=1.4)
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt + ", best quality, high quality",
negative_text="monochrome, lowres, bad anatomy, worst quality, low quality",
)
time_ids = sdxl.default_time_ids
manual_seed(seed)
x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
)
predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image)
@no_grad()
def test_guide_adapting_sdxl_loras_ip_adapter(
test_device: torch.device,
sdxl: StableDiffusion_XL,
sdxl_ip_adapter_plus_weights: Path,
image_encoder_weights: Path,
scifi_lora_weights: Path,
pixelart_lora_weights: Path,
image_prompt_german_castle: Image.Image,
expected_image_guide_adapting_sdxl_loras_ip_adapter: Image.Image,
) -> None:
if test_device.type == "cpu":
warn(message="not running on CPU, skipping")
pytest.skip()
expected_image = expected_image_guide_adapting_sdxl_loras_ip_adapter
prompt = "a futuristic castle surrounded by a forest, mountains in the background"
seed = 42
sdxl.set_inference_steps(50, first_step=0)
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
manager = SDLoraManager(sdxl)
manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights), scale=1.5)
manager.add_loras("pixel-art-lora", load_from_safetensors(pixelart_lora_weights), scale=1.55)
ip_adapter = SDXLIPAdapter(
target=sdxl.unet,
weights=load_from_safetensors(sdxl_ip_adapter_plus_weights),
scale=1.0,
fine_grained=True,
)
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.inject()
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt + ", best quality, high quality",
negative_text="monochrome, lowres, bad anatomy, worst quality, low quality",
)
time_ids = sdxl.default_time_ids
image_prompt_preprocessed = ip_adapter.preprocess_image(image_prompt_german_castle)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(image_prompt_preprocessed)
ip_adapter.set_clip_image_embedding(clip_image_embedding)
manual_seed(seed)
x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
)
predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image)
# We do not (yet) test the last example using T2i-Adapter with Zoe Depth.