write ControlLora e2e tests

This commit is contained in:
Laurent 2024-02-14 15:27:13 +00:00 committed by Laureηt
parent 5fee723cd1
commit 7fe392298a
8 changed files with 159 additions and 0 deletions

View file

@ -1,4 +1,5 @@
import gc import gc
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Iterator from typing import Iterator
from warnings import warn from warnings import warn
@ -27,6 +28,7 @@ from refiners.foundationals.latent_diffusion.reference_only_control import Refer
from refiners.foundationals.latent_diffusion.restart import Restart from refiners.foundationals.latent_diffusion.restart import Restart
from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.control_lora import ControlLoraAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
from tests.utils import ensure_similar_images from tests.utils import ensure_similar_images
@ -185,6 +187,84 @@ def controlnet_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str,
return cn_name, condition_image, expected_image, weights_path return cn_name, condition_image, expected_image, weights_path
@dataclass
class ControlLoraConfig:
scale: float
condition_path: str
weights_path: str
@dataclass
class ControlLoraResolvedConfig:
scale: float
condition_image: Image.Image
weights_path: Path
CONTROL_LORA_CONFIGS: dict[str, dict[str, ControlLoraConfig]] = {
"expected_controllora_PyraCanny.png": {
"PyraCanny": ControlLoraConfig(
scale=1.0,
condition_path="cutecat_guide_PyraCanny.png",
weights_path="refiners_control-lora-canny-rank128.safetensors",
),
},
"expected_controllora_CPDS.png": {
"CPDS": ControlLoraConfig(
scale=1.0,
condition_path="cutecat_guide_CPDS.png",
weights_path="refiners_fooocus_xl_cpds_128.safetensors",
),
},
"expected_controllora_PyraCanny+CPDS.png": {
"PyraCanny": ControlLoraConfig(
scale=0.55,
condition_path="cutecat_guide_PyraCanny.png",
weights_path="refiners_control-lora-canny-rank128.safetensors",
),
"CPDS": ControlLoraConfig(
scale=0.55,
condition_path="cutecat_guide_CPDS.png",
weights_path="refiners_fooocus_xl_cpds_128.safetensors",
),
},
"expected_controllora_disabled.png": {
"PyraCanny": ControlLoraConfig(
scale=0.0,
condition_path="cutecat_guide_PyraCanny.png",
weights_path="refiners_control-lora-canny-rank128.safetensors",
),
"CPDS": ControlLoraConfig(
scale=0.0,
condition_path="cutecat_guide_CPDS.png",
weights_path="refiners_fooocus_xl_cpds_128.safetensors",
),
},
}
@pytest.fixture(params=CONTROL_LORA_CONFIGS.items())
def controllora_sdxl_config(
request: pytest.FixtureRequest,
ref_path: Path,
test_weights_path: Path,
) -> tuple[Image.Image, dict[str, ControlLoraResolvedConfig]]:
name: str = request.param[0]
configs: dict[str, ControlLoraConfig] = request.param[1]
expected_image = Image.open(ref_path / name).convert("RGB")
loaded_configs = {
config_name: ControlLoraResolvedConfig(
scale=config.scale,
condition_image=Image.open(ref_path / config.condition_path).convert("RGB"),
weights_path=test_weights_path / "control_lora" / config.weights_path,
)
for config_name, config in configs.items()
}
return expected_image, loaded_configs
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
name = "depth" name = "depth"
@ -1074,6 +1154,79 @@ def test_diffusion_controlnet_stack(
ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_diffusion_sdxl_controllora(
controllora_sdxl_config: tuple[Image.Image, dict[str, ControlLoraResolvedConfig]],
sdxl_ddim_lda_fp16_fix: StableDiffusion_XL,
) -> None:
sdxl = sdxl_ddim_lda_fp16_fix.to(dtype=torch.float16)
sdxl.dtype = torch.float16 # FIXME: should not be necessary
expected_image = controllora_sdxl_config[0]
configs = controllora_sdxl_config[1]
adapters: dict[str, ControlLoraAdapter] = {}
for config_name, config in configs.items():
adapter = ControlLoraAdapter(
name=config_name,
scale=config.scale,
target=sdxl.unet,
weights=load_from_safetensors(
path=config.weights_path,
device=sdxl.device,
),
)
adapter.set_condition(
image_to_tensor(
image=config.condition_image,
device=sdxl.device,
dtype=sdxl.dtype,
)
)
adapters[config_name] = adapter
# inject all the control lora adapters
for adapter in adapters.values():
adapter.inject()
# compute the text embeddings
prompt = "a cute cat, flying in the air, detailed high-quality professional image, blank background"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality, watermarks"
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt,
negative_text=negative_prompt,
)
# initialize the latents
manual_seed(2)
x = torch.randn(
(1, 4, 128, 128),
device=sdxl.device,
dtype=sdxl.dtype,
)
# denoise
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=sdxl.default_time_ids,
)
# decode latent to image
predicted_image = sdxl.lda.decode_latents(x)
# ensure the predicted image is similar to the expected image
ensure_similar_images(
img_1=predicted_image,
img_2=expected_image,
min_psnr=35,
min_ssim=0.99,
)
@no_grad() @no_grad()
def test_diffusion_lora( def test_diffusion_lora(
sd15_std: StableDiffusion_1, sd15_std: StableDiffusion_1,

View file

@ -52,6 +52,10 @@ Special cases:
- `expected_sdxl_dpo_lora.png` - `expected_sdxl_dpo_lora.png`
- `expected_sdxl_multi_loras.png` - `expected_sdxl_multi_loras.png`
- `expected_image_ip_adapter_multi.png` - `expected_image_ip_adapter_multi.png`
- `expected_controllora_CPDS.png`
- `expected_controllora_PyraCanny.png`
- `expected_controllora_PyraCanny+CPDS.png`
- `expected_controllora_disabled.png`
## Other images ## Other images
@ -81,6 +85,8 @@ Special cases:
- `statue.png` [comes from tencent-ailab/IP-Adapter](https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/assets/images/statue.png). - `statue.png` [comes from tencent-ailab/IP-Adapter](https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/assets/images/statue.png).
- `cutecat_guide_PyraCanny.png` and `cutecat_guide_CPDS.png` were [generated inside Fooocus](https://github.com/lllyasviel/Fooocus/blob/e8d88d3e250e541c6daf99d6ef734e8dc3cfdc7f/extras/preprocessors.py).
## VAE without randomness ## VAE without randomness
```diff ```diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 633 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 448 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB