mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
add a test for IP-Adapter + ControlNet
This commit is contained in:
parent
cf9efb57c8
commit
c421cfd56c
|
@ -80,6 +80,11 @@ def expected_image_sdxl_ip_adapter_woman(ref_path: Path) -> Image.Image:
|
|||
return Image.open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_ip_adapter_controlnet(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_ip_adapter_controlnet.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image:
|
||||
return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert(mode="RGB")
|
||||
|
@ -1076,6 +1081,71 @@ def test_diffusion_sdxl_ip_adapter(
|
|||
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_woman)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_diffusion_ip_adapter_controlnet(
|
||||
sd15_ddim: StableDiffusion_1,
|
||||
ip_adapter_weights: Path,
|
||||
image_encoder_weights: Path,
|
||||
lora_data_pokemon: tuple[Image.Image, Path],
|
||||
controlnet_data_depth: tuple[str, Image.Image, Image.Image, Path],
|
||||
expected_image_ip_adapter_controlnet: Image.Image,
|
||||
test_device: torch.device,
|
||||
):
|
||||
sd15 = sd15_ddim.to(dtype=torch.float16)
|
||||
n_steps = 50
|
||||
input_image, _ = lora_data_pokemon # use the Pokemon LoRA output as input
|
||||
_, depth_condition_image, _, depth_cn_weights_path = controlnet_data_depth
|
||||
|
||||
prompt = "best quality, high quality"
|
||||
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||
|
||||
ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_weights))
|
||||
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
||||
ip_adapter.inject()
|
||||
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(input_image))
|
||||
|
||||
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
|
||||
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
|
||||
|
||||
clip_text_embedding = torch.cat(
|
||||
(
|
||||
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
|
||||
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
|
||||
)
|
||||
)
|
||||
|
||||
depth_controlnet = SD1ControlnetAdapter(
|
||||
sd15.unet,
|
||||
name="depth",
|
||||
scale=1.0,
|
||||
weights=load_from_safetensors(depth_cn_weights_path),
|
||||
).inject()
|
||||
depth_cn_condition = image_to_tensor(
|
||||
depth_condition_image.convert("RGB"),
|
||||
device=test_device,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
||||
|
||||
for step in sd15.steps:
|
||||
depth_controlnet.set_controlnet_condition(depth_cn_condition)
|
||||
x = sd15(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_ip_adapter_controlnet)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_sdxl_random_init(
|
||||
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device
|
||||
|
|
|
@ -35,7 +35,7 @@ output.images[0].save("std_random_init_expected.png")
|
|||
Special cases:
|
||||
|
||||
- `expected_refonly.png` has been generated [with Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui).
|
||||
- `expected_inpainting_refonly.png`, `expected_image_ip_adapter_woman.png`, `expected_image_sdxl_ip_adapter_woman.png` have been generated with refiners itself (and inspected so that they look reasonable).
|
||||
- `expected_inpainting_refonly.png`, `expected_image_ip_adapter_woman.png`, `expected_image_sdxl_ip_adapter_woman.png` and `expected_ip_adapter_controlnet.png` have been generated with refiners itself (and inspected so that they look reasonable).
|
||||
|
||||
## Other images
|
||||
|
||||
|
|
BIN
tests/e2e/test_diffusion_ref/expected_ip_adapter_controlnet.png
Normal file
BIN
tests/e2e/test_diffusion_ref/expected_ip_adapter_controlnet.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 309 KiB |
Loading…
Reference in a new issue