import gc from dataclasses import dataclass from pathlib import Path from typing import Iterator from warnings import warn import pytest import torch from PIL import Image from refiners.fluxion.layers.attentions import ScaledDotProductAttention from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad from refiners.foundationals.clip.concepts import ConceptExtender from refiners.foundationals.latent_diffusion import ( ControlLoraAdapter, SD1ControlnetAdapter, SD1IPAdapter, SD1T2IAdapter, SD1UNet, SDFreeUAdapter, SDXLIPAdapter, SDXLT2IAdapter, StableDiffusion_1, StableDiffusion_1_Inpainting, ) from refiners.foundationals.latent_diffusion.lora import SDLoraManager from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter from refiners.foundationals.latent_diffusion.restart import Restart from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule, SolverParams from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAdapter from tests.utils import ensure_similar_images def _img_open(path: Path) -> Image.Image: return Image.open(path) # type: ignore @pytest.fixture(autouse=True) def ensure_gc(): # Avoid GPU OOMs # See https://github.com/pytest-dev/pytest/discussions/8153#discussioncomment-214812 gc.collect() @pytest.fixture(scope="module") def ref_path(test_e2e_path: Path) -> Path: return test_e2e_path / "test_diffusion_ref" @pytest.fixture(scope="module") def cutecat_init(ref_path: Path) -> Image.Image: return _img_open(ref_path / "cutecat_init.png").convert("RGB") @pytest.fixture(scope="module") def kitchen_dog(ref_path: Path) -> Image.Image: return _img_open(ref_path / "kitchen_dog.png").convert("RGB") @pytest.fixture(scope="module") def kitchen_dog_mask(ref_path: Path) -> Image.Image: return _img_open(ref_path / "kitchen_dog_mask.png").convert("RGB") @pytest.fixture(scope="module") def woman_image(ref_path: Path) -> Image.Image: return _img_open(ref_path / "woman.png").convert("RGB") @pytest.fixture(scope="module") def statue_image(ref_path: Path) -> Image.Image: return _img_open(ref_path / "statue.png").convert("RGB") @pytest.fixture def expected_image_std_random_init(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_std_random_init.png").convert("RGB") @pytest.fixture def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB") @pytest.fixture def expected_karras_random_init(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_karras_random_init.png").convert("RGB") @pytest.fixture def expected_image_std_random_init_sag(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_std_random_init_sag.png").convert("RGB") @pytest.fixture def expected_image_std_init_image(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_std_init_image.png").convert("RGB") @pytest.fixture def expected_image_std_inpainting(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_std_inpainting.png").convert("RGB") @pytest.fixture def expected_image_controlnet_stack(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_controlnet_stack.png").convert("RGB") @pytest.fixture def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB") @pytest.fixture def expected_image_ip_adapter_multi(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_image_ip_adapter_multi.png").convert("RGB") @pytest.fixture def expected_image_ip_adapter_plus_statue(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_image_ip_adapter_plus_statue.png").convert("RGB") @pytest.fixture def expected_image_sdxl_ip_adapter_woman(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB") @pytest.fixture def expected_image_sdxl_ip_adapter_plus_woman(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_image_sdxl_ip_adapter_plus_woman.png").convert("RGB") @pytest.fixture def expected_image_ip_adapter_controlnet(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_ip_adapter_controlnet.png").convert("RGB") @pytest.fixture def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert("RGB") @pytest.fixture def expected_sdxl_ddim_random_init_sag(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert("RGB") @pytest.fixture def expected_sdxl_euler_random_init(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_cutecat_sdxl_euler_random_init.png").convert("RGB") @pytest.fixture def expected_style_aligned(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_style_aligned.png").convert(mode="RGB") @pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"]) def controlnet_data( ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest ) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]: cn_name: str = request.param condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") weights_fn = { "depth": "lllyasviel_control_v11f1p_sd15_depth", "canny": "lllyasviel_control_v11p_sd15_canny", "lineart": "lllyasviel_control_v11p_sd15_lineart", "normals": "lllyasviel_control_v11p_sd15_normalbae", "sam": "mfidabel_controlnet-segment-anything", } weights_path = test_weights_path / "controlnet" / f"{weights_fn[cn_name]}.safetensors" yield (cn_name, condition_image, expected_image, weights_path) @pytest.fixture(scope="module") def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: cn_name = "canny" condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11p_sd15_canny.safetensors" return cn_name, condition_image, expected_image, weights_path @pytest.fixture(scope="module") def controlnet_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: cn_name = "depth" condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11f1p_sd15_depth.safetensors" return cn_name, condition_image, expected_image, weights_path @dataclass class ControlLoraConfig: scale: float condition_path: str weights_path: str @dataclass class ControlLoraResolvedConfig: scale: float condition_image: Image.Image weights_path: Path CONTROL_LORA_CONFIGS: dict[str, dict[str, ControlLoraConfig]] = { "expected_controllora_PyraCanny.png": { "PyraCanny": ControlLoraConfig( scale=1.0, condition_path="cutecat_guide_PyraCanny.png", weights_path="refiners_control-lora-canny-rank128.safetensors", ), }, "expected_controllora_CPDS.png": { "CPDS": ControlLoraConfig( scale=1.0, condition_path="cutecat_guide_CPDS.png", weights_path="refiners_fooocus_xl_cpds_128.safetensors", ), }, "expected_controllora_PyraCanny+CPDS.png": { "PyraCanny": ControlLoraConfig( scale=0.55, condition_path="cutecat_guide_PyraCanny.png", weights_path="refiners_control-lora-canny-rank128.safetensors", ), "CPDS": ControlLoraConfig( scale=0.55, condition_path="cutecat_guide_CPDS.png", weights_path="refiners_fooocus_xl_cpds_128.safetensors", ), }, "expected_controllora_disabled.png": { "PyraCanny": ControlLoraConfig( scale=0.0, condition_path="cutecat_guide_PyraCanny.png", weights_path="refiners_control-lora-canny-rank128.safetensors", ), "CPDS": ControlLoraConfig( scale=0.0, condition_path="cutecat_guide_CPDS.png", weights_path="refiners_fooocus_xl_cpds_128.safetensors", ), }, } @pytest.fixture(params=CONTROL_LORA_CONFIGS.items()) def controllora_sdxl_config( request: pytest.FixtureRequest, ref_path: Path, test_weights_path: Path, ) -> tuple[Image.Image, dict[str, ControlLoraResolvedConfig]]: name: str = request.param[0] configs: dict[str, ControlLoraConfig] = request.param[1] expected_image = _img_open(ref_path / name).convert("RGB") loaded_configs = { config_name: ControlLoraResolvedConfig( scale=config.scale, condition_image=_img_open(ref_path / config.condition_path).convert("RGB"), weights_path=test_weights_path / "control-loras" / config.weights_path, ) for config_name, config in configs.items() } return expected_image, loaded_configs @pytest.fixture(scope="module") def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: name = "depth" condition_image = _img_open(ref_path / f"cutecat_guide_{name}.png").convert("RGB") expected_image = _img_open(ref_path / f"expected_t2i_adapter_{name}.png").convert("RGB") weights_path = test_weights_path / "T2I-Adapter" / "t2iadapter_depth_sd15v2.safetensors" return name, condition_image, expected_image, weights_path @pytest.fixture(scope="module") def t2i_adapter_xl_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: name = "canny" condition_image = _img_open(ref_path / f"fairy_guide_{name}.png").convert("RGB") expected_image = _img_open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB") weights_path = test_weights_path / "T2I-Adapter" / "t2i-adapter-canny-sdxl-1.0.safetensors" if not weights_path.is_file(): warn(f"could not find weights at {weights_path}, skipping") pytest.skip(allow_module_level=True) return name, condition_image, expected_image, weights_path @pytest.fixture(scope="module") def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]: expected_image = _img_open(ref_path / "expected_lora_pokemon.png").convert("RGB") weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin" if not weights_path.is_file(): warn(f"could not find weights at {weights_path}, skipping") pytest.skip(allow_module_level=True) tensors = load_tensors(weights_path) return expected_image, tensors @pytest.fixture(scope="module") def lora_data_dpo(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]: expected_image = _img_open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB") weights_path = test_weights_path / "loras" / "dpo-lora" / "pytorch_lora_weights.safetensors" if not weights_path.is_file(): warn(f"could not find weights at {weights_path}, skipping") pytest.skip(allow_module_level=True) tensors = load_from_safetensors(weights_path) return expected_image, tensors @pytest.fixture(scope="module") def lora_sliders(test_weights_path: Path) -> tuple[dict[str, dict[str, torch.Tensor]], dict[str, float]]: weights_path = test_weights_path / "loras" / "sliders" if not weights_path.is_dir(): warn(f"could not find weights at {weights_path}, skipping") pytest.skip(allow_module_level=True) return { "age": load_tensors(weights_path / "age.pt"), # type: ignore "cartoon_style": load_tensors(weights_path / "cartoon_style.pt"), # type: ignore "eyesize": load_tensors(weights_path / "eyesize.pt"), # type: ignore }, { "age": 0.3, "cartoon_style": -0.2, "eyesize": -0.2, } @pytest.fixture def scene_image_inpainting_refonly(ref_path: Path) -> Image.Image: return _img_open(ref_path / "inpainting-scene.png").convert("RGB") @pytest.fixture def mask_image_inpainting_refonly(ref_path: Path) -> Image.Image: return _img_open(ref_path / "inpainting-mask.png").convert("RGB") @pytest.fixture def target_image_inpainting_refonly(ref_path: Path) -> Image.Image: return _img_open(ref_path / "inpainting-target.png").convert("RGB") @pytest.fixture def expected_image_inpainting_refonly(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_inpainting_refonly.png").convert("RGB") @pytest.fixture def expected_image_refonly(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_refonly.png").convert("RGB") @pytest.fixture def condition_image_refonly(ref_path: Path) -> Image.Image: return _img_open(ref_path / "cyberpunk_guide.png").convert("RGB") @pytest.fixture def expected_image_textual_inversion_random_init(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_textual_inversion_random_init.png").convert("RGB") @pytest.fixture def expected_multi_diffusion(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_multi_diffusion.png").convert(mode="RGB") @pytest.fixture def expected_restart(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_restart.png").convert(mode="RGB") @pytest.fixture def expected_freeu(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_freeu.png").convert(mode="RGB") @pytest.fixture def expected_sdxl_multi_loras(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_sdxl_multi_loras.png").convert(mode="RGB") @pytest.fixture def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image.Image, Image.Image]: assets = Path(__file__).parent.parent.parent / "assets" dropy = assets / "dropy_logo.png" image_prompt = assets / "dragon_quest_slime.jpg" condition_image = assets / "dropy_canny.png" return ( _img_open(dropy).convert(mode="RGB"), _img_open(image_prompt).convert(mode="RGB"), _img_open(condition_image).convert(mode="RGB"), _img_open(ref_path / "expected_dropy_slime_9752.png").convert(mode="RGB"), ) @pytest.fixture def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor: return load_tensors(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")[""] @pytest.fixture(scope="module") def text_encoder_weights(test_weights_path: Path) -> Path: text_encoder_weights = test_weights_path / "CLIPTextEncoderL.safetensors" if not text_encoder_weights.is_file(): warn(f"could not find weights at {text_encoder_weights}, skipping") pytest.skip(allow_module_level=True) return text_encoder_weights @pytest.fixture(scope="module") def lda_weights(test_weights_path: Path) -> Path: lda_weights = test_weights_path / "lda.safetensors" if not lda_weights.is_file(): warn(f"could not find weights at {lda_weights}, skipping") pytest.skip(allow_module_level=True) return lda_weights @pytest.fixture(scope="module") def unet_weights_std(test_weights_path: Path) -> Path: unet_weights_std = test_weights_path / "unet.safetensors" if not unet_weights_std.is_file(): warn(f"could not find weights at {unet_weights_std}, skipping") pytest.skip(allow_module_level=True) return unet_weights_std @pytest.fixture(scope="module") def unet_weights_inpainting(test_weights_path: Path) -> Path: unet_weights_inpainting = test_weights_path / "inpainting" / "unet.safetensors" if not unet_weights_inpainting.is_file(): warn(f"could not find weights at {unet_weights_inpainting}, skipping") pytest.skip(allow_module_level=True) return unet_weights_inpainting @pytest.fixture(scope="module") def lda_ft_mse_weights(test_weights_path: Path) -> Path: lda_weights = test_weights_path / "lda_ft_mse.safetensors" if not lda_weights.is_file(): warn(f"could not find weights at {lda_weights}, skipping") pytest.skip(allow_module_level=True) return lda_weights @pytest.fixture(scope="module") def ip_adapter_weights(test_weights_path: Path) -> Path: ip_adapter_weights = test_weights_path / "ip-adapter_sd15.safetensors" if not ip_adapter_weights.is_file(): warn(f"could not find weights at {ip_adapter_weights}, skipping") pytest.skip(allow_module_level=True) return ip_adapter_weights @pytest.fixture(scope="module") def ip_adapter_plus_weights(test_weights_path: Path) -> Path: ip_adapter_weights = test_weights_path / "ip-adapter-plus_sd15.safetensors" if not ip_adapter_weights.is_file(): warn(f"could not find weights at {ip_adapter_weights}, skipping") pytest.skip(allow_module_level=True) return ip_adapter_weights @pytest.fixture(scope="module") def sdxl_ip_adapter_weights(test_weights_path: Path) -> Path: ip_adapter_weights = test_weights_path / "ip-adapter_sdxl_vit-h.safetensors" if not ip_adapter_weights.is_file(): warn(f"could not find weights at {ip_adapter_weights}, skipping") pytest.skip(allow_module_level=True) return ip_adapter_weights @pytest.fixture(scope="module") def sdxl_ip_adapter_plus_weights(test_weights_path: Path) -> Path: ip_adapter_weights = test_weights_path / "ip-adapter-plus_sdxl_vit-h.safetensors" if not ip_adapter_weights.is_file(): warn(f"could not find weights at {ip_adapter_weights}, skipping") pytest.skip(allow_module_level=True) return ip_adapter_weights @pytest.fixture(scope="module") def image_encoder_weights(test_weights_path: Path) -> Path: image_encoder_weights = test_weights_path / "CLIPImageEncoderH.safetensors" if not image_encoder_weights.is_file(): warn(f"could not find weights at {image_encoder_weights}, skipping") pytest.skip(allow_module_level=True) return image_encoder_weights @pytest.fixture def sd15_std( text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device ) -> StableDiffusion_1: if test_device.type == "cpu": warn("not running on CPU, skipping") pytest.skip() sd15 = StableDiffusion_1(device=test_device) sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) sd15.lda.load_from_safetensors(lda_weights) sd15.unet.load_from_safetensors(unet_weights_std) return sd15 @pytest.fixture def sd15_std_float16( text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device ) -> StableDiffusion_1: if test_device.type == "cpu": warn("not running on CPU, skipping") pytest.skip() sd15 = StableDiffusion_1(device=test_device, dtype=torch.float16) sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) sd15.lda.load_from_safetensors(lda_weights) sd15.unet.load_from_safetensors(unet_weights_std) return sd15 @pytest.fixture def sd15_inpainting( text_encoder_weights: Path, lda_weights: Path, unet_weights_inpainting: Path, test_device: torch.device ) -> StableDiffusion_1_Inpainting: if test_device.type == "cpu": warn("not running on CPU, skipping") pytest.skip() unet = SD1UNet(in_channels=9) sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device) sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) sd15.lda.load_from_safetensors(lda_weights) sd15.unet.load_from_safetensors(unet_weights_inpainting) return sd15 @pytest.fixture def sd15_inpainting_float16( text_encoder_weights: Path, lda_weights: Path, unet_weights_inpainting: Path, test_device: torch.device ) -> StableDiffusion_1_Inpainting: if test_device.type == "cpu": warn("not running on CPU, skipping") pytest.skip() unet = SD1UNet(in_channels=9) sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device, dtype=torch.float16) sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) sd15.lda.load_from_safetensors(lda_weights) sd15.unet.load_from_safetensors(unet_weights_inpainting) return sd15 @pytest.fixture def sd15_ddim( text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device ) -> StableDiffusion_1: if test_device.type == "cpu": warn("not running on CPU, skipping") pytest.skip() ddim_solver = DDIM(num_inference_steps=20) sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device) sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) sd15.lda.load_from_safetensors(lda_weights) sd15.unet.load_from_safetensors(unet_weights_std) return sd15 @pytest.fixture def sd15_ddim_karras( text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device ) -> StableDiffusion_1: if test_device.type == "cpu": warn("not running on CPU, skipping") pytest.skip() ddim_solver = DDIM(num_inference_steps=20, params=SolverParams(noise_schedule=NoiseSchedule.KARRAS)) sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device) sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) sd15.lda.load_from_safetensors(lda_weights) sd15.unet.load_from_safetensors(unet_weights_std) return sd15 @pytest.fixture def sd15_euler( text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device ) -> StableDiffusion_1: if test_device.type == "cpu": warn("not running on CPU, skipping") pytest.skip() euler_solver = Euler(num_inference_steps=30) sd15 = StableDiffusion_1(solver=euler_solver, device=test_device) sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) sd15.lda.load_from_safetensors(lda_weights) sd15.unet.load_from_safetensors(unet_weights_std) return sd15 @pytest.fixture def sd15_ddim_lda_ft_mse( text_encoder_weights: Path, lda_ft_mse_weights: Path, unet_weights_std: Path, test_device: torch.device ) -> StableDiffusion_1: if test_device.type == "cpu": warn("not running on CPU, skipping") pytest.skip() ddim_solver = DDIM(num_inference_steps=20) sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device) sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights)) sd15.lda.load_state_dict(load_from_safetensors(lda_ft_mse_weights)) sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std)) return sd15 @pytest.fixture def sdxl_lda_weights(test_weights_path: Path) -> Path: sdxl_lda_weights = test_weights_path / "sdxl-lda.safetensors" if not sdxl_lda_weights.is_file(): warn(message=f"could not find weights at {sdxl_lda_weights}, skipping") pytest.skip(allow_module_level=True) return sdxl_lda_weights @pytest.fixture def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path: sdxl_lda_weights = test_weights_path / "sdxl-lda-fp16-fix.safetensors" if not sdxl_lda_weights.is_file(): warn(message=f"could not find weights at {sdxl_lda_weights}, skipping") pytest.skip(allow_module_level=True) return sdxl_lda_weights @pytest.fixture def sdxl_unet_weights(test_weights_path: Path) -> Path: sdxl_unet_weights = test_weights_path / "sdxl-unet.safetensors" if not sdxl_unet_weights.is_file(): warn(message=f"could not find weights at {sdxl_unet_weights}, skipping") pytest.skip(allow_module_level=True) return sdxl_unet_weights @pytest.fixture def sdxl_text_encoder_weights(test_weights_path: Path) -> Path: sdxl_double_text_encoder_weights = test_weights_path / "DoubleCLIPTextEncoder.safetensors" if not sdxl_double_text_encoder_weights.is_file(): warn(message=f"could not find weights at {sdxl_double_text_encoder_weights}, skipping") pytest.skip(allow_module_level=True) return sdxl_double_text_encoder_weights @pytest.fixture def sdxl_ddim( sdxl_text_encoder_weights: Path, sdxl_lda_weights: Path, sdxl_unet_weights: Path, test_device: torch.device ) -> StableDiffusion_XL: if test_device.type == "cpu": warn(message="not running on CPU, skipping") pytest.skip() solver = DDIM(num_inference_steps=30) sdxl = StableDiffusion_XL(solver=solver, device=test_device) sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights) sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_weights) sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights) return sdxl @pytest.fixture def sdxl_ddim_lda_fp16_fix( sdxl_text_encoder_weights: Path, sdxl_lda_fp16_fix_weights: Path, sdxl_unet_weights: Path, test_device: torch.device ) -> StableDiffusion_XL: if test_device.type == "cpu": warn(message="not running on CPU, skipping") pytest.skip() solver = DDIM(num_inference_steps=30) sdxl = StableDiffusion_XL(solver=solver, device=test_device) sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights) sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights) sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights) return sdxl @pytest.fixture def sdxl_euler_deterministic(sdxl_ddim: StableDiffusion_XL) -> StableDiffusion_XL: return StableDiffusion_XL( unet=sdxl_ddim.unet, lda=sdxl_ddim.lda, clip_text_encoder=sdxl_ddim.clip_text_encoder, solver=Euler(num_inference_steps=30), device=sdxl_ddim.device, dtype=sdxl_ddim.dtype, ) @no_grad() def test_diffusion_std_random_init( sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device ): sd15 = sd15_std prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) sd15.set_inference_steps(30) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) for step in sd15.steps: x = sd15( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image_std_random_init) @no_grad() def test_diffusion_batch2(sd15_std: StableDiffusion_1): sd15 = sd15_std prompt1 = "a cute cat, detailed high-quality professional image" negative_prompt1 = "lowres, bad anatomy, bad hands, cropped, worst quality" prompt2 = "a cute dog" negative_prompt2 = "lowres, bad anatomy, bad hands" clip_text_embedding_b2 = sd15.compute_clip_text_embedding( text=[prompt1, prompt2], negative_text=[negative_prompt1, negative_prompt2] ) step = sd15.steps[0] manual_seed(2) rand_b2 = torch.randn(2, 4, 64, 64, device=sd15.device) x_b2 = sd15( rand_b2, step=step, clip_text_embedding=clip_text_embedding_b2, condition_scale=7.5, ) assert x_b2.shape == (2, 4, 64, 64) rand_1 = rand_b2[0:1] clip_text_embedding_1 = sd15.compute_clip_text_embedding(text=[prompt1], negative_text=[negative_prompt1]) x_1 = sd15( rand_1, step=step, clip_text_embedding=clip_text_embedding_1, condition_scale=7.5, ) rand_2 = rand_b2[1:2] clip_text_embedding_2 = sd15.compute_clip_text_embedding(text=[prompt2], negative_text=[negative_prompt2]) x_2 = sd15( rand_2, step=step, clip_text_embedding=clip_text_embedding_2, condition_scale=7.5, ) # The 5e-3 tolerance is detailed in https://github.com/finegrain-ai/refiners/pull/263#issuecomment-1956404911 assert torch.allclose( x_b2[0], x_1[0], atol=5e-3, rtol=0 ), f"Batch 2 and batch1 output should be the same and are distant of {torch.max((x_b2[0] - x_1[0]).abs()).item()}" assert torch.allclose( x_b2[1], x_2[0], atol=5e-3, rtol=0 ), f"Batch 2 and batch1 output should be the same and are distant of {torch.max((x_b2[1] - x_2[0]).abs()).item()}" @no_grad() def test_diffusion_std_random_init_euler( sd15_euler: StableDiffusion_1, expected_image_std_random_init_euler: Image.Image, test_device: torch.device ): sd15 = sd15_euler euler_solver = sd15_euler.solver assert isinstance(euler_solver, Euler) prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) sd15.set_inference_steps(30) manual_seed(2) x = sd15.init_latents((512, 512)).to(sd15.device, sd15.dtype) for step in sd15.steps: x = sd15( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image_std_random_init_euler) @no_grad() def test_diffusion_karras_random_init( sd15_ddim_karras: StableDiffusion_1, expected_karras_random_init: Image.Image, test_device: torch.device ): sd15 = sd15_ddim_karras prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) for step in sd15.steps: x = sd15( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_karras_random_init, min_psnr=35, min_ssim=0.98) @no_grad() def test_diffusion_std_random_init_float16( sd15_std_float16: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device ): sd15 = sd15_std_float16 prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) assert clip_text_embedding.dtype == torch.float16 sd15.set_inference_steps(30) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) for step in sd15.steps: x = sd15( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98) @no_grad() def test_diffusion_std_random_init_sag( sd15_std: StableDiffusion_1, expected_image_std_random_init_sag: Image.Image, test_device: torch.device ): sd15 = sd15_std prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) sd15.set_inference_steps(30) sd15.set_self_attention_guidance(enable=True, scale=0.75) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) for step in sd15.steps: x = sd15( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image_std_random_init_sag) @no_grad() def test_diffusion_std_init_image( sd15_std: StableDiffusion_1, cutecat_init: Image.Image, expected_image_std_init_image: Image.Image, ): sd15 = sd15_std prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) sd15.set_inference_steps(35, first_step=5) manual_seed(2) x = sd15.init_latents((512, 512), cutecat_init) for step in sd15.steps: x = sd15( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image_std_init_image) @no_grad() def test_rectangular_init_latents( sd15_std: StableDiffusion_1, cutecat_init: Image.Image, ): sd15 = sd15_std # Just check latents initialization with a non-square image (and not the entire diffusion) width, height = 512, 504 rect_init_image = cutecat_init.crop((0, 0, width, height)) x = sd15.init_latents((height, width), rect_init_image) assert sd15.lda.latents_to_image(x).size == (width, height) @no_grad() def test_diffusion_inpainting( sd15_inpainting: StableDiffusion_1_Inpainting, kitchen_dog: Image.Image, kitchen_dog_mask: Image.Image, expected_image_std_inpainting: Image.Image, test_device: torch.device, ): sd15 = sd15_inpainting prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) sd15.set_inference_steps(30) sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) for step in sd15.steps: x = sd15( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) predicted_image = sd15.lda.latents_to_image(x) # PSNR and SSIM values are large because with float32 we get large differences even v.s. ourselves. ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95) @no_grad() def test_diffusion_inpainting_float16( sd15_inpainting_float16: StableDiffusion_1_Inpainting, kitchen_dog: Image.Image, kitchen_dog_mask: Image.Image, expected_image_std_inpainting: Image.Image, test_device: torch.device, ): sd15 = sd15_inpainting_float16 prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) assert clip_text_embedding.dtype == torch.float16 sd15.set_inference_steps(30) sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) for step in sd15.steps: x = sd15( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) predicted_image = sd15.lda.latents_to_image(x) # PSNR and SSIM values are large because float16 is even worse than float32. ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92) @no_grad() def test_diffusion_controlnet( sd15_std: StableDiffusion_1, controlnet_data: tuple[str, Image.Image, Image.Image, Path], test_device: torch.device, ): sd15 = sd15_std cn_name, condition_image, expected_image, cn_weights_path = controlnet_data if not cn_weights_path.is_file(): warn(f"could not find weights at {cn_weights_path}, skipping") pytest.skip(allow_module_level=True) prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) sd15.set_inference_steps(30) controlnet = SD1ControlnetAdapter( sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path) ).inject() cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) for step in sd15.steps: controlnet.set_controlnet_condition(cn_condition) x = sd15( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) @no_grad() def test_diffusion_controlnet_structural_copy( sd15_std: StableDiffusion_1, controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path], test_device: torch.device, ): sd15_base = sd15_std sd15 = sd15_base.structural_copy() cn_name, condition_image, expected_image, cn_weights_path = controlnet_data_canny if not cn_weights_path.is_file(): warn(f"could not find weights at {cn_weights_path}, skipping") pytest.skip(allow_module_level=True) prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) sd15.set_inference_steps(30) controlnet = SD1ControlnetAdapter( sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path) ).inject() cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) for step in sd15.steps: controlnet.set_controlnet_condition(cn_condition) x = sd15( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) @no_grad() def test_diffusion_controlnet_float16( sd15_std_float16: StableDiffusion_1, controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path], test_device: torch.device, ): sd15 = sd15_std_float16 cn_name, condition_image, expected_image, cn_weights_path = controlnet_data_canny if not cn_weights_path.is_file(): warn(f"could not find weights at {cn_weights_path}, skipping") pytest.skip(allow_module_level=True) prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) sd15.set_inference_steps(30) controlnet = SD1ControlnetAdapter( sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path) ).inject() cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device, dtype=torch.float16) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) for step in sd15.steps: controlnet.set_controlnet_condition(cn_condition) x = sd15( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) @no_grad() def test_diffusion_controlnet_stack( sd15_std: StableDiffusion_1, controlnet_data_depth: tuple[str, Image.Image, Image.Image, Path], controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path], expected_image_controlnet_stack: Image.Image, test_device: torch.device, ): sd15 = sd15_std _, depth_condition_image, _, depth_cn_weights_path = controlnet_data_depth _, canny_condition_image, _, canny_cn_weights_path = controlnet_data_canny if not canny_cn_weights_path.is_file(): warn(f"could not find weights at {canny_cn_weights_path}, skipping") pytest.skip(allow_module_level=True) if not depth_cn_weights_path.is_file(): warn(f"could not find weights at {depth_cn_weights_path}, skipping") pytest.skip(allow_module_level=True) prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) sd15.set_inference_steps(30) depth_controlnet = SD1ControlnetAdapter( sd15.unet, name="depth", scale=0.3, weights=load_from_safetensors(depth_cn_weights_path) ).inject() canny_controlnet = SD1ControlnetAdapter( sd15.unet, name="canny", scale=0.7, weights=load_from_safetensors(canny_cn_weights_path) ).inject() depth_cn_condition = image_to_tensor(depth_condition_image.convert("RGB"), device=test_device) canny_cn_condition = image_to_tensor(canny_condition_image.convert("RGB"), device=test_device) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) for step in sd15.steps: depth_controlnet.set_controlnet_condition(depth_cn_condition) canny_controlnet.set_controlnet_condition(canny_cn_condition) x = sd15( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98) @no_grad() def test_diffusion_sdxl_control_lora( controllora_sdxl_config: tuple[Image.Image, dict[str, ControlLoraResolvedConfig]], sdxl_ddim_lda_fp16_fix: StableDiffusion_XL, ) -> None: sdxl = sdxl_ddim_lda_fp16_fix.to(dtype=torch.float16) sdxl.dtype = torch.float16 # FIXME: should not be necessary expected_image = controllora_sdxl_config[0] configs = controllora_sdxl_config[1] adapters: dict[str, ControlLoraAdapter] = {} for config_name, config in configs.items(): adapter = ControlLoraAdapter( name=config_name, scale=config.scale, target=sdxl.unet, weights=load_from_safetensors( path=config.weights_path, device=sdxl.device, ), ) adapter.set_condition( image_to_tensor( image=config.condition_image, device=sdxl.device, dtype=sdxl.dtype, ) ) adapters[config_name] = adapter # inject all the control lora adapters for adapter in adapters.values(): adapter.inject() # compute the text embeddings prompt = "a cute cat, flying in the air, detailed high-quality professional image, blank background" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality, watermarks" clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( text=prompt, negative_text=negative_prompt, ) # initialize the latents manual_seed(2) x = torch.randn( (1, 4, 128, 128), device=sdxl.device, dtype=sdxl.dtype, ) # denoise for step in sdxl.steps: x = sdxl( x, step=step, clip_text_embedding=clip_text_embedding, pooled_text_embedding=pooled_text_embedding, time_ids=sdxl.default_time_ids, ) # decode latent to image predicted_image = sdxl.lda.decode_latents(x) # ensure the predicted image is similar to the expected image ensure_similar_images( img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.99, ) @no_grad() def test_diffusion_lora( sd15_std: StableDiffusion_1, lora_data_pokemon: tuple[Image.Image, dict[str, torch.Tensor]], test_device: torch.device, ) -> None: sd15 = sd15_std expected_image, lora_weights = lora_data_pokemon prompt = "a cute cat" clip_text_embedding = sd15.compute_clip_text_embedding(prompt) sd15.set_inference_steps(30) SDLoraManager(sd15).add_loras("pokemon", lora_weights, scale=1) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) for step in sd15.steps: x = sd15( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) @no_grad() def test_diffusion_sdxl_batch2(sdxl_ddim: StableDiffusion_XL) -> None: sdxl = sdxl_ddim prompt1 = "a cute cat, detailed high-quality professional image" negative_prompt1 = "lowres, bad anatomy, bad hands, cropped, worst quality" prompt2 = "a cute dog" negative_prompt2 = "lowres, bad anatomy, bad hands" clip_text_embedding_b2, pooled_text_embedding_b2 = sdxl.compute_clip_text_embedding( text=[prompt1, prompt2], negative_text=[negative_prompt1, negative_prompt2] ) time_ids = sdxl.default_time_ids time_ids_b2 = sdxl.default_time_ids.repeat(2, 1) manual_seed(seed=2) x_b2 = torch.randn(2, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype) x_1 = x_b2[0:1] x_2 = x_b2[1:2] x_b2 = sdxl( x_b2, step=sdxl.steps[0], clip_text_embedding=clip_text_embedding_b2, pooled_text_embedding=pooled_text_embedding_b2, time_ids=time_ids_b2, ) clip_text_embedding_1, pooled_text_embedding_1 = sdxl.compute_clip_text_embedding( text=prompt1, negative_text=negative_prompt1 ) x_1 = sdxl( x_1, step=sdxl.steps[0], clip_text_embedding=clip_text_embedding_1, pooled_text_embedding=pooled_text_embedding_1, time_ids=time_ids, ) clip_text_embedding_2, pooled_text_embedding_2 = sdxl.compute_clip_text_embedding( text=prompt2, negative_text=negative_prompt2 ) x_2 = sdxl( x_2, step=sdxl.steps[0], clip_text_embedding=clip_text_embedding_2, pooled_text_embedding=pooled_text_embedding_2, time_ids=time_ids, ) # The 5e-3 tolerance is detailed in https://github.com/finegrain-ai/refiners/pull/263#issuecomment-1956404911 assert torch.allclose( x_b2[0], x_1[0], atol=5e-3, rtol=0 ), f"Batch 2 and batch1 output should be the same and are distant of {torch.max((x_b2[0] - x_1[0]).abs()).item()}" assert torch.allclose( x_b2[1], x_2[0], atol=5e-3, rtol=0 ), f"Batch 2 and batch1 output should be the same and are distant of {torch.max((x_b2[1] - x_2[0]).abs()).item()}" @no_grad() def test_diffusion_sdxl_lora( sdxl_ddim: StableDiffusion_XL, lora_data_dpo: tuple[Image.Image, dict[str, torch.Tensor]], ) -> None: sdxl = sdxl_ddim expected_image, lora_weights = lora_data_dpo # parameters are the same as https://huggingface.co/radames/sdxl-DPO-LoRA # except that we are using DDIM instead of sde-dpmsolver++ seed = 12341234123 guidance_scale = 7.5 lora_scale = 1.4 prompt = "professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography" negative_prompt = "3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white" SDLoraManager(sdxl).add_loras("dpo", lora_weights, scale=lora_scale, unet_inclusions=["CrossAttentionBlock"]) clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( text=prompt, negative_text=negative_prompt ) time_ids = sdxl.default_time_ids sdxl.set_inference_steps(40) manual_seed(seed=seed) x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype) for step in sdxl.steps: x = sdxl( x, step=step, clip_text_embedding=clip_text_embedding, pooled_text_embedding=pooled_text_embedding, time_ids=time_ids, condition_scale=guidance_scale, ) predicted_image = sdxl.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) @no_grad() def test_diffusion_sdxl_multiple_loras( sdxl_ddim: StableDiffusion_XL, lora_data_dpo: tuple[Image.Image, dict[str, torch.Tensor]], lora_sliders: tuple[dict[str, dict[str, torch.Tensor]], dict[str, float]], expected_sdxl_multi_loras: Image.Image, ) -> None: sdxl = sdxl_ddim expected_image = expected_sdxl_multi_loras _, dpo_weights = lora_data_dpo slider_loras, slider_scales = lora_sliders manager = SDLoraManager(sdxl) for lora_name, lora_weights in slider_loras.items(): manager.add_loras( lora_name, lora_weights, slider_scales[lora_name], unet_inclusions=["SelfAttention", "ResidualBlock", "Downsample", "Upsample"], ) manager.add_loras("dpo", dpo_weights, 1.4, unet_inclusions=["CrossAttentionBlock"]) # parameters are the same as https://huggingface.co/radames/sdxl-DPO-LoRA # except that we are using DDIM instead of sde-dpmsolver++ n_steps = 40 seed = 12341234123 guidance_scale = 4 prompt = "professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography" negative_prompt = "3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white" clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( text=prompt, negative_text=negative_prompt ) time_ids = sdxl.default_time_ids sdxl.set_inference_steps(n_steps) manual_seed(seed=seed) x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype) for step in sdxl.steps: x = sdxl( x, step=step, clip_text_embedding=clip_text_embedding, pooled_text_embedding=pooled_text_embedding, time_ids=time_ids, condition_scale=guidance_scale, ) predicted_image = sdxl.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) @no_grad() def test_diffusion_refonly( sd15_ddim: StableDiffusion_1, condition_image_refonly: Image.Image, expected_image_refonly: Image.Image, test_device: torch.device, ): sd15 = sd15_ddim prompt = "Chicken" clip_text_embedding = sd15.compute_clip_text_embedding(prompt) refonly_adapter = ReferenceOnlyControlAdapter(sd15.unet).inject() guide = sd15.lda.image_to_latents(condition_image_refonly) guide = torch.cat((guide, guide)) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) for step in sd15.steps: noise = torch.randn(2, 4, 64, 64, device=test_device) noised_guide = sd15.solver.add_noise(guide, noise, step) refonly_adapter.set_controlnet_condition(noised_guide) x = sd15( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) torch.randn(2, 4, 64, 64, device=test_device) # for SD Web UI reproductibility only predicted_image = sd15.lda.latents_to_image(x) # min_psnr lowered to 33 because this reference image was generated without noise removal (see #192) ensure_similar_images(predicted_image, expected_image_refonly, min_psnr=33, min_ssim=0.99) @no_grad() def test_diffusion_inpainting_refonly( sd15_inpainting: StableDiffusion_1_Inpainting, scene_image_inpainting_refonly: Image.Image, target_image_inpainting_refonly: Image.Image, mask_image_inpainting_refonly: Image.Image, expected_image_inpainting_refonly: Image.Image, test_device: torch.device, ): sd15 = sd15_inpainting prompt = "" # unconditional clip_text_embedding = sd15.compute_clip_text_embedding(prompt) refonly_adapter = ReferenceOnlyControlAdapter(sd15.unet).inject() sd15.set_inference_steps(30) sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly) guide = sd15.lda.image_to_latents(scene_image_inpainting_refonly) guide = torch.cat((guide, guide)) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) for step in sd15.steps: noise = torch.randn_like(guide) noised_guide = sd15.solver.add_noise(guide, noise, step) # See https://github.com/Mikubill/sd-webui-controlnet/pull/1275 ("1.1.170 reference-only begin to support # inpaint variation models") noised_guide = torch.cat([noised_guide, torch.zeros_like(noised_guide)[:, 0:1, :, :], guide], dim=1) refonly_adapter.set_controlnet_condition(noised_guide) x = sd15( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99) @no_grad() def test_diffusion_textual_inversion_random_init( sd15_std: StableDiffusion_1, expected_image_textual_inversion_random_init: Image.Image, text_embedding_textual_inversion: torch.Tensor, test_device: torch.device, ): sd15 = sd15_std conceptExtender = ConceptExtender(sd15.clip_text_encoder) conceptExtender.add_concept("", text_embedding_textual_inversion) conceptExtender.inject() prompt = "a cute cat on a " clip_text_embedding = sd15.compute_clip_text_embedding(prompt) sd15.set_inference_steps(30) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) for step in sd15.steps: x = sd15( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98) @no_grad() def test_diffusion_ip_adapter( sd15_ddim_lda_ft_mse: StableDiffusion_1, ip_adapter_weights: Path, image_encoder_weights: Path, woman_image: Image.Image, expected_image_ip_adapter_woman: Image.Image, test_device: torch.device, ): sd15 = sd15_ddim_lda_ft_mse.to(dtype=torch.float16) # See tencent-ailab/IP-Adapter best practices section: # # If you only use the image prompt, you can set the scale=1.0 and text_prompt="" (or some generic text # prompts, e.g. "best quality", you can also use any negative text prompt). # # The prompts below are the ones used by default by IPAdapter's generate method if none are specified prompt = "best quality, high quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_weights)) ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) ip_adapter.inject() clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image)) ip_adapter.set_clip_image_embedding(clip_image_embedding) sd15.set_inference_steps(50) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) for step in sd15.steps: x = sd15( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image_ip_adapter_woman) @no_grad() def test_diffusion_ip_adapter_multi( sd15_ddim_lda_ft_mse: StableDiffusion_1, ip_adapter_weights: Path, image_encoder_weights: Path, woman_image: Image.Image, statue_image: Image.Image, expected_image_ip_adapter_multi: Image.Image, test_device: torch.device, ): sd15 = sd15_ddim_lda_ft_mse.to(dtype=torch.float16) prompt = "best quality, high quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_weights)) ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) ip_adapter.inject() clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_image_embedding = ip_adapter.compute_clip_image_embedding([woman_image, statue_image], weights=[1.0, 1.4]) ip_adapter.set_clip_image_embedding(clip_image_embedding) sd15.set_inference_steps(50) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) for step in sd15.steps: x = sd15( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) predicted_image = sd15.lda.decode_latents(x) ensure_similar_images(predicted_image, expected_image_ip_adapter_multi) @no_grad() def test_diffusion_sdxl_ip_adapter( sdxl_ddim: StableDiffusion_XL, sdxl_ip_adapter_weights: Path, image_encoder_weights: Path, woman_image: Image.Image, expected_image_sdxl_ip_adapter_woman: Image.Image, test_device: torch.device, ): sdxl = sdxl_ddim.to(dtype=torch.float16) prompt = "best quality, high quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_weights)) ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) ip_adapter.inject() with no_grad(): clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( text=prompt, negative_text=negative_prompt ) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image)) ip_adapter.set_clip_image_embedding(clip_image_embedding) time_ids = sdxl.default_time_ids sdxl.set_inference_steps(30) manual_seed(2) x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16) with no_grad(): for step in sdxl.steps: x = sdxl( x, step=step, clip_text_embedding=clip_text_embedding, pooled_text_embedding=pooled_text_embedding, time_ids=time_ids, condition_scale=5, ) # See https://huggingface.co/madebyollin/sdxl-vae-fp16-fix: "SDXL-VAE generates NaNs in fp16 because the # internal activation values are too big" sdxl.lda.to(dtype=torch.float32) predicted_image = sdxl.lda.latents_to_image(x.to(dtype=torch.float32)) ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_woman) @no_grad() def test_diffusion_ip_adapter_controlnet( sd15_ddim: StableDiffusion_1, ip_adapter_weights: Path, image_encoder_weights: Path, lora_data_pokemon: tuple[Image.Image, Path], controlnet_data_depth: tuple[str, Image.Image, Image.Image, Path], expected_image_ip_adapter_controlnet: Image.Image, test_device: torch.device, ): sd15 = sd15_ddim.to(dtype=torch.float16) input_image, _ = lora_data_pokemon # use the Pokemon LoRA output as input _, depth_condition_image, _, depth_cn_weights_path = controlnet_data_depth prompt = "best quality, high quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_weights)) ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) ip_adapter.inject() depth_controlnet = SD1ControlnetAdapter( sd15.unet, name="depth", scale=1.0, weights=load_from_safetensors(depth_cn_weights_path), ).inject() clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(input_image)) ip_adapter.set_clip_image_embedding(clip_image_embedding) depth_cn_condition = image_to_tensor( depth_condition_image.convert("RGB"), device=test_device, dtype=torch.float16, ) sd15.set_inference_steps(50) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) for step in sd15.steps: depth_controlnet.set_controlnet_condition(depth_cn_condition) x = sd15( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image_ip_adapter_controlnet) @no_grad() def test_diffusion_ip_adapter_plus( sd15_ddim_lda_ft_mse: StableDiffusion_1, ip_adapter_plus_weights: Path, image_encoder_weights: Path, statue_image: Image.Image, expected_image_ip_adapter_plus_statue: Image.Image, test_device: torch.device, ): sd15 = sd15_ddim_lda_ft_mse.to(dtype=torch.float16) prompt = "best quality, high quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" ip_adapter = SD1IPAdapter( target=sd15.unet, weights=load_from_safetensors(ip_adapter_plus_weights), fine_grained=True ) ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) ip_adapter.inject() clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(statue_image)) ip_adapter.set_clip_image_embedding(clip_image_embedding) sd15.set_inference_steps(50) manual_seed(42) # seed=42 is used in the official IP-Adapter demo x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) for step in sd15.steps: x = sd15( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image_ip_adapter_plus_statue, min_psnr=35, min_ssim=0.98) @no_grad() def test_diffusion_sdxl_ip_adapter_plus( sdxl_ddim: StableDiffusion_XL, sdxl_ip_adapter_plus_weights: Path, image_encoder_weights: Path, woman_image: Image.Image, expected_image_sdxl_ip_adapter_plus_woman: Image.Image, test_device: torch.device, ): sdxl = sdxl_ddim.to(dtype=torch.float16) prompt = "best quality, high quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" ip_adapter = SDXLIPAdapter( target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_plus_weights), fine_grained=True ) ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) ip_adapter.inject() clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( text=prompt, negative_text=negative_prompt ) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image)) ip_adapter.set_clip_image_embedding(clip_image_embedding) time_ids = sdxl.default_time_ids sdxl.set_inference_steps(30) manual_seed(2) x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16) for step in sdxl.steps: x = sdxl( x, step=step, clip_text_embedding=clip_text_embedding, pooled_text_embedding=pooled_text_embedding, time_ids=time_ids, condition_scale=5, ) sdxl.lda.to(dtype=torch.float32) predicted_image = sdxl.lda.latents_to_image(x.to(dtype=torch.float32)) ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman) @no_grad() @pytest.mark.parametrize("structural_copy", [False, True]) def test_diffusion_sdxl_random_init( sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device, structural_copy: bool, ) -> None: sdxl = sdxl_ddim.structural_copy() if structural_copy else sdxl_ddim expected_image = expected_sdxl_ddim_random_init prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( text=prompt, negative_text=negative_prompt ) time_ids = sdxl.default_time_ids sdxl.set_inference_steps(30) manual_seed(seed=2) x = torch.randn(1, 4, 128, 128, device=test_device) for step in sdxl.steps: x = sdxl( x, step=step, clip_text_embedding=clip_text_embedding, pooled_text_embedding=pooled_text_embedding, time_ids=time_ids, condition_scale=5, ) predicted_image = sdxl.lda.latents_to_image(x=x) ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98) @no_grad() def test_diffusion_sdxl_random_init_sag( sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init_sag: Image.Image, test_device: torch.device ) -> None: sdxl = sdxl_ddim expected_image = expected_sdxl_ddim_random_init_sag prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( text=prompt, negative_text=negative_prompt ) time_ids = sdxl.default_time_ids sdxl.set_inference_steps(30) sdxl.set_self_attention_guidance(enable=True, scale=0.75) manual_seed(seed=2) x = torch.randn(1, 4, 128, 128, device=test_device) for step in sdxl.steps: x = sdxl( x, step=step, clip_text_embedding=clip_text_embedding, pooled_text_embedding=pooled_text_embedding, time_ids=time_ids, condition_scale=5, ) predicted_image = sdxl.lda.latents_to_image(x=x) ensure_similar_images(img_1=predicted_image, img_2=expected_image) @no_grad() def test_diffusion_sdxl_sliced_attention( sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image ) -> None: unet = sdxl_ddim.unet.structural_copy() for layer in unet.layers(ScaledDotProductAttention): layer.slice_size = 2048 sdxl = StableDiffusion_XL( unet=unet, lda=sdxl_ddim.lda, clip_text_encoder=sdxl_ddim.clip_text_encoder, solver=sdxl_ddim.solver, device=sdxl_ddim.device, dtype=sdxl_ddim.dtype, ) expected_image = expected_sdxl_ddim_random_init prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( text=prompt, negative_text=negative_prompt ) time_ids = sdxl.default_time_ids sdxl.set_inference_steps(30) manual_seed(2) x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype) for step in sdxl.steps: x = sdxl( x, step=step, clip_text_embedding=clip_text_embedding, pooled_text_embedding=pooled_text_embedding, time_ids=time_ids, condition_scale=5, ) predicted_image = sdxl.lda.decode_latents(x) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) @no_grad() def test_diffusion_sdxl_euler_deterministic( sdxl_euler_deterministic: StableDiffusion_XL, expected_sdxl_euler_random_init: Image.Image ) -> None: sdxl = sdxl_euler_deterministic assert isinstance(sdxl.solver, Euler) expected_image = expected_sdxl_euler_random_init prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( text=prompt, negative_text=negative_prompt ) time_ids = sdxl.default_time_ids sdxl.set_inference_steps(30) manual_seed(2) x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype) for step in sdxl.steps: x = sdxl( x, step=step, clip_text_embedding=clip_text_embedding, pooled_text_embedding=pooled_text_embedding, time_ids=time_ids, condition_scale=5, ) predicted_image = sdxl.lda.decode_latents(x) ensure_similar_images(predicted_image, expected_image) @no_grad() def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None: manual_seed(seed=2) sd = sd15_ddim multi_diffusion = SD1MultiDiffusion(sd) clip_text_embedding = sd.compute_clip_text_embedding(text="a panorama of a mountain") target_1 = DiffusionTarget( size=(64, 64), offset=(0, 0), clip_text_embedding=clip_text_embedding, start_step=0, ) target_2 = DiffusionTarget( size=(64, 64), offset=(0, 16), clip_text_embedding=clip_text_embedding, condition_scale=3, start_step=0, ) noise = torch.randn(1, 4, 64, 80, device=sd.device, dtype=sd.dtype) x = noise for step in sd.steps: x = multi_diffusion( x, noise=noise, step=step, targets=[target_1, target_2], ) result = sd.lda.latents_to_image(x=x) ensure_similar_images(img_1=result, img_2=expected_multi_diffusion, min_psnr=35, min_ssim=0.98) @no_grad() def test_t2i_adapter_depth( sd15_std: StableDiffusion_1, t2i_adapter_data_depth: tuple[str, Image.Image, Image.Image, Path], test_device: torch.device, ): sd15 = sd15_std name, condition_image, expected_image, weights_path = t2i_adapter_data_depth if not weights_path.is_file(): warn(f"could not find weights at {weights_path}, skipping") pytest.skip(allow_module_level=True) prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) sd15.set_inference_steps(30) t2i_adapter = SD1T2IAdapter(target=sd15.unet, name=name, weights=load_from_safetensors(weights_path)).inject() condition = image_to_tensor(condition_image.convert("RGB"), device=test_device) t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition)) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) for step in sd15.steps: x = sd15( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image) @no_grad() def test_t2i_adapter_xl_canny( sdxl_ddim: StableDiffusion_XL, t2i_adapter_xl_data_canny: tuple[str, Image.Image, Image.Image, Path], test_device: torch.device, ): sdxl = sdxl_ddim name, condition_image, expected_image, weights_path = t2i_adapter_xl_data_canny if not weights_path.is_file(): warn(f"could not find weights at {weights_path}, skipping") pytest.skip(allow_module_level=True) prompt = "Mystical fairy in real, magic, 4k picture, high quality" negative_prompt = ( "extra digit, fewer digits, cropped, worst quality, low quality, glitch, deformed, mutated, ugly, disfigured" ) clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( text=prompt, negative_text=negative_prompt ) time_ids = sdxl.default_time_ids sdxl.set_inference_steps(30) t2i_adapter = SDXLT2IAdapter(target=sdxl.unet, name=name, weights=load_from_safetensors(weights_path)).inject() t2i_adapter.scale = 0.8 condition = image_to_tensor(condition_image.convert("RGB"), device=test_device) t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition)) manual_seed(2) x = torch.randn(1, 4, condition_image.height // 8, condition_image.width // 8, device=test_device) for step in sdxl.steps: x = sdxl( x, step=step, clip_text_embedding=clip_text_embedding, pooled_text_embedding=pooled_text_embedding, time_ids=time_ids, condition_scale=7.5, ) predicted_image = sdxl.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image) @no_grad() def test_restart( sd15_ddim: StableDiffusion_1, expected_restart: Image.Image, test_device: torch.device, ): sd15 = sd15_ddim prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) sd15.set_inference_steps(30) restart = Restart(ldm=sd15) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) for step in sd15.steps: x = sd15( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=8, ) if step == restart.start_step: x = restart( x, clip_text_embedding=clip_text_embedding, condition_scale=8, ) predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_restart, min_psnr=35, min_ssim=0.98) @no_grad() def test_freeu( sd15_std: StableDiffusion_1, expected_freeu: Image.Image, ): sd15 = sd15_std prompt = "best quality, high quality cute cat" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) sd15.set_inference_steps(50, first_step=1) SDFreeUAdapter( sd15.unet, backbone_scales=[1.2, 1.2, 1.2, 1.4, 1.4, 1.4], skip_scales=[0.9, 0.9, 0.9, 0.2, 0.2, 0.2] ).inject() manual_seed(9752) x = sd15.init_latents((512, 512)).to(device=sd15.device, dtype=sd15.dtype) for step in sd15.steps: x = sd15( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_freeu) @no_grad() def test_hello_world( sdxl_ddim_lda_fp16_fix: StableDiffusion_XL, t2i_adapter_xl_data_canny: tuple[str, Image.Image, Image.Image, Path], sdxl_ip_adapter_weights: Path, image_encoder_weights: Path, hello_world_assets: tuple[Image.Image, Image.Image, Image.Image, Image.Image], ) -> None: sdxl = sdxl_ddim_lda_fp16_fix.to(dtype=torch.float16) sdxl.dtype = torch.float16 # FIXME: should not be necessary name, _, _, weights_path = t2i_adapter_xl_data_canny init_image, image_prompt, condition_image, expected_image = hello_world_assets if not weights_path.is_file(): warn(f"could not find weights at {weights_path}, skipping") pytest.skip(allow_module_level=True) ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_weights)) ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) ip_adapter.inject() image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt)) ip_adapter.set_clip_image_embedding(image_embedding) # Note: default text prompts for IP-Adapter clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality" ) time_ids = sdxl.default_time_ids t2i_adapter = SDXLT2IAdapter(target=sdxl.unet, name=name, weights=load_from_safetensors(weights_path)).inject() condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype) t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition)) ip_adapter.scale = 0.85 t2i_adapter.scale = 0.8 sdxl.set_inference_steps(50, first_step=1) sdxl.set_self_attention_guidance(enable=True, scale=0.75) manual_seed(9752) x = sdxl.init_latents(size=(1024, 1024), init_image=init_image).to(device=sdxl.device, dtype=sdxl.dtype) for step in sdxl.steps: x = sdxl( x, step=step, clip_text_embedding=clip_text_embedding, pooled_text_embedding=pooled_text_embedding, time_ids=time_ids, ) predicted_image = sdxl.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image) @no_grad() def test_style_aligned( sdxl_ddim_lda_fp16_fix: StableDiffusion_XL, expected_style_aligned: Image.Image, ): sdxl = sdxl_ddim_lda_fp16_fix.to(dtype=torch.float16) sdxl.dtype = torch.float16 # FIXME: should not be necessary style_aligned_adapter = StyleAlignedAdapter(sdxl.unet) style_aligned_adapter.inject() set_of_prompts = [ "a toy train. macro photo. 3d game asset", "a toy airplane. macro photo. 3d game asset", "a toy bicycle. macro photo. 3d game asset", "a toy car. macro photo. 3d game asset", "a toy boat. macro photo. 3d game asset", ] # create (context) embeddings from prompts clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( text=set_of_prompts, negative_text=[""] * len(set_of_prompts) ) time_ids = sdxl.default_time_ids.repeat(len(set_of_prompts), 1) # initialize latents manual_seed(seed=2) x = torch.randn( (len(set_of_prompts), 4, 128, 128), device=sdxl.device, dtype=sdxl.dtype, ) # denoise for step in sdxl.steps: x = sdxl( x, step=step, clip_text_embedding=clip_text_embedding, pooled_text_embedding=pooled_text_embedding, time_ids=time_ids, ) # decode latents predicted_images = [sdxl.lda.decode_latents(latent.unsqueeze(0)) for latent in x] # tile all images horizontally merged_image = Image.new("RGB", (1024 * len(predicted_images), 1024)) for i in range(len(predicted_images)): merged_image.paste(predicted_images[i], (i * 1024, 0)) # type: ignore # compare against reference image ensure_similar_images(merged_image, expected_style_aligned, min_psnr=35, min_ssim=0.99)