add e2e test for T2I-Adapter XL canny

This commit is contained in:
Cédric Deltheil 2023-09-24 22:05:56 +02:00 committed by Cédric Deltheil
parent 4301e81eb3
commit f37f25a2e4
4 changed files with 69 additions and 2 deletions

View file

@ -16,6 +16,7 @@ from refiners.foundationals.latent_diffusion import (
SD1IPAdapter, SD1IPAdapter,
SD1T2IAdapter, SD1T2IAdapter,
SDXLIPAdapter, SDXLIPAdapter,
SDXLT2IAdapter,
) )
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget
@ -129,6 +130,7 @@ def controlnet_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str,
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11f1p_sd15_depth.safetensors" weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11f1p_sd15_depth.safetensors"
return cn_name, condition_image, expected_image, weights_path return cn_name, condition_image, expected_image, weights_path
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
name = "depth" name = "depth"
@ -138,6 +140,15 @@ def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str
return name, condition_image, expected_image, weights_path return name, condition_image, expected_image, weights_path
@pytest.fixture(scope="module")
def t2i_adapter_xl_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
name = "canny"
condition_image = Image.open(ref_path / f"fairy_guide_{name}.png").convert("RGB")
expected_image = Image.open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB")
weights_path = test_weights_path / "T2I-Adapter" / "t2i-adapter-canny-sdxl-1.0.safetensors"
return name, condition_image, expected_image, weights_path
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, Path]: def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, Path]:
expected_image = Image.open(ref_path / "expected_lora_pokemon.png").convert("RGB") expected_image = Image.open(ref_path / "expected_lora_pokemon.png").convert("RGB")
@ -1283,3 +1294,52 @@ def test_t2i_adapter_depth(
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image) ensure_similar_images(predicted_image, expected_image)
@torch.no_grad()
def test_t2i_adapter_xl_canny(
sdxl_ddim: StableDiffusion_XL,
t2i_adapter_xl_data_canny: tuple[str, Image.Image, Image.Image, Path],
test_device: torch.device,
):
sdxl = sdxl_ddim
n_steps = 30
name, condition_image, expected_image, weights_path = t2i_adapter_xl_data_canny
if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
prompt = "Mystical fairy in real, magic, 4k picture, high quality"
negative_prompt = (
"extra digit, fewer digits, cropped, worst quality, low quality, glitch, deformed, mutated, ugly, disfigured"
)
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt
)
time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(n_steps)
t2i_adapter = SDXLT2IAdapter(target=sdxl.unet, name=name, weights=load_from_safetensors(weights_path)).inject()
t2i_adapter.set_scale(0.8)
condition = image_to_tensor(condition_image.convert("RGB"), device=test_device)
t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))
manual_seed(2)
x = torch.randn(1, 4, condition_image.height // 8, condition_image.width // 8, device=test_device)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=7.5,
)
predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image)

View file

@ -35,7 +35,12 @@ output.images[0].save("std_random_init_expected.png")
Special cases: Special cases:
- `expected_refonly.png` has been generated [with Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui). - `expected_refonly.png` has been generated [with Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui).
- `expected_inpainting_refonly.png`, `expected_image_ip_adapter_woman.png`, `expected_image_sdxl_ip_adapter_woman.png` and `expected_ip_adapter_controlnet.png` have been generated with refiners itself (and inspected so that they look reasonable). - The following references have been generated with refiners itself (and inspected so that they look reasonable):
- `expected_inpainting_refonly.png`,
- `expected_image_ip_adapter_woman.png`,
- `expected_image_sdxl_ip_adapter_woman.png`
- `expected_ip_adapter_controlnet.png`
- `expected_t2i_adapter_xl_canny.png`
## Other images ## Other images
@ -45,13 +50,15 @@ Special cases:
- `kitchen_mask.png` is made manually. - `kitchen_mask.png` is made manually.
- Controlnet guides have been manually generated using open source software and models, namely: - Controlnet guides have been manually generated (x) using open source software and models, namely:
- Canny: opencv-python - Canny: opencv-python
- Depth: https://github.com/isl-org/ZoeDepth - Depth: https://github.com/isl-org/ZoeDepth
- Lineart: https://github.com/lllyasviel/ControlNet-v1-1-nightly/tree/main/annotator/lineart - Lineart: https://github.com/lllyasviel/ControlNet-v1-1-nightly/tree/main/annotator/lineart
- Normals: https://github.com/baegwangbin/surface_normal_uncertainty/tree/fe2b9f1 - Normals: https://github.com/baegwangbin/surface_normal_uncertainty/tree/fe2b9f1
- SAM: https://huggingface.co/spaces/mfidabel/controlnet-segment-anything - SAM: https://huggingface.co/spaces/mfidabel/controlnet-segment-anything
(x): excepted `fairy_guide_canny.png` which comes from [TencentARC/t2i-adapter-canny-sdxl-1.0](https://huggingface.co/TencentARC/t2i-adapter-canny-sdxl-1.0)
- `cyberpunk_guide.png` [comes from Lexica](https://lexica.art/prompt/5ba40855-0d0c-4322-8722-51115985f573). - `cyberpunk_guide.png` [comes from Lexica](https://lexica.art/prompt/5ba40855-0d0c-4322-8722-51115985f573).
- `inpainting-mask.png`, `inpainting-scene.png` and `inpainting-target.png` have been generated as follows: - `inpainting-mask.png`, `inpainting-scene.png` and `inpainting-target.png` have been generated as follows:

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 149 KiB