mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
write ControlLora
e2e tests
This commit is contained in:
parent
5fee723cd1
commit
7fe392298a
|
@ -1,4 +1,5 @@
|
||||||
import gc
|
import gc
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterator
|
from typing import Iterator
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
@ -27,6 +28,7 @@ from refiners.foundationals.latent_diffusion.reference_only_control import Refer
|
||||||
from refiners.foundationals.latent_diffusion.restart import Restart
|
from refiners.foundationals.latent_diffusion.restart import Restart
|
||||||
from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule
|
from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.control_lora import ControlLoraAdapter
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||||
from tests.utils import ensure_similar_images
|
from tests.utils import ensure_similar_images
|
||||||
|
|
||||||
|
@ -185,6 +187,84 @@ def controlnet_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str,
|
||||||
return cn_name, condition_image, expected_image, weights_path
|
return cn_name, condition_image, expected_image, weights_path
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ControlLoraConfig:
|
||||||
|
scale: float
|
||||||
|
condition_path: str
|
||||||
|
weights_path: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ControlLoraResolvedConfig:
|
||||||
|
scale: float
|
||||||
|
condition_image: Image.Image
|
||||||
|
weights_path: Path
|
||||||
|
|
||||||
|
|
||||||
|
CONTROL_LORA_CONFIGS: dict[str, dict[str, ControlLoraConfig]] = {
|
||||||
|
"expected_controllora_PyraCanny.png": {
|
||||||
|
"PyraCanny": ControlLoraConfig(
|
||||||
|
scale=1.0,
|
||||||
|
condition_path="cutecat_guide_PyraCanny.png",
|
||||||
|
weights_path="refiners_control-lora-canny-rank128.safetensors",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"expected_controllora_CPDS.png": {
|
||||||
|
"CPDS": ControlLoraConfig(
|
||||||
|
scale=1.0,
|
||||||
|
condition_path="cutecat_guide_CPDS.png",
|
||||||
|
weights_path="refiners_fooocus_xl_cpds_128.safetensors",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"expected_controllora_PyraCanny+CPDS.png": {
|
||||||
|
"PyraCanny": ControlLoraConfig(
|
||||||
|
scale=0.55,
|
||||||
|
condition_path="cutecat_guide_PyraCanny.png",
|
||||||
|
weights_path="refiners_control-lora-canny-rank128.safetensors",
|
||||||
|
),
|
||||||
|
"CPDS": ControlLoraConfig(
|
||||||
|
scale=0.55,
|
||||||
|
condition_path="cutecat_guide_CPDS.png",
|
||||||
|
weights_path="refiners_fooocus_xl_cpds_128.safetensors",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"expected_controllora_disabled.png": {
|
||||||
|
"PyraCanny": ControlLoraConfig(
|
||||||
|
scale=0.0,
|
||||||
|
condition_path="cutecat_guide_PyraCanny.png",
|
||||||
|
weights_path="refiners_control-lora-canny-rank128.safetensors",
|
||||||
|
),
|
||||||
|
"CPDS": ControlLoraConfig(
|
||||||
|
scale=0.0,
|
||||||
|
condition_path="cutecat_guide_CPDS.png",
|
||||||
|
weights_path="refiners_fooocus_xl_cpds_128.safetensors",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=CONTROL_LORA_CONFIGS.items())
|
||||||
|
def controllora_sdxl_config(
|
||||||
|
request: pytest.FixtureRequest,
|
||||||
|
ref_path: Path,
|
||||||
|
test_weights_path: Path,
|
||||||
|
) -> tuple[Image.Image, dict[str, ControlLoraResolvedConfig]]:
|
||||||
|
name: str = request.param[0]
|
||||||
|
configs: dict[str, ControlLoraConfig] = request.param[1]
|
||||||
|
expected_image = Image.open(ref_path / name).convert("RGB")
|
||||||
|
|
||||||
|
loaded_configs = {
|
||||||
|
config_name: ControlLoraResolvedConfig(
|
||||||
|
scale=config.scale,
|
||||||
|
condition_image=Image.open(ref_path / config.condition_path).convert("RGB"),
|
||||||
|
weights_path=test_weights_path / "control_lora" / config.weights_path,
|
||||||
|
)
|
||||||
|
for config_name, config in configs.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return expected_image, loaded_configs
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
|
def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
|
||||||
name = "depth"
|
name = "depth"
|
||||||
|
@ -1074,6 +1154,79 @@ def test_diffusion_controlnet_stack(
|
||||||
ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_diffusion_sdxl_controllora(
|
||||||
|
controllora_sdxl_config: tuple[Image.Image, dict[str, ControlLoraResolvedConfig]],
|
||||||
|
sdxl_ddim_lda_fp16_fix: StableDiffusion_XL,
|
||||||
|
) -> None:
|
||||||
|
sdxl = sdxl_ddim_lda_fp16_fix.to(dtype=torch.float16)
|
||||||
|
sdxl.dtype = torch.float16 # FIXME: should not be necessary
|
||||||
|
|
||||||
|
expected_image = controllora_sdxl_config[0]
|
||||||
|
configs = controllora_sdxl_config[1]
|
||||||
|
|
||||||
|
adapters: dict[str, ControlLoraAdapter] = {}
|
||||||
|
for config_name, config in configs.items():
|
||||||
|
adapter = ControlLoraAdapter(
|
||||||
|
name=config_name,
|
||||||
|
scale=config.scale,
|
||||||
|
target=sdxl.unet,
|
||||||
|
weights=load_from_safetensors(
|
||||||
|
path=config.weights_path,
|
||||||
|
device=sdxl.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
adapter.set_condition(
|
||||||
|
image_to_tensor(
|
||||||
|
image=config.condition_image,
|
||||||
|
device=sdxl.device,
|
||||||
|
dtype=sdxl.dtype,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
adapters[config_name] = adapter
|
||||||
|
|
||||||
|
# inject all the control lora adapters
|
||||||
|
for adapter in adapters.values():
|
||||||
|
adapter.inject()
|
||||||
|
|
||||||
|
# compute the text embeddings
|
||||||
|
prompt = "a cute cat, flying in the air, detailed high-quality professional image, blank background"
|
||||||
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality, watermarks"
|
||||||
|
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||||
|
text=prompt,
|
||||||
|
negative_text=negative_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# initialize the latents
|
||||||
|
manual_seed(2)
|
||||||
|
x = torch.randn(
|
||||||
|
(1, 4, 128, 128),
|
||||||
|
device=sdxl.device,
|
||||||
|
dtype=sdxl.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# denoise
|
||||||
|
for step in sdxl.steps:
|
||||||
|
x = sdxl(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
pooled_text_embedding=pooled_text_embedding,
|
||||||
|
time_ids=sdxl.default_time_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# decode latent to image
|
||||||
|
predicted_image = sdxl.lda.decode_latents(x)
|
||||||
|
|
||||||
|
# ensure the predicted image is similar to the expected image
|
||||||
|
ensure_similar_images(
|
||||||
|
img_1=predicted_image,
|
||||||
|
img_2=expected_image,
|
||||||
|
min_psnr=35,
|
||||||
|
min_ssim=0.99,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_diffusion_lora(
|
def test_diffusion_lora(
|
||||||
sd15_std: StableDiffusion_1,
|
sd15_std: StableDiffusion_1,
|
||||||
|
|
|
@ -52,6 +52,10 @@ Special cases:
|
||||||
- `expected_sdxl_dpo_lora.png`
|
- `expected_sdxl_dpo_lora.png`
|
||||||
- `expected_sdxl_multi_loras.png`
|
- `expected_sdxl_multi_loras.png`
|
||||||
- `expected_image_ip_adapter_multi.png`
|
- `expected_image_ip_adapter_multi.png`
|
||||||
|
- `expected_controllora_CPDS.png`
|
||||||
|
- `expected_controllora_PyraCanny.png`
|
||||||
|
- `expected_controllora_PyraCanny+CPDS.png`
|
||||||
|
- `expected_controllora_disabled.png`
|
||||||
|
|
||||||
## Other images
|
## Other images
|
||||||
|
|
||||||
|
@ -81,6 +85,8 @@ Special cases:
|
||||||
|
|
||||||
- `statue.png` [comes from tencent-ailab/IP-Adapter](https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/assets/images/statue.png).
|
- `statue.png` [comes from tencent-ailab/IP-Adapter](https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/assets/images/statue.png).
|
||||||
|
|
||||||
|
- `cutecat_guide_PyraCanny.png` and `cutecat_guide_CPDS.png` were [generated inside Fooocus](https://github.com/lllyasviel/Fooocus/blob/e8d88d3e250e541c6daf99d6ef734e8dc3cfdc7f/extras/preprocessors.py).
|
||||||
|
|
||||||
## VAE without randomness
|
## VAE without randomness
|
||||||
|
|
||||||
```diff
|
```diff
|
||||||
|
|
BIN
tests/e2e/test_diffusion_ref/cutecat_guide_CPDS.png
Normal file
BIN
tests/e2e/test_diffusion_ref/cutecat_guide_CPDS.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 633 KiB |
BIN
tests/e2e/test_diffusion_ref/cutecat_guide_PyraCanny.png
Normal file
BIN
tests/e2e/test_diffusion_ref/cutecat_guide_PyraCanny.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 448 KiB |
BIN
tests/e2e/test_diffusion_ref/expected_controllora_CPDS.png
Normal file
BIN
tests/e2e/test_diffusion_ref/expected_controllora_CPDS.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.3 MiB |
Binary file not shown.
After Width: | Height: | Size: 1.4 MiB |
BIN
tests/e2e/test_diffusion_ref/expected_controllora_PyraCanny.png
Normal file
BIN
tests/e2e/test_diffusion_ref/expected_controllora_PyraCanny.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.4 MiB |
BIN
tests/e2e/test_diffusion_ref/expected_controllora_disabled.png
Normal file
BIN
tests/e2e/test_diffusion_ref/expected_controllora_disabled.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.2 MiB |
Loading…
Reference in a new issue