mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
add end-to-end test for multi-ip adapter
This commit is contained in:
parent
ca5c5a7ca5
commit
5634e68fde
|
@ -108,6 +108,11 @@ def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image:
|
|||
return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_ip_adapter_multi(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_image_ip_adapter_multi.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_ip_adapter_plus_statue(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_image_ip_adapter_plus_statue.png").convert("RGB")
|
||||
|
@ -1334,6 +1339,46 @@ def test_diffusion_ip_adapter(
|
|||
ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_diffusion_ip_adapter_multi(
|
||||
sd15_ddim_lda_ft_mse: StableDiffusion_1,
|
||||
ip_adapter_weights: Path,
|
||||
image_encoder_weights: Path,
|
||||
woman_image: Image.Image,
|
||||
statue_image: Image.Image,
|
||||
expected_image_ip_adapter_multi: Image.Image,
|
||||
test_device: torch.device,
|
||||
):
|
||||
sd15 = sd15_ddim_lda_ft_mse.to(dtype=torch.float16)
|
||||
|
||||
prompt = "best quality, high quality"
|
||||
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||
|
||||
ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_weights))
|
||||
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
||||
ip_adapter.inject()
|
||||
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
clip_image_embedding = ip_adapter.compute_clip_image_embedding([woman_image, statue_image], weights=[1.0, 1.4])
|
||||
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
||||
|
||||
sd15.set_inference_steps(50)
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
||||
|
||||
for step in sd15.steps:
|
||||
x = sd15(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_ip_adapter_multi)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_diffusion_sdxl_ip_adapter(
|
||||
sdxl_ddim: StableDiffusion_XL,
|
||||
|
|
|
@ -50,6 +50,7 @@ Special cases:
|
|||
- `expected_dropy_slime_9752.png`
|
||||
- `expected_sdxl_dpo_lora.png`
|
||||
- `expected_sdxl_multi_loras.png`
|
||||
- `expected_image_ip_adapter_multi.png`
|
||||
|
||||
## Other images
|
||||
|
||||
|
|
BIN
tests/e2e/test_diffusion_ref/expected_image_ip_adapter_multi.png
Normal file
BIN
tests/e2e/test_diffusion_ref/expected_image_ip_adapter_multi.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 351 KiB |
Loading…
Reference in a new issue