add e2e test for T2I-Adapter depth

Expected output generated with diffusers' StableDiffusionAdapterPipeline
This commit is contained in:
Cédric Deltheil 2023-09-24 21:32:06 +02:00 committed by Cédric Deltheil
parent 2106c237d9
commit 4301e81eb3
2 changed files with 50 additions and 0 deletions

View file

@ -14,6 +14,7 @@ from refiners.foundationals.latent_diffusion import (
SD1UNet,
SD1ControlnetAdapter,
SD1IPAdapter,
SD1T2IAdapter,
SDXLIPAdapter,
)
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
@ -128,6 +129,14 @@ def controlnet_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str,
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11f1p_sd15_depth.safetensors"
return cn_name, condition_image, expected_image, weights_path
@pytest.fixture(scope="module")
def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
name = "depth"
condition_image = Image.open(ref_path / f"cutecat_guide_{name}.png").convert("RGB")
expected_image = Image.open(ref_path / f"expected_t2i_adapter_{name}.png").convert("RGB")
weights_path = test_weights_path / "T2I-Adapter" / "t2iadapter_depth_sd15v2.safetensors"
return name, condition_image, expected_image, weights_path
@pytest.fixture(scope="module")
def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, Path]:
@ -1233,3 +1242,44 @@ def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion:
)
result = sd.lda.decode_latents(x=x)
ensure_similar_images(img_1=result, img_2=expected_multi_diffusion, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
def test_t2i_adapter_depth(
sd15_std: StableDiffusion_1,
t2i_adapter_data_depth: tuple[str, Image.Image, Image.Image, Path],
test_device: torch.device,
):
sd15 = sd15_std
n_steps = 30
name, condition_image, expected_image, weights_path = t2i_adapter_data_depth
if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps)
t2i_adapter = SD1T2IAdapter(target=sd15.unet, name=name, weights=load_from_safetensors(weights_path)).inject()
condition = image_to_tensor(condition_image.convert("RGB"), device=test_device)
t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image)

Binary file not shown.

After

Width:  |  Height:  |  Size: 419 KiB