refiners/tests/e2e/test_diffusion.py

2326 lines
80 KiB
Python

import gc
from dataclasses import dataclass
from pathlib import Path
from typing import Iterator
from warnings import warn
import pytest
import torch
from PIL import Image
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad
from refiners.foundationals.clip.concepts import ConceptExtender
from refiners.foundationals.latent_diffusion import (
ControlLoraAdapter,
SD1ControlnetAdapter,
SD1IPAdapter,
SD1T2IAdapter,
SD1UNet,
SDFreeUAdapter,
SDXLIPAdapter,
SDXLT2IAdapter,
StableDiffusion_1,
StableDiffusion_1_Inpainting,
)
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
from refiners.foundationals.latent_diffusion.restart import Restart
from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule, SolverParams
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAdapter
from tests.utils import ensure_similar_images
def _img_open(path: Path) -> Image.Image:
return Image.open(path) # type: ignore
@pytest.fixture(autouse=True)
def ensure_gc():
# Avoid GPU OOMs
# See https://github.com/pytest-dev/pytest/discussions/8153#discussioncomment-214812
gc.collect()
@pytest.fixture(scope="module")
def ref_path(test_e2e_path: Path) -> Path:
return test_e2e_path / "test_diffusion_ref"
@pytest.fixture(scope="module")
def cutecat_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "cutecat_init.png").convert("RGB")
@pytest.fixture(scope="module")
def kitchen_dog(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "kitchen_dog.png").convert("RGB")
@pytest.fixture(scope="module")
def kitchen_dog_mask(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "kitchen_dog_mask.png").convert("RGB")
@pytest.fixture(scope="module")
def woman_image(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "woman.png").convert("RGB")
@pytest.fixture(scope="module")
def statue_image(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "statue.png").convert("RGB")
@pytest.fixture
def expected_image_std_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_random_init.png").convert("RGB")
@pytest.fixture
def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB")
@pytest.fixture
def expected_karras_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_karras_random_init.png").convert("RGB")
@pytest.fixture
def expected_image_std_random_init_sag(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_random_init_sag.png").convert("RGB")
@pytest.fixture
def expected_image_std_init_image(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_init_image.png").convert("RGB")
@pytest.fixture
def expected_image_std_inpainting(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_inpainting.png").convert("RGB")
@pytest.fixture
def expected_image_controlnet_stack(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_controlnet_stack.png").convert("RGB")
@pytest.fixture
def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB")
@pytest.fixture
def expected_image_ip_adapter_multi(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_ip_adapter_multi.png").convert("RGB")
@pytest.fixture
def expected_image_ip_adapter_plus_statue(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_ip_adapter_plus_statue.png").convert("RGB")
@pytest.fixture
def expected_image_sdxl_ip_adapter_woman(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB")
@pytest.fixture
def expected_image_sdxl_ip_adapter_plus_woman(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_sdxl_ip_adapter_plus_woman.png").convert("RGB")
@pytest.fixture
def expected_image_ip_adapter_controlnet(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_ip_adapter_controlnet.png").convert("RGB")
@pytest.fixture
def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert("RGB")
@pytest.fixture
def expected_sdxl_ddim_random_init_sag(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert("RGB")
@pytest.fixture
def expected_sdxl_euler_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_cutecat_sdxl_euler_random_init.png").convert("RGB")
@pytest.fixture
def expected_style_aligned(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_style_aligned.png").convert(mode="RGB")
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
def controlnet_data(
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]:
cn_name: str = request.param
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
weights_fn = {
"depth": "lllyasviel_control_v11f1p_sd15_depth",
"canny": "lllyasviel_control_v11p_sd15_canny",
"lineart": "lllyasviel_control_v11p_sd15_lineart",
"normals": "lllyasviel_control_v11p_sd15_normalbae",
"sam": "mfidabel_controlnet-segment-anything",
}
weights_path = test_weights_path / "controlnet" / f"{weights_fn[cn_name]}.safetensors"
yield (cn_name, condition_image, expected_image, weights_path)
@pytest.fixture(scope="module")
def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
cn_name = "canny"
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11p_sd15_canny.safetensors"
return cn_name, condition_image, expected_image, weights_path
@pytest.fixture(scope="module")
def controlnet_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
cn_name = "depth"
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11f1p_sd15_depth.safetensors"
return cn_name, condition_image, expected_image, weights_path
@dataclass
class ControlLoraConfig:
scale: float
condition_path: str
weights_path: str
@dataclass
class ControlLoraResolvedConfig:
scale: float
condition_image: Image.Image
weights_path: Path
CONTROL_LORA_CONFIGS: dict[str, dict[str, ControlLoraConfig]] = {
"expected_controllora_PyraCanny.png": {
"PyraCanny": ControlLoraConfig(
scale=1.0,
condition_path="cutecat_guide_PyraCanny.png",
weights_path="refiners_control-lora-canny-rank128.safetensors",
),
},
"expected_controllora_CPDS.png": {
"CPDS": ControlLoraConfig(
scale=1.0,
condition_path="cutecat_guide_CPDS.png",
weights_path="refiners_fooocus_xl_cpds_128.safetensors",
),
},
"expected_controllora_PyraCanny+CPDS.png": {
"PyraCanny": ControlLoraConfig(
scale=0.55,
condition_path="cutecat_guide_PyraCanny.png",
weights_path="refiners_control-lora-canny-rank128.safetensors",
),
"CPDS": ControlLoraConfig(
scale=0.55,
condition_path="cutecat_guide_CPDS.png",
weights_path="refiners_fooocus_xl_cpds_128.safetensors",
),
},
"expected_controllora_disabled.png": {
"PyraCanny": ControlLoraConfig(
scale=0.0,
condition_path="cutecat_guide_PyraCanny.png",
weights_path="refiners_control-lora-canny-rank128.safetensors",
),
"CPDS": ControlLoraConfig(
scale=0.0,
condition_path="cutecat_guide_CPDS.png",
weights_path="refiners_fooocus_xl_cpds_128.safetensors",
),
},
}
@pytest.fixture(params=CONTROL_LORA_CONFIGS.items())
def controllora_sdxl_config(
request: pytest.FixtureRequest,
ref_path: Path,
test_weights_path: Path,
) -> tuple[Image.Image, dict[str, ControlLoraResolvedConfig]]:
name: str = request.param[0]
configs: dict[str, ControlLoraConfig] = request.param[1]
expected_image = _img_open(ref_path / name).convert("RGB")
loaded_configs = {
config_name: ControlLoraResolvedConfig(
scale=config.scale,
condition_image=_img_open(ref_path / config.condition_path).convert("RGB"),
weights_path=test_weights_path / "control-loras" / config.weights_path,
)
for config_name, config in configs.items()
}
return expected_image, loaded_configs
@pytest.fixture(scope="module")
def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
name = "depth"
condition_image = _img_open(ref_path / f"cutecat_guide_{name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_t2i_adapter_{name}.png").convert("RGB")
weights_path = test_weights_path / "T2I-Adapter" / "t2iadapter_depth_sd15v2.safetensors"
return name, condition_image, expected_image, weights_path
@pytest.fixture(scope="module")
def t2i_adapter_xl_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
name = "canny"
condition_image = _img_open(ref_path / f"fairy_guide_{name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB")
weights_path = test_weights_path / "T2I-Adapter" / "t2i-adapter-canny-sdxl-1.0.safetensors"
if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
return name, condition_image, expected_image, weights_path
@pytest.fixture(scope="module")
def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]:
expected_image = _img_open(ref_path / "expected_lora_pokemon.png").convert("RGB")
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin"
if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
tensors = load_tensors(weights_path)
return expected_image, tensors
@pytest.fixture(scope="module")
def lora_data_dpo(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]:
expected_image = _img_open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB")
weights_path = test_weights_path / "loras" / "dpo-lora" / "pytorch_lora_weights.safetensors"
if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
tensors = load_from_safetensors(weights_path)
return expected_image, tensors
@pytest.fixture(scope="module")
def lora_sliders(test_weights_path: Path) -> tuple[dict[str, dict[str, torch.Tensor]], dict[str, float]]:
weights_path = test_weights_path / "loras" / "sliders"
if not weights_path.is_dir():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
return {
"age": load_tensors(weights_path / "age.pt"), # type: ignore
"cartoon_style": load_tensors(weights_path / "cartoon_style.pt"), # type: ignore
"eyesize": load_tensors(weights_path / "eyesize.pt"), # type: ignore
}, {
"age": 0.3,
"cartoon_style": -0.2,
"eyesize": -0.2,
}
@pytest.fixture
def scene_image_inpainting_refonly(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "inpainting-scene.png").convert("RGB")
@pytest.fixture
def mask_image_inpainting_refonly(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "inpainting-mask.png").convert("RGB")
@pytest.fixture
def target_image_inpainting_refonly(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "inpainting-target.png").convert("RGB")
@pytest.fixture
def expected_image_inpainting_refonly(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_inpainting_refonly.png").convert("RGB")
@pytest.fixture
def expected_image_refonly(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_refonly.png").convert("RGB")
@pytest.fixture
def condition_image_refonly(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "cyberpunk_guide.png").convert("RGB")
@pytest.fixture
def expected_image_textual_inversion_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_textual_inversion_random_init.png").convert("RGB")
@pytest.fixture
def expected_multi_diffusion(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_multi_diffusion.png").convert(mode="RGB")
@pytest.fixture
def expected_restart(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_restart.png").convert(mode="RGB")
@pytest.fixture
def expected_freeu(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_freeu.png").convert(mode="RGB")
@pytest.fixture
def expected_sdxl_multi_loras(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_sdxl_multi_loras.png").convert(mode="RGB")
@pytest.fixture
def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image.Image, Image.Image]:
assets = Path(__file__).parent.parent.parent / "assets"
dropy = assets / "dropy_logo.png"
image_prompt = assets / "dragon_quest_slime.jpg"
condition_image = assets / "dropy_canny.png"
return (
_img_open(dropy).convert(mode="RGB"),
_img_open(image_prompt).convert(mode="RGB"),
_img_open(condition_image).convert(mode="RGB"),
_img_open(ref_path / "expected_dropy_slime_9752.png").convert(mode="RGB"),
)
@pytest.fixture
def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
return load_tensors(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"]
@pytest.fixture(scope="module")
def text_encoder_weights(test_weights_path: Path) -> Path:
text_encoder_weights = test_weights_path / "CLIPTextEncoderL.safetensors"
if not text_encoder_weights.is_file():
warn(f"could not find weights at {text_encoder_weights}, skipping")
pytest.skip(allow_module_level=True)
return text_encoder_weights
@pytest.fixture(scope="module")
def lda_weights(test_weights_path: Path) -> Path:
lda_weights = test_weights_path / "lda.safetensors"
if not lda_weights.is_file():
warn(f"could not find weights at {lda_weights}, skipping")
pytest.skip(allow_module_level=True)
return lda_weights
@pytest.fixture(scope="module")
def unet_weights_std(test_weights_path: Path) -> Path:
unet_weights_std = test_weights_path / "unet.safetensors"
if not unet_weights_std.is_file():
warn(f"could not find weights at {unet_weights_std}, skipping")
pytest.skip(allow_module_level=True)
return unet_weights_std
@pytest.fixture(scope="module")
def unet_weights_inpainting(test_weights_path: Path) -> Path:
unet_weights_inpainting = test_weights_path / "inpainting" / "unet.safetensors"
if not unet_weights_inpainting.is_file():
warn(f"could not find weights at {unet_weights_inpainting}, skipping")
pytest.skip(allow_module_level=True)
return unet_weights_inpainting
@pytest.fixture(scope="module")
def lda_ft_mse_weights(test_weights_path: Path) -> Path:
lda_weights = test_weights_path / "lda_ft_mse.safetensors"
if not lda_weights.is_file():
warn(f"could not find weights at {lda_weights}, skipping")
pytest.skip(allow_module_level=True)
return lda_weights
@pytest.fixture(scope="module")
def ip_adapter_weights(test_weights_path: Path) -> Path:
ip_adapter_weights = test_weights_path / "ip-adapter_sd15.safetensors"
if not ip_adapter_weights.is_file():
warn(f"could not find weights at {ip_adapter_weights}, skipping")
pytest.skip(allow_module_level=True)
return ip_adapter_weights
@pytest.fixture(scope="module")
def ip_adapter_plus_weights(test_weights_path: Path) -> Path:
ip_adapter_weights = test_weights_path / "ip-adapter-plus_sd15.safetensors"
if not ip_adapter_weights.is_file():
warn(f"could not find weights at {ip_adapter_weights}, skipping")
pytest.skip(allow_module_level=True)
return ip_adapter_weights
@pytest.fixture(scope="module")
def sdxl_ip_adapter_weights(test_weights_path: Path) -> Path:
ip_adapter_weights = test_weights_path / "ip-adapter_sdxl_vit-h.safetensors"
if not ip_adapter_weights.is_file():
warn(f"could not find weights at {ip_adapter_weights}, skipping")
pytest.skip(allow_module_level=True)
return ip_adapter_weights
@pytest.fixture(scope="module")
def sdxl_ip_adapter_plus_weights(test_weights_path: Path) -> Path:
ip_adapter_weights = test_weights_path / "ip-adapter-plus_sdxl_vit-h.safetensors"
if not ip_adapter_weights.is_file():
warn(f"could not find weights at {ip_adapter_weights}, skipping")
pytest.skip(allow_module_level=True)
return ip_adapter_weights
@pytest.fixture(scope="module")
def image_encoder_weights(test_weights_path: Path) -> Path:
image_encoder_weights = test_weights_path / "CLIPImageEncoderH.safetensors"
if not image_encoder_weights.is_file():
warn(f"could not find weights at {image_encoder_weights}, skipping")
pytest.skip(allow_module_level=True)
return image_encoder_weights
@pytest.fixture
def sd15_std(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
) -> StableDiffusion_1:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
sd15 = StableDiffusion_1(device=test_device)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_std)
return sd15
@pytest.fixture
def sd15_std_float16(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
) -> StableDiffusion_1:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
sd15 = StableDiffusion_1(device=test_device, dtype=torch.float16)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_std)
return sd15
@pytest.fixture
def sd15_inpainting(
text_encoder_weights: Path, lda_weights: Path, unet_weights_inpainting: Path, test_device: torch.device
) -> StableDiffusion_1_Inpainting:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
unet = SD1UNet(in_channels=9)
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_inpainting)
return sd15
@pytest.fixture
def sd15_inpainting_float16(
text_encoder_weights: Path, lda_weights: Path, unet_weights_inpainting: Path, test_device: torch.device
) -> StableDiffusion_1_Inpainting:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
unet = SD1UNet(in_channels=9)
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device, dtype=torch.float16)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_inpainting)
return sd15
@pytest.fixture
def sd15_ddim(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
) -> StableDiffusion_1:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
ddim_solver = DDIM(num_inference_steps=20)
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_std)
return sd15
@pytest.fixture
def sd15_ddim_karras(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
) -> StableDiffusion_1:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
ddim_solver = DDIM(num_inference_steps=20, params=SolverParams(noise_schedule=NoiseSchedule.KARRAS))
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_std)
return sd15
@pytest.fixture
def sd15_euler(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
) -> StableDiffusion_1:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
euler_solver = Euler(num_inference_steps=30)
sd15 = StableDiffusion_1(solver=euler_solver, device=test_device)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_std)
return sd15
@pytest.fixture
def sd15_ddim_lda_ft_mse(
text_encoder_weights: Path, lda_ft_mse_weights: Path, unet_weights_std: Path, test_device: torch.device
) -> StableDiffusion_1:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
ddim_solver = DDIM(num_inference_steps=20)
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
sd15.lda.load_state_dict(load_from_safetensors(lda_ft_mse_weights))
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std))
return sd15
@pytest.fixture
def sdxl_lda_weights(test_weights_path: Path) -> Path:
sdxl_lda_weights = test_weights_path / "sdxl-lda.safetensors"
if not sdxl_lda_weights.is_file():
warn(message=f"could not find weights at {sdxl_lda_weights}, skipping")
pytest.skip(allow_module_level=True)
return sdxl_lda_weights
@pytest.fixture
def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path:
sdxl_lda_weights = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
if not sdxl_lda_weights.is_file():
warn(message=f"could not find weights at {sdxl_lda_weights}, skipping")
pytest.skip(allow_module_level=True)
return sdxl_lda_weights
@pytest.fixture
def sdxl_unet_weights(test_weights_path: Path) -> Path:
sdxl_unet_weights = test_weights_path / "sdxl-unet.safetensors"
if not sdxl_unet_weights.is_file():
warn(message=f"could not find weights at {sdxl_unet_weights}, skipping")
pytest.skip(allow_module_level=True)
return sdxl_unet_weights
@pytest.fixture
def sdxl_text_encoder_weights(test_weights_path: Path) -> Path:
sdxl_double_text_encoder_weights = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
if not sdxl_double_text_encoder_weights.is_file():
warn(message=f"could not find weights at {sdxl_double_text_encoder_weights}, skipping")
pytest.skip(allow_module_level=True)
return sdxl_double_text_encoder_weights
@pytest.fixture
def sdxl_ddim(
sdxl_text_encoder_weights: Path, sdxl_lda_weights: Path, sdxl_unet_weights: Path, test_device: torch.device
) -> StableDiffusion_XL:
if test_device.type == "cpu":
warn(message="not running on CPU, skipping")
pytest.skip()
solver = DDIM(num_inference_steps=30)
sdxl = StableDiffusion_XL(solver=solver, device=test_device)
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_weights)
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights)
return sdxl
@pytest.fixture
def sdxl_ddim_lda_fp16_fix(
sdxl_text_encoder_weights: Path, sdxl_lda_fp16_fix_weights: Path, sdxl_unet_weights: Path, test_device: torch.device
) -> StableDiffusion_XL:
if test_device.type == "cpu":
warn(message="not running on CPU, skipping")
pytest.skip()
solver = DDIM(num_inference_steps=30)
sdxl = StableDiffusion_XL(solver=solver, device=test_device)
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights)
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights)
return sdxl
@pytest.fixture
def sdxl_euler_deterministic(sdxl_ddim: StableDiffusion_XL) -> StableDiffusion_XL:
return StableDiffusion_XL(
unet=sdxl_ddim.unet,
lda=sdxl_ddim.lda,
clip_text_encoder=sdxl_ddim.clip_text_encoder,
solver=Euler(num_inference_steps=30),
device=sdxl_ddim.device,
dtype=sdxl_ddim.dtype,
)
@no_grad()
def test_diffusion_std_random_init(
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
):
sd15 = sd15_std
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_inference_steps(30)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_std_random_init)
@no_grad()
def test_diffusion_batch2(sd15_std: StableDiffusion_1):
sd15 = sd15_std
prompt1 = "a cute cat, detailed high-quality professional image"
negative_prompt1 = "lowres, bad anatomy, bad hands, cropped, worst quality"
prompt2 = "a cute dog"
negative_prompt2 = "lowres, bad anatomy, bad hands"
clip_text_embedding_b2 = sd15.compute_clip_text_embedding(
text=[prompt1, prompt2], negative_text=[negative_prompt1, negative_prompt2]
)
step = sd15.steps[0]
manual_seed(2)
rand_b2 = torch.randn(2, 4, 64, 64, device=sd15.device)
x_b2 = sd15(
rand_b2,
step=step,
clip_text_embedding=clip_text_embedding_b2,
condition_scale=7.5,
)
assert x_b2.shape == (2, 4, 64, 64)
rand_1 = rand_b2[0:1]
clip_text_embedding_1 = sd15.compute_clip_text_embedding(text=[prompt1], negative_text=[negative_prompt1])
x_1 = sd15(
rand_1,
step=step,
clip_text_embedding=clip_text_embedding_1,
condition_scale=7.5,
)
rand_2 = rand_b2[1:2]
clip_text_embedding_2 = sd15.compute_clip_text_embedding(text=[prompt2], negative_text=[negative_prompt2])
x_2 = sd15(
rand_2,
step=step,
clip_text_embedding=clip_text_embedding_2,
condition_scale=7.5,
)
# The 5e-3 tolerance is detailed in https://github.com/finegrain-ai/refiners/pull/263#issuecomment-1956404911
assert torch.allclose(
x_b2[0], x_1[0], atol=5e-3, rtol=0
), f"Batch 2 and batch1 output should be the same and are distant of {torch.max((x_b2[0] - x_1[0]).abs()).item()}"
assert torch.allclose(
x_b2[1], x_2[0], atol=5e-3, rtol=0
), f"Batch 2 and batch1 output should be the same and are distant of {torch.max((x_b2[1] - x_2[0]).abs()).item()}"
@no_grad()
def test_diffusion_std_random_init_euler(
sd15_euler: StableDiffusion_1, expected_image_std_random_init_euler: Image.Image, test_device: torch.device
):
sd15 = sd15_euler
euler_solver = sd15_euler.solver
assert isinstance(euler_solver, Euler)
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_inference_steps(30)
manual_seed(2)
x = sd15.init_latents((512, 512)).to(sd15.device, sd15.dtype)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_std_random_init_euler)
@no_grad()
def test_diffusion_karras_random_init(
sd15_ddim_karras: StableDiffusion_1, expected_karras_random_init: Image.Image, test_device: torch.device
):
sd15 = sd15_ddim_karras
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_karras_random_init, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_diffusion_std_random_init_float16(
sd15_std_float16: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
):
sd15 = sd15_std_float16
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
assert clip_text_embedding.dtype == torch.float16
sd15.set_inference_steps(30)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_diffusion_std_random_init_sag(
sd15_std: StableDiffusion_1, expected_image_std_random_init_sag: Image.Image, test_device: torch.device
):
sd15 = sd15_std
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_inference_steps(30)
sd15.set_self_attention_guidance(enable=True, scale=0.75)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_std_random_init_sag)
@no_grad()
def test_diffusion_std_init_image(
sd15_std: StableDiffusion_1,
cutecat_init: Image.Image,
expected_image_std_init_image: Image.Image,
):
sd15 = sd15_std
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_inference_steps(35, first_step=5)
manual_seed(2)
x = sd15.init_latents((512, 512), cutecat_init)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_std_init_image)
@no_grad()
def test_rectangular_init_latents(
sd15_std: StableDiffusion_1,
cutecat_init: Image.Image,
):
sd15 = sd15_std
# Just check latents initialization with a non-square image (and not the entire diffusion)
width, height = 512, 504
rect_init_image = cutecat_init.crop((0, 0, width, height))
x = sd15.init_latents((height, width), rect_init_image)
assert sd15.lda.latents_to_image(x).size == (width, height)
@no_grad()
def test_diffusion_inpainting(
sd15_inpainting: StableDiffusion_1_Inpainting,
kitchen_dog: Image.Image,
kitchen_dog_mask: Image.Image,
expected_image_std_inpainting: Image.Image,
test_device: torch.device,
):
sd15 = sd15_inpainting
prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_inference_steps(30)
sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
# PSNR and SSIM values are large because with float32 we get large differences even v.s. ourselves.
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95)
@no_grad()
def test_diffusion_inpainting_float16(
sd15_inpainting_float16: StableDiffusion_1_Inpainting,
kitchen_dog: Image.Image,
kitchen_dog_mask: Image.Image,
expected_image_std_inpainting: Image.Image,
test_device: torch.device,
):
sd15 = sd15_inpainting_float16
prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
assert clip_text_embedding.dtype == torch.float16
sd15.set_inference_steps(30)
sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
# PSNR and SSIM values are large because float16 is even worse than float32.
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92)
@no_grad()
def test_diffusion_controlnet(
sd15_std: StableDiffusion_1,
controlnet_data: tuple[str, Image.Image, Image.Image, Path],
test_device: torch.device,
):
sd15 = sd15_std
cn_name, condition_image, expected_image, cn_weights_path = controlnet_data
if not cn_weights_path.is_file():
warn(f"could not find weights at {cn_weights_path}, skipping")
pytest.skip(allow_module_level=True)
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_inference_steps(30)
controlnet = SD1ControlnetAdapter(
sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path)
).inject()
cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
for step in sd15.steps:
controlnet.set_controlnet_condition(cn_condition)
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_diffusion_controlnet_structural_copy(
sd15_std: StableDiffusion_1,
controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path],
test_device: torch.device,
):
sd15_base = sd15_std
sd15 = sd15_base.structural_copy()
cn_name, condition_image, expected_image, cn_weights_path = controlnet_data_canny
if not cn_weights_path.is_file():
warn(f"could not find weights at {cn_weights_path}, skipping")
pytest.skip(allow_module_level=True)
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_inference_steps(30)
controlnet = SD1ControlnetAdapter(
sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path)
).inject()
cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
for step in sd15.steps:
controlnet.set_controlnet_condition(cn_condition)
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_diffusion_controlnet_float16(
sd15_std_float16: StableDiffusion_1,
controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path],
test_device: torch.device,
):
sd15 = sd15_std_float16
cn_name, condition_image, expected_image, cn_weights_path = controlnet_data_canny
if not cn_weights_path.is_file():
warn(f"could not find weights at {cn_weights_path}, skipping")
pytest.skip(allow_module_level=True)
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_inference_steps(30)
controlnet = SD1ControlnetAdapter(
sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path)
).inject()
cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device, dtype=torch.float16)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
for step in sd15.steps:
controlnet.set_controlnet_condition(cn_condition)
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_diffusion_controlnet_stack(
sd15_std: StableDiffusion_1,
controlnet_data_depth: tuple[str, Image.Image, Image.Image, Path],
controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path],
expected_image_controlnet_stack: Image.Image,
test_device: torch.device,
):
sd15 = sd15_std
_, depth_condition_image, _, depth_cn_weights_path = controlnet_data_depth
_, canny_condition_image, _, canny_cn_weights_path = controlnet_data_canny
if not canny_cn_weights_path.is_file():
warn(f"could not find weights at {canny_cn_weights_path}, skipping")
pytest.skip(allow_module_level=True)
if not depth_cn_weights_path.is_file():
warn(f"could not find weights at {depth_cn_weights_path}, skipping")
pytest.skip(allow_module_level=True)
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_inference_steps(30)
depth_controlnet = SD1ControlnetAdapter(
sd15.unet, name="depth", scale=0.3, weights=load_from_safetensors(depth_cn_weights_path)
).inject()
canny_controlnet = SD1ControlnetAdapter(
sd15.unet, name="canny", scale=0.7, weights=load_from_safetensors(canny_cn_weights_path)
).inject()
depth_cn_condition = image_to_tensor(depth_condition_image.convert("RGB"), device=test_device)
canny_cn_condition = image_to_tensor(canny_condition_image.convert("RGB"), device=test_device)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
for step in sd15.steps:
depth_controlnet.set_controlnet_condition(depth_cn_condition)
canny_controlnet.set_controlnet_condition(canny_cn_condition)
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_diffusion_sdxl_control_lora(
controllora_sdxl_config: tuple[Image.Image, dict[str, ControlLoraResolvedConfig]],
sdxl_ddim_lda_fp16_fix: StableDiffusion_XL,
) -> None:
sdxl = sdxl_ddim_lda_fp16_fix.to(dtype=torch.float16)
sdxl.dtype = torch.float16 # FIXME: should not be necessary
expected_image = controllora_sdxl_config[0]
configs = controllora_sdxl_config[1]
adapters: dict[str, ControlLoraAdapter] = {}
for config_name, config in configs.items():
adapter = ControlLoraAdapter(
name=config_name,
scale=config.scale,
target=sdxl.unet,
weights=load_from_safetensors(
path=config.weights_path,
device=sdxl.device,
),
)
adapter.set_condition(
image_to_tensor(
image=config.condition_image,
device=sdxl.device,
dtype=sdxl.dtype,
)
)
adapters[config_name] = adapter
# inject all the control lora adapters
for adapter in adapters.values():
adapter.inject()
# compute the text embeddings
prompt = "a cute cat, flying in the air, detailed high-quality professional image, blank background"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality, watermarks"
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt,
negative_text=negative_prompt,
)
# initialize the latents
manual_seed(2)
x = torch.randn(
(1, 4, 128, 128),
device=sdxl.device,
dtype=sdxl.dtype,
)
# denoise
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=sdxl.default_time_ids,
)
# decode latent to image
predicted_image = sdxl.lda.decode_latents(x)
# ensure the predicted image is similar to the expected image
ensure_similar_images(
img_1=predicted_image,
img_2=expected_image,
min_psnr=35,
min_ssim=0.99,
)
@no_grad()
def test_diffusion_lora(
sd15_std: StableDiffusion_1,
lora_data_pokemon: tuple[Image.Image, dict[str, torch.Tensor]],
test_device: torch.device,
) -> None:
sd15 = sd15_std
expected_image, lora_weights = lora_data_pokemon
prompt = "a cute cat"
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sd15.set_inference_steps(30)
SDLoraManager(sd15).add_loras("pokemon", lora_weights, scale=1)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_diffusion_sdxl_batch2(sdxl_ddim: StableDiffusion_XL) -> None:
sdxl = sdxl_ddim
prompt1 = "a cute cat, detailed high-quality professional image"
negative_prompt1 = "lowres, bad anatomy, bad hands, cropped, worst quality"
prompt2 = "a cute dog"
negative_prompt2 = "lowres, bad anatomy, bad hands"
clip_text_embedding_b2, pooled_text_embedding_b2 = sdxl.compute_clip_text_embedding(
text=[prompt1, prompt2], negative_text=[negative_prompt1, negative_prompt2]
)
time_ids = sdxl.default_time_ids
time_ids_b2 = sdxl.default_time_ids.repeat(2, 1)
manual_seed(seed=2)
x_b2 = torch.randn(2, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype)
x_1 = x_b2[0:1]
x_2 = x_b2[1:2]
x_b2 = sdxl(
x_b2,
step=sdxl.steps[0],
clip_text_embedding=clip_text_embedding_b2,
pooled_text_embedding=pooled_text_embedding_b2,
time_ids=time_ids_b2,
)
clip_text_embedding_1, pooled_text_embedding_1 = sdxl.compute_clip_text_embedding(
text=prompt1, negative_text=negative_prompt1
)
x_1 = sdxl(
x_1,
step=sdxl.steps[0],
clip_text_embedding=clip_text_embedding_1,
pooled_text_embedding=pooled_text_embedding_1,
time_ids=time_ids,
)
clip_text_embedding_2, pooled_text_embedding_2 = sdxl.compute_clip_text_embedding(
text=prompt2, negative_text=negative_prompt2
)
x_2 = sdxl(
x_2,
step=sdxl.steps[0],
clip_text_embedding=clip_text_embedding_2,
pooled_text_embedding=pooled_text_embedding_2,
time_ids=time_ids,
)
# The 5e-3 tolerance is detailed in https://github.com/finegrain-ai/refiners/pull/263#issuecomment-1956404911
assert torch.allclose(
x_b2[0], x_1[0], atol=5e-3, rtol=0
), f"Batch 2 and batch1 output should be the same and are distant of {torch.max((x_b2[0] - x_1[0]).abs()).item()}"
assert torch.allclose(
x_b2[1], x_2[0], atol=5e-3, rtol=0
), f"Batch 2 and batch1 output should be the same and are distant of {torch.max((x_b2[1] - x_2[0]).abs()).item()}"
@no_grad()
def test_diffusion_sdxl_lora(
sdxl_ddim: StableDiffusion_XL,
lora_data_dpo: tuple[Image.Image, dict[str, torch.Tensor]],
) -> None:
sdxl = sdxl_ddim
expected_image, lora_weights = lora_data_dpo
# parameters are the same as https://huggingface.co/radames/sdxl-DPO-LoRA
# except that we are using DDIM instead of sde-dpmsolver++
seed = 12341234123
guidance_scale = 7.5
lora_scale = 1.4
prompt = "professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography"
negative_prompt = "3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white"
SDLoraManager(sdxl).add_loras("dpo", lora_weights, scale=lora_scale, unet_inclusions=["CrossAttentionBlock"])
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt
)
time_ids = sdxl.default_time_ids
sdxl.set_inference_steps(40)
manual_seed(seed=seed)
x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=guidance_scale,
)
predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_diffusion_sdxl_multiple_loras(
sdxl_ddim: StableDiffusion_XL,
lora_data_dpo: tuple[Image.Image, dict[str, torch.Tensor]],
lora_sliders: tuple[dict[str, dict[str, torch.Tensor]], dict[str, float]],
expected_sdxl_multi_loras: Image.Image,
) -> None:
sdxl = sdxl_ddim
expected_image = expected_sdxl_multi_loras
_, dpo_weights = lora_data_dpo
slider_loras, slider_scales = lora_sliders
manager = SDLoraManager(sdxl)
for lora_name, lora_weights in slider_loras.items():
manager.add_loras(
lora_name,
lora_weights,
slider_scales[lora_name],
unet_inclusions=["SelfAttention", "ResidualBlock", "Downsample", "Upsample"],
)
manager.add_loras("dpo", dpo_weights, 1.4, unet_inclusions=["CrossAttentionBlock"])
# parameters are the same as https://huggingface.co/radames/sdxl-DPO-LoRA
# except that we are using DDIM instead of sde-dpmsolver++
n_steps = 40
seed = 12341234123
guidance_scale = 4
prompt = "professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography"
negative_prompt = "3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white"
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt
)
time_ids = sdxl.default_time_ids
sdxl.set_inference_steps(n_steps)
manual_seed(seed=seed)
x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=guidance_scale,
)
predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_diffusion_refonly(
sd15_ddim: StableDiffusion_1,
condition_image_refonly: Image.Image,
expected_image_refonly: Image.Image,
test_device: torch.device,
):
sd15 = sd15_ddim
prompt = "Chicken"
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
refonly_adapter = ReferenceOnlyControlAdapter(sd15.unet).inject()
guide = sd15.lda.image_to_latents(condition_image_refonly)
guide = torch.cat((guide, guide))
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
for step in sd15.steps:
noise = torch.randn(2, 4, 64, 64, device=test_device)
noised_guide = sd15.solver.add_noise(guide, noise, step)
refonly_adapter.set_controlnet_condition(noised_guide)
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
torch.randn(2, 4, 64, 64, device=test_device) # for SD Web UI reproductibility only
predicted_image = sd15.lda.latents_to_image(x)
# min_psnr lowered to 33 because this reference image was generated without noise removal (see #192)
ensure_similar_images(predicted_image, expected_image_refonly, min_psnr=33, min_ssim=0.99)
@no_grad()
def test_diffusion_inpainting_refonly(
sd15_inpainting: StableDiffusion_1_Inpainting,
scene_image_inpainting_refonly: Image.Image,
target_image_inpainting_refonly: Image.Image,
mask_image_inpainting_refonly: Image.Image,
expected_image_inpainting_refonly: Image.Image,
test_device: torch.device,
):
sd15 = sd15_inpainting
prompt = "" # unconditional
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
refonly_adapter = ReferenceOnlyControlAdapter(sd15.unet).inject()
sd15.set_inference_steps(30)
sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly)
guide = sd15.lda.image_to_latents(scene_image_inpainting_refonly)
guide = torch.cat((guide, guide))
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
for step in sd15.steps:
noise = torch.randn_like(guide)
noised_guide = sd15.solver.add_noise(guide, noise, step)
# See https://github.com/Mikubill/sd-webui-controlnet/pull/1275 ("1.1.170 reference-only begin to support
# inpaint variation models")
noised_guide = torch.cat([noised_guide, torch.zeros_like(noised_guide)[:, 0:1, :, :], guide], dim=1)
refonly_adapter.set_controlnet_condition(noised_guide)
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99)
@no_grad()
def test_diffusion_textual_inversion_random_init(
sd15_std: StableDiffusion_1,
expected_image_textual_inversion_random_init: Image.Image,
text_embedding_textual_inversion: torch.Tensor,
test_device: torch.device,
):
sd15 = sd15_std
conceptExtender = ConceptExtender(sd15.clip_text_encoder)
conceptExtender.add_concept("<gta5-artwork>", text_embedding_textual_inversion)
conceptExtender.inject()
prompt = "a cute cat on a <gta5-artwork>"
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sd15.set_inference_steps(30)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_diffusion_ip_adapter(
sd15_ddim_lda_ft_mse: StableDiffusion_1,
ip_adapter_weights: Path,
image_encoder_weights: Path,
woman_image: Image.Image,
expected_image_ip_adapter_woman: Image.Image,
test_device: torch.device,
):
sd15 = sd15_ddim_lda_ft_mse.to(dtype=torch.float16)
# See tencent-ailab/IP-Adapter best practices section:
#
# If you only use the image prompt, you can set the scale=1.0 and text_prompt="" (or some generic text
# prompts, e.g. "best quality", you can also use any negative text prompt).
#
# The prompts below are the ones used by default by IPAdapter's generate method if none are specified
prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_weights))
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.inject()
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
ip_adapter.set_clip_image_embedding(clip_image_embedding)
sd15.set_inference_steps(50)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)
@no_grad()
def test_diffusion_ip_adapter_multi(
sd15_ddim_lda_ft_mse: StableDiffusion_1,
ip_adapter_weights: Path,
image_encoder_weights: Path,
woman_image: Image.Image,
statue_image: Image.Image,
expected_image_ip_adapter_multi: Image.Image,
test_device: torch.device,
):
sd15 = sd15_ddim_lda_ft_mse.to(dtype=torch.float16)
prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_weights))
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.inject()
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
clip_image_embedding = ip_adapter.compute_clip_image_embedding([woman_image, statue_image], weights=[1.0, 1.4])
ip_adapter.set_clip_image_embedding(clip_image_embedding)
sd15.set_inference_steps(50)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image_ip_adapter_multi)
@no_grad()
def test_diffusion_sdxl_ip_adapter(
sdxl_ddim: StableDiffusion_XL,
sdxl_ip_adapter_weights: Path,
image_encoder_weights: Path,
woman_image: Image.Image,
expected_image_sdxl_ip_adapter_woman: Image.Image,
test_device: torch.device,
):
sdxl = sdxl_ddim.to(dtype=torch.float16)
prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_weights))
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.inject()
with no_grad():
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt
)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
ip_adapter.set_clip_image_embedding(clip_image_embedding)
time_ids = sdxl.default_time_ids
sdxl.set_inference_steps(30)
manual_seed(2)
x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16)
with no_grad():
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=5,
)
# See https://huggingface.co/madebyollin/sdxl-vae-fp16-fix: "SDXL-VAE generates NaNs in fp16 because the
# internal activation values are too big"
sdxl.lda.to(dtype=torch.float32)
predicted_image = sdxl.lda.latents_to_image(x.to(dtype=torch.float32))
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_woman)
@no_grad()
def test_diffusion_ip_adapter_controlnet(
sd15_ddim: StableDiffusion_1,
ip_adapter_weights: Path,
image_encoder_weights: Path,
lora_data_pokemon: tuple[Image.Image, Path],
controlnet_data_depth: tuple[str, Image.Image, Image.Image, Path],
expected_image_ip_adapter_controlnet: Image.Image,
test_device: torch.device,
):
sd15 = sd15_ddim.to(dtype=torch.float16)
input_image, _ = lora_data_pokemon # use the Pokemon LoRA output as input
_, depth_condition_image, _, depth_cn_weights_path = controlnet_data_depth
prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_weights))
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.inject()
depth_controlnet = SD1ControlnetAdapter(
sd15.unet,
name="depth",
scale=1.0,
weights=load_from_safetensors(depth_cn_weights_path),
).inject()
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(input_image))
ip_adapter.set_clip_image_embedding(clip_image_embedding)
depth_cn_condition = image_to_tensor(
depth_condition_image.convert("RGB"),
device=test_device,
dtype=torch.float16,
)
sd15.set_inference_steps(50)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
for step in sd15.steps:
depth_controlnet.set_controlnet_condition(depth_cn_condition)
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_ip_adapter_controlnet)
@no_grad()
def test_diffusion_ip_adapter_plus(
sd15_ddim_lda_ft_mse: StableDiffusion_1,
ip_adapter_plus_weights: Path,
image_encoder_weights: Path,
statue_image: Image.Image,
expected_image_ip_adapter_plus_statue: Image.Image,
test_device: torch.device,
):
sd15 = sd15_ddim_lda_ft_mse.to(dtype=torch.float16)
prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SD1IPAdapter(
target=sd15.unet, weights=load_from_safetensors(ip_adapter_plus_weights), fine_grained=True
)
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.inject()
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(statue_image))
ip_adapter.set_clip_image_embedding(clip_image_embedding)
sd15.set_inference_steps(50)
manual_seed(42) # seed=42 is used in the official IP-Adapter demo
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_ip_adapter_plus_statue, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_diffusion_sdxl_ip_adapter_plus(
sdxl_ddim: StableDiffusion_XL,
sdxl_ip_adapter_plus_weights: Path,
image_encoder_weights: Path,
woman_image: Image.Image,
expected_image_sdxl_ip_adapter_plus_woman: Image.Image,
test_device: torch.device,
):
sdxl = sdxl_ddim.to(dtype=torch.float16)
prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SDXLIPAdapter(
target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_plus_weights), fine_grained=True
)
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.inject()
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt
)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
ip_adapter.set_clip_image_embedding(clip_image_embedding)
time_ids = sdxl.default_time_ids
sdxl.set_inference_steps(30)
manual_seed(2)
x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=5,
)
sdxl.lda.to(dtype=torch.float32)
predicted_image = sdxl.lda.latents_to_image(x.to(dtype=torch.float32))
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman)
@no_grad()
@pytest.mark.parametrize("structural_copy", [False, True])
def test_diffusion_sdxl_random_init(
sdxl_ddim: StableDiffusion_XL,
expected_sdxl_ddim_random_init: Image.Image,
test_device: torch.device,
structural_copy: bool,
) -> None:
sdxl = sdxl_ddim.structural_copy() if structural_copy else sdxl_ddim
expected_image = expected_sdxl_ddim_random_init
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt
)
time_ids = sdxl.default_time_ids
sdxl.set_inference_steps(30)
manual_seed(seed=2)
x = torch.randn(1, 4, 128, 128, device=test_device)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=5,
)
predicted_image = sdxl.lda.latents_to_image(x=x)
ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_diffusion_sdxl_random_init_sag(
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init_sag: Image.Image, test_device: torch.device
) -> None:
sdxl = sdxl_ddim
expected_image = expected_sdxl_ddim_random_init_sag
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt
)
time_ids = sdxl.default_time_ids
sdxl.set_inference_steps(30)
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
manual_seed(seed=2)
x = torch.randn(1, 4, 128, 128, device=test_device)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=5,
)
predicted_image = sdxl.lda.latents_to_image(x=x)
ensure_similar_images(img_1=predicted_image, img_2=expected_image)
@no_grad()
def test_diffusion_sdxl_sliced_attention(
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image
) -> None:
unet = sdxl_ddim.unet.structural_copy()
for layer in unet.layers(ScaledDotProductAttention):
layer.slice_size = 2048
sdxl = StableDiffusion_XL(
unet=unet,
lda=sdxl_ddim.lda,
clip_text_encoder=sdxl_ddim.clip_text_encoder,
solver=sdxl_ddim.solver,
device=sdxl_ddim.device,
dtype=sdxl_ddim.dtype,
)
expected_image = expected_sdxl_ddim_random_init
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt
)
time_ids = sdxl.default_time_ids
sdxl.set_inference_steps(30)
manual_seed(2)
x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=5,
)
predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_diffusion_sdxl_euler_deterministic(
sdxl_euler_deterministic: StableDiffusion_XL, expected_sdxl_euler_random_init: Image.Image
) -> None:
sdxl = sdxl_euler_deterministic
assert isinstance(sdxl.solver, Euler)
expected_image = expected_sdxl_euler_random_init
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt
)
time_ids = sdxl.default_time_ids
sdxl.set_inference_steps(30)
manual_seed(2)
x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=5,
)
predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image)
@no_grad()
def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None:
manual_seed(seed=2)
sd = sd15_ddim
multi_diffusion = SD1MultiDiffusion(sd)
clip_text_embedding = sd.compute_clip_text_embedding(text="a panorama of a mountain")
target_1 = DiffusionTarget(
size=(64, 64),
offset=(0, 0),
clip_text_embedding=clip_text_embedding,
start_step=0,
)
target_2 = DiffusionTarget(
size=(64, 64),
offset=(0, 16),
clip_text_embedding=clip_text_embedding,
condition_scale=3,
start_step=0,
)
noise = torch.randn(1, 4, 64, 80, device=sd.device, dtype=sd.dtype)
x = noise
for step in sd.steps:
x = multi_diffusion(
x,
noise=noise,
step=step,
targets=[target_1, target_2],
)
result = sd.lda.latents_to_image(x=x)
ensure_similar_images(img_1=result, img_2=expected_multi_diffusion, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_t2i_adapter_depth(
sd15_std: StableDiffusion_1,
t2i_adapter_data_depth: tuple[str, Image.Image, Image.Image, Path],
test_device: torch.device,
):
sd15 = sd15_std
name, condition_image, expected_image, weights_path = t2i_adapter_data_depth
if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_inference_steps(30)
t2i_adapter = SD1T2IAdapter(target=sd15.unet, name=name, weights=load_from_safetensors(weights_path)).inject()
condition = image_to_tensor(condition_image.convert("RGB"), device=test_device)
t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image)
@no_grad()
def test_t2i_adapter_xl_canny(
sdxl_ddim: StableDiffusion_XL,
t2i_adapter_xl_data_canny: tuple[str, Image.Image, Image.Image, Path],
test_device: torch.device,
):
sdxl = sdxl_ddim
name, condition_image, expected_image, weights_path = t2i_adapter_xl_data_canny
if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
prompt = "Mystical fairy in real, magic, 4k picture, high quality"
negative_prompt = (
"extra digit, fewer digits, cropped, worst quality, low quality, glitch, deformed, mutated, ugly, disfigured"
)
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt
)
time_ids = sdxl.default_time_ids
sdxl.set_inference_steps(30)
t2i_adapter = SDXLT2IAdapter(target=sdxl.unet, name=name, weights=load_from_safetensors(weights_path)).inject()
t2i_adapter.scale = 0.8
condition = image_to_tensor(condition_image.convert("RGB"), device=test_device)
t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))
manual_seed(2)
x = torch.randn(1, 4, condition_image.height // 8, condition_image.width // 8, device=test_device)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=7.5,
)
predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image)
@no_grad()
def test_restart(
sd15_ddim: StableDiffusion_1,
expected_restart: Image.Image,
test_device: torch.device,
):
sd15 = sd15_ddim
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_inference_steps(30)
restart = Restart(ldm=sd15)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=8,
)
if step == restart.start_step:
x = restart(
x,
clip_text_embedding=clip_text_embedding,
condition_scale=8,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_restart, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_freeu(
sd15_std: StableDiffusion_1,
expected_freeu: Image.Image,
):
sd15 = sd15_std
prompt = "best quality, high quality cute cat"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_inference_steps(50, first_step=1)
SDFreeUAdapter(
sd15.unet, backbone_scales=[1.2, 1.2, 1.2, 1.4, 1.4, 1.4], skip_scales=[0.9, 0.9, 0.9, 0.2, 0.2, 0.2]
).inject()
manual_seed(9752)
x = sd15.init_latents((512, 512)).to(device=sd15.device, dtype=sd15.dtype)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_freeu)
@no_grad()
def test_hello_world(
sdxl_ddim_lda_fp16_fix: StableDiffusion_XL,
t2i_adapter_xl_data_canny: tuple[str, Image.Image, Image.Image, Path],
sdxl_ip_adapter_weights: Path,
image_encoder_weights: Path,
hello_world_assets: tuple[Image.Image, Image.Image, Image.Image, Image.Image],
) -> None:
sdxl = sdxl_ddim_lda_fp16_fix.to(dtype=torch.float16)
sdxl.dtype = torch.float16 # FIXME: should not be necessary
name, _, _, weights_path = t2i_adapter_xl_data_canny
init_image, image_prompt, condition_image, expected_image = hello_world_assets
if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_weights))
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.inject()
image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt))
ip_adapter.set_clip_image_embedding(image_embedding)
# Note: default text prompts for IP-Adapter
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality"
)
time_ids = sdxl.default_time_ids
t2i_adapter = SDXLT2IAdapter(target=sdxl.unet, name=name, weights=load_from_safetensors(weights_path)).inject()
condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype)
t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))
ip_adapter.scale = 0.85
t2i_adapter.scale = 0.8
sdxl.set_inference_steps(50, first_step=1)
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
manual_seed(9752)
x = sdxl.init_latents(size=(1024, 1024), init_image=init_image).to(device=sdxl.device, dtype=sdxl.dtype)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
)
predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image)
@no_grad()
def test_style_aligned(
sdxl_ddim_lda_fp16_fix: StableDiffusion_XL,
expected_style_aligned: Image.Image,
):
sdxl = sdxl_ddim_lda_fp16_fix.to(dtype=torch.float16)
sdxl.dtype = torch.float16 # FIXME: should not be necessary
style_aligned_adapter = StyleAlignedAdapter(sdxl.unet)
style_aligned_adapter.inject()
set_of_prompts = [
"a toy train. macro photo. 3d game asset",
"a toy airplane. macro photo. 3d game asset",
"a toy bicycle. macro photo. 3d game asset",
"a toy car. macro photo. 3d game asset",
"a toy boat. macro photo. 3d game asset",
]
# create (context) embeddings from prompts
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=set_of_prompts, negative_text=[""] * len(set_of_prompts)
)
time_ids = sdxl.default_time_ids.repeat(len(set_of_prompts), 1)
# initialize latents
manual_seed(seed=2)
x = torch.randn(
(len(set_of_prompts), 4, 128, 128),
device=sdxl.device,
dtype=sdxl.dtype,
)
# denoise
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
)
# decode latents
predicted_images = [sdxl.lda.decode_latents(latent.unsqueeze(0)) for latent in x]
# tile all images horizontally
merged_image = Image.new("RGB", (1024 * len(predicted_images), 1024))
for i in range(len(predicted_images)):
merged_image.paste(predicted_images[i], (i * 1024, 0)) # type: ignore
# compare against reference image
ensure_similar_images(merged_image, expected_style_aligned, min_psnr=35, min_ssim=0.99)