2024-01-29 13:38:17 +00:00
import gc
2024-02-14 15:27:13 +00:00
from dataclasses import dataclass
2023-12-11 10:46:38 +00:00
from pathlib import Path
2023-08-04 13:28:41 +00:00
from typing import Iterator
from warnings import warn
2023-12-11 10:46:38 +00:00
import pytest
import torch
2023-08-04 13:28:41 +00:00
from PIL import Image
2024-06-30 18:56:32 +00:00
from tests . utils import T5TextEmbedder , ensure_similar_images
2023-08-04 13:28:41 +00:00
2024-10-09 09:28:34 +00:00
from refiners . conversion import controllora_sdxl
from refiners . conversion . utils import Hub
2024-01-30 15:49:30 +00:00
from refiners . fluxion . layers . attentions import ScaledDotProductAttention
2024-01-19 15:37:01 +00:00
from refiners . fluxion . utils import image_to_tensor , load_from_safetensors , load_tensors , manual_seed , no_grad
2023-12-11 10:46:38 +00:00
from refiners . foundationals . clip . concepts import ConceptExtender
2024-08-12 09:33:15 +00:00
from refiners . foundationals . clip . text_encoder import CLIPTextEncoderL
2023-08-28 14:29:38 +00:00
from refiners . foundationals . latent_diffusion import (
2024-02-19 13:05:36 +00:00
ControlLoraAdapter ,
2023-08-31 08:40:01 +00:00
SD1ControlnetAdapter ,
2024-06-30 18:56:32 +00:00
SD1ELLAAdapter ,
2023-09-06 10:23:53 +00:00
SD1IPAdapter ,
2023-09-24 19:32:06 +00:00
SD1T2IAdapter ,
2023-12-11 10:46:38 +00:00
SD1UNet ,
SDFreeUAdapter ,
2023-09-12 15:28:13 +00:00
SDXLIPAdapter ,
2023-09-24 20:05:56 +00:00
SDXLT2IAdapter ,
2023-12-11 10:46:38 +00:00
StableDiffusion_1 ,
StableDiffusion_1_Inpainting ,
2023-08-28 14:29:38 +00:00
)
2024-01-18 14:34:29 +00:00
from refiners . foundationals . latent_diffusion . lora import SDLoraManager
2024-07-11 13:05:15 +00:00
from refiners . foundationals . latent_diffusion . multi_diffusion import Size , Tile
2023-12-11 10:46:38 +00:00
from refiners . foundationals . latent_diffusion . reference_only_control import ReferenceOnlyControlAdapter
2023-10-12 13:04:57 +00:00
from refiners . foundationals . latent_diffusion . restart import Restart
2024-02-22 14:16:22 +00:00
from refiners . foundationals . latent_diffusion . solvers import DDIM , Euler , NoiseSchedule , SolverParams
2024-07-11 13:05:15 +00:00
from refiners . foundationals . latent_diffusion . solvers . dpm import DPMSolver
2024-08-12 09:33:15 +00:00
from refiners . foundationals . latent_diffusion . stable_diffusion_1 . ic_light import ICLight
from refiners . foundationals . latent_diffusion . stable_diffusion_1 . model import SD1Autoencoder
2024-07-11 13:05:15 +00:00
from refiners . foundationals . latent_diffusion . stable_diffusion_1 . multi_diffusion import (
SD1DiffusionTarget ,
SD1MultiDiffusion ,
)
from refiners . foundationals . latent_diffusion . stable_diffusion_1 . multi_upscaler import (
MultiUpscaler ,
UpscalerCheckpoints ,
)
2023-09-06 16:43:02 +00:00
from refiners . foundationals . latent_diffusion . stable_diffusion_xl . model import StableDiffusion_XL
2024-02-15 14:11:11 +00:00
from refiners . foundationals . latent_diffusion . style_aligned import StyleAlignedAdapter
2023-08-04 13:28:41 +00:00
2024-10-09 09:28:34 +00:00
from . . weight_paths import get_path
2023-08-04 13:28:41 +00:00
2024-01-29 13:38:17 +00:00
@pytest.fixture ( autouse = True )
def ensure_gc ( ) :
# Avoid GPU OOMs
# See https://github.com/pytest-dev/pytest/discussions/8153#discussioncomment-214812
gc . collect ( )
2023-08-04 13:28:41 +00:00
@pytest.fixture ( scope = " module " )
def ref_path ( test_e2e_path : Path ) - > Path :
return test_e2e_path / " test_diffusion_ref "
@pytest.fixture ( scope = " module " )
def cutecat_init ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " cutecat_init.png " ) . convert ( " RGB " )
2023-08-04 13:28:41 +00:00
@pytest.fixture ( scope = " module " )
def kitchen_dog ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " kitchen_dog.png " ) . convert ( " RGB " )
2023-08-04 13:28:41 +00:00
@pytest.fixture ( scope = " module " )
def kitchen_dog_mask ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " kitchen_dog_mask.png " ) . convert ( " RGB " )
2023-08-04 13:28:41 +00:00
2023-09-06 10:23:53 +00:00
@pytest.fixture ( scope = " module " )
def woman_image ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " woman.png " ) . convert ( " RGB " )
2023-09-06 10:23:53 +00:00
2023-09-29 12:34:45 +00:00
@pytest.fixture ( scope = " module " )
def statue_image ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " statue.png " ) . convert ( " RGB " )
2023-09-29 12:34:45 +00:00
2023-08-04 13:28:41 +00:00
@pytest.fixture
def expected_image_std_random_init ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_std_random_init.png " ) . convert ( " RGB " )
2023-08-04 13:28:41 +00:00
2024-10-03 08:45:30 +00:00
@pytest.fixture
def expected_image_std_random_init_bfloat16 ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_std_random_init_bfloat16.png " ) . convert ( " RGB " )
2024-10-03 08:45:30 +00:00
2024-07-23 08:52:40 +00:00
@pytest.fixture
def expected_image_std_sde_random_init ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_std_sde_random_init.png " ) . convert ( " RGB " )
2024-07-23 08:52:40 +00:00
2024-09-06 10:56:24 +00:00
@pytest.fixture
def expected_image_std_sde_karras_random_init ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_std_sde_karras_random_init.png " ) . convert ( " RGB " )
2024-09-06 10:56:24 +00:00
2024-01-10 11:26:47 +00:00
@pytest.fixture
def expected_image_std_random_init_euler ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_std_random_init_euler.png " ) . convert ( " RGB " )
2024-01-10 11:26:47 +00:00
2023-12-03 17:07:42 +00:00
@pytest.fixture
def expected_karras_random_init ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_karras_random_init.png " ) . convert ( " RGB " )
2023-12-03 17:07:42 +00:00
2023-10-09 14:57:58 +00:00
@pytest.fixture
def expected_image_std_random_init_sag ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_std_random_init_sag.png " ) . convert ( " RGB " )
2023-10-09 14:57:58 +00:00
2023-08-04 13:28:41 +00:00
@pytest.fixture
def expected_image_std_init_image ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_std_init_image.png " ) . convert ( " RGB " )
2023-08-04 13:28:41 +00:00
2024-06-30 18:56:32 +00:00
@pytest.fixture
def expected_image_ella_adapter ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_image_ella_adapter.png " ) . convert ( " RGB " )
2024-06-30 18:56:32 +00:00
2023-08-04 13:28:41 +00:00
@pytest.fixture
def expected_image_std_inpainting ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_std_inpainting.png " ) . convert ( " RGB " )
2023-08-04 13:28:41 +00:00
2023-08-31 08:40:01 +00:00
@pytest.fixture
def expected_image_controlnet_stack ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_controlnet_stack.png " ) . convert ( " RGB " )
2023-08-31 08:40:01 +00:00
2023-09-06 10:23:53 +00:00
@pytest.fixture
def expected_image_ip_adapter_woman ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_image_ip_adapter_woman.png " ) . convert ( " RGB " )
2023-09-06 10:23:53 +00:00
2024-01-30 10:40:16 +00:00
@pytest.fixture
def expected_image_ip_adapter_multi ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_image_ip_adapter_multi.png " ) . convert ( " RGB " )
2024-01-30 10:40:16 +00:00
2023-09-29 12:34:45 +00:00
@pytest.fixture
def expected_image_ip_adapter_plus_statue ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_image_ip_adapter_plus_statue.png " ) . convert ( " RGB " )
2023-09-29 12:34:45 +00:00
2023-09-12 15:28:13 +00:00
@pytest.fixture
def expected_image_sdxl_ip_adapter_woman ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_image_sdxl_ip_adapter_woman.png " ) . convert ( " RGB " )
2023-09-12 15:28:13 +00:00
2023-09-29 12:34:45 +00:00
@pytest.fixture
def expected_image_sdxl_ip_adapter_plus_woman ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_image_sdxl_ip_adapter_plus_woman.png " ) . convert ( " RGB " )
2023-09-29 12:34:45 +00:00
2023-09-13 12:13:41 +00:00
@pytest.fixture
def expected_image_ip_adapter_controlnet ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_ip_adapter_controlnet.png " ) . convert ( " RGB " )
2023-09-13 12:13:41 +00:00
2023-09-06 16:43:02 +00:00
@pytest.fixture
def expected_sdxl_ddim_random_init ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_cutecat_sdxl_ddim_random_init.png " ) . convert ( " RGB " )
2023-09-06 16:43:02 +00:00
2023-10-09 14:57:58 +00:00
@pytest.fixture
def expected_sdxl_ddim_random_init_sag ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_cutecat_sdxl_ddim_random_init_sag.png " ) . convert ( " RGB " )
2024-01-30 17:38:34 +00:00
@pytest.fixture
def expected_sdxl_euler_random_init ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_cutecat_sdxl_euler_random_init.png " ) . convert ( " RGB " )
2023-10-09 14:57:58 +00:00
2024-02-15 14:11:11 +00:00
@pytest.fixture
def expected_style_aligned ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_style_aligned.png " ) . convert ( mode = " RGB " )
2024-02-15 14:11:11 +00:00
2023-08-04 13:28:41 +00:00
@pytest.fixture ( scope = " module " , params = [ " canny " , " depth " , " lineart " , " normals " , " sam " ] )
def controlnet_data (
2024-10-09 09:28:34 +00:00
ref_path : Path ,
controlnet_depth_weights_path : Path ,
controlnet_canny_weights_path : Path ,
controlnet_lineart_weights_path : Path ,
controlnet_normals_weights_path : Path ,
controlnet_sam_weights_path : Path ,
request : pytest . FixtureRequest ,
2023-08-04 13:28:41 +00:00
) - > Iterator [ tuple [ str , Image . Image , Image . Image , Path ] ] :
cn_name : str = request . param
2024-10-15 13:51:19 +00:00
condition_image = Image . open ( ref_path / f " cutecat_guide_ { cn_name } .png " ) . convert ( " RGB " )
expected_image = Image . open ( ref_path / f " expected_controlnet_ { cn_name } .png " ) . convert ( " RGB " )
2024-10-09 09:28:34 +00:00
2023-08-04 13:28:41 +00:00
weights_fn = {
2024-10-09 09:28:34 +00:00
" depth " : controlnet_depth_weights_path ,
" canny " : controlnet_canny_weights_path ,
" lineart " : controlnet_lineart_weights_path ,
" normals " : controlnet_normals_weights_path ,
" sam " : controlnet_sam_weights_path ,
2023-08-04 13:28:41 +00:00
}
2024-10-09 09:28:34 +00:00
weights_path = weights_fn [ cn_name ]
2023-08-04 13:28:41 +00:00
2024-10-09 09:28:34 +00:00
yield cn_name , condition_image , expected_image , weights_path
2023-08-04 13:28:41 +00:00
2024-06-24 09:32:27 +00:00
@pytest.fixture ( scope = " module " , params = [ " canny " ] )
def controlnet_data_scale_decay (
2024-10-09 09:28:34 +00:00
ref_path : Path ,
controlnet_canny_weights_path : Path ,
request : pytest . FixtureRequest ,
2024-06-24 09:32:27 +00:00
) - > Iterator [ tuple [ str , Image . Image , Image . Image , Path ] ] :
cn_name : str = request . param
2024-10-15 13:51:19 +00:00
condition_image = Image . open ( ref_path / f " cutecat_guide_ { cn_name } .png " ) . convert ( " RGB " )
expected_image = Image . open ( ref_path / f " expected_controlnet_ { cn_name } _scale_decay.png " ) . convert ( " RGB " )
2024-10-09 09:28:34 +00:00
2024-06-24 09:32:27 +00:00
weights_fn = {
2024-10-09 09:28:34 +00:00
" canny " : controlnet_canny_weights_path ,
2024-06-24 09:32:27 +00:00
}
2024-10-09 09:28:34 +00:00
weights_path = weights_fn [ cn_name ]
2024-06-24 09:32:27 +00:00
yield ( cn_name , condition_image , expected_image , weights_path )
2024-06-24 09:05:19 +00:00
@pytest.fixture ( scope = " module " )
2024-10-09 09:28:34 +00:00
def controlnet_data_tile (
ref_path : Path ,
controlnet_tiles_weights_path : Path ,
) - > tuple [ Image . Image , Image . Image , Path ] :
2024-10-15 13:51:19 +00:00
condition_image = Image . open ( ref_path / f " low_res_dog.png " ) . convert ( " RGB " ) . resize ( ( 1024 , 1024 ) ) # type: ignore
expected_image = Image . open ( ref_path / f " expected_controlnet_tile.png " ) . convert ( " RGB " )
2024-10-09 09:28:34 +00:00
return condition_image , expected_image , controlnet_tiles_weights_path
2024-06-24 09:05:19 +00:00
2023-08-04 13:28:41 +00:00
@pytest.fixture ( scope = " module " )
2024-10-09 09:28:34 +00:00
def controlnet_data_canny (
ref_path : Path ,
controlnet_canny_weights_path : Path ,
) - > tuple [ str , Image . Image , Image . Image , Path ] :
2023-08-04 13:28:41 +00:00
cn_name = " canny "
2024-10-15 13:51:19 +00:00
condition_image = Image . open ( ref_path / f " cutecat_guide_ { cn_name } .png " ) . convert ( " RGB " )
expected_image = Image . open ( ref_path / f " expected_controlnet_ { cn_name } .png " ) . convert ( " RGB " )
2024-10-09 09:28:34 +00:00
return cn_name , condition_image , expected_image , controlnet_canny_weights_path
2023-08-04 13:28:41 +00:00
2023-08-31 08:40:01 +00:00
@pytest.fixture ( scope = " module " )
2024-10-09 09:28:34 +00:00
def controlnet_data_depth (
ref_path : Path ,
controlnet_depth_weights_path : Path ,
) - > tuple [ str , Image . Image , Image . Image , Path ] :
2023-08-31 08:40:01 +00:00
cn_name = " depth "
2024-10-15 13:51:19 +00:00
condition_image = Image . open ( ref_path / f " cutecat_guide_ { cn_name } .png " ) . convert ( " RGB " )
expected_image = Image . open ( ref_path / f " expected_controlnet_ { cn_name } .png " ) . convert ( " RGB " )
2024-10-09 09:28:34 +00:00
return cn_name , condition_image , expected_image , controlnet_depth_weights_path
2023-08-31 08:40:01 +00:00
2023-09-24 20:05:56 +00:00
2024-02-14 15:27:13 +00:00
@dataclass
class ControlLoraConfig :
scale : float
condition_path : str
2024-10-09 09:28:34 +00:00
weights : Hub
2024-02-14 15:27:13 +00:00
@dataclass
class ControlLoraResolvedConfig :
scale : float
condition_image : Image . Image
weights_path : Path
CONTROL_LORA_CONFIGS : dict [ str , dict [ str , ControlLoraConfig ] ] = {
" expected_controllora_PyraCanny.png " : {
" PyraCanny " : ControlLoraConfig (
scale = 1.0 ,
condition_path = " cutecat_guide_PyraCanny.png " ,
2024-10-09 09:28:34 +00:00
weights = controllora_sdxl . canny . converted ,
2024-02-14 15:27:13 +00:00
) ,
} ,
" expected_controllora_CPDS.png " : {
" CPDS " : ControlLoraConfig (
scale = 1.0 ,
condition_path = " cutecat_guide_CPDS.png " ,
2024-10-09 09:28:34 +00:00
weights = controllora_sdxl . cpds . converted ,
2024-02-14 15:27:13 +00:00
) ,
} ,
" expected_controllora_PyraCanny+CPDS.png " : {
" PyraCanny " : ControlLoraConfig (
scale = 0.55 ,
condition_path = " cutecat_guide_PyraCanny.png " ,
2024-10-09 09:28:34 +00:00
weights = controllora_sdxl . canny . converted ,
2024-02-14 15:27:13 +00:00
) ,
" CPDS " : ControlLoraConfig (
scale = 0.55 ,
condition_path = " cutecat_guide_CPDS.png " ,
2024-10-09 09:28:34 +00:00
weights = controllora_sdxl . cpds . converted ,
2024-02-14 15:27:13 +00:00
) ,
} ,
" expected_controllora_disabled.png " : {
" PyraCanny " : ControlLoraConfig (
scale = 0.0 ,
condition_path = " cutecat_guide_PyraCanny.png " ,
2024-10-09 09:28:34 +00:00
weights = controllora_sdxl . canny . converted ,
2024-02-14 15:27:13 +00:00
) ,
" CPDS " : ControlLoraConfig (
scale = 0.0 ,
condition_path = " cutecat_guide_CPDS.png " ,
2024-10-09 09:28:34 +00:00
weights = controllora_sdxl . cpds . converted ,
2024-02-14 15:27:13 +00:00
) ,
} ,
}
@pytest.fixture ( params = CONTROL_LORA_CONFIGS . items ( ) )
def controllora_sdxl_config (
request : pytest . FixtureRequest ,
2024-10-09 09:28:34 +00:00
use_local_weights : bool ,
2024-02-14 15:27:13 +00:00
ref_path : Path ,
) - > tuple [ Image . Image , dict [ str , ControlLoraResolvedConfig ] ] :
name : str = request . param [ 0 ]
configs : dict [ str , ControlLoraConfig ] = request . param [ 1 ]
2024-10-15 13:51:19 +00:00
expected_image = Image . open ( ref_path / name ) . convert ( " RGB " )
2024-02-14 15:27:13 +00:00
loaded_configs = {
config_name : ControlLoraResolvedConfig (
scale = config . scale ,
2024-10-15 13:51:19 +00:00
condition_image = Image . open ( ref_path / config . condition_path ) . convert ( " RGB " ) ,
2024-10-09 09:28:34 +00:00
weights_path = get_path ( config . weights , use_local_weights ) ,
2024-02-14 15:27:13 +00:00
)
for config_name , config in configs . items ( )
}
return expected_image , loaded_configs
2023-09-24 19:32:06 +00:00
@pytest.fixture ( scope = " module " )
2024-10-09 09:28:34 +00:00
def t2i_adapter_data_depth (
ref_path : Path ,
t2i_depth_weights_path : Path ,
) - > tuple [ str , Image . Image , Image . Image , Path ] :
2023-09-24 19:32:06 +00:00
name = " depth "
2024-10-15 13:51:19 +00:00
condition_image = Image . open ( ref_path / f " cutecat_guide_ { name } .png " ) . convert ( " RGB " )
expected_image = Image . open ( ref_path / f " expected_t2i_adapter_ { name } .png " ) . convert ( " RGB " )
2024-10-09 09:28:34 +00:00
return name , condition_image , expected_image , t2i_depth_weights_path
2023-09-24 19:32:06 +00:00
2023-08-31 08:40:01 +00:00
2023-09-24 20:05:56 +00:00
@pytest.fixture ( scope = " module " )
2024-10-09 09:28:34 +00:00
def t2i_adapter_xl_data_canny (
ref_path : Path ,
t2i_sdxl_canny_weights_path : Path ,
) - > tuple [ str , Image . Image , Image . Image , Path ] :
2023-09-24 20:05:56 +00:00
name = " canny "
2024-10-15 13:51:19 +00:00
condition_image = Image . open ( ref_path / f " fairy_guide_ { name } .png " ) . convert ( " RGB " )
expected_image = Image . open ( ref_path / f " expected_t2i_adapter_xl_ { name } .png " ) . convert ( " RGB " )
2024-10-09 09:28:34 +00:00
return name , condition_image , expected_image , t2i_sdxl_canny_weights_path
2023-09-24 20:05:56 +00:00
2023-08-04 13:28:41 +00:00
@pytest.fixture ( scope = " module " )
2024-10-09 09:28:34 +00:00
def lora_data_pokemon (
ref_path : Path ,
lora_pokemon_weights_path : Path ,
) - > tuple [ Image . Image , dict [ str , torch . Tensor ] ] :
2024-10-15 13:51:19 +00:00
expected_image = Image . open ( ref_path / " expected_lora_pokemon.png " ) . convert ( " RGB " )
2024-10-09 09:28:34 +00:00
tensors = load_tensors ( lora_pokemon_weights_path )
2024-01-18 14:34:29 +00:00
return expected_image , tensors
@pytest.fixture ( scope = " module " )
2024-10-09 09:28:34 +00:00
def lora_data_dpo (
ref_path : Path ,
lora_dpo_weights_path : Path ,
) - > tuple [ Image . Image , dict [ str , torch . Tensor ] ] :
2024-10-15 13:51:19 +00:00
expected_image = Image . open ( ref_path / " expected_sdxl_dpo_lora.png " ) . convert ( " RGB " )
2024-10-09 09:28:34 +00:00
tensors = load_from_safetensors ( lora_dpo_weights_path )
2024-01-18 14:34:29 +00:00
return expected_image , tensors
2023-08-04 13:28:41 +00:00
2024-01-22 13:45:34 +00:00
@pytest.fixture ( scope = " module " )
2024-10-09 09:28:34 +00:00
def lora_sliders (
lora_slider_age_weights_path : Path ,
lora_slider_cartoon_style_weights_path : Path ,
lora_slider_eyesize_weights_path : Path ,
) - > tuple [ dict [ str , dict [ str , torch . Tensor ] ] , dict [ str , float ] ] :
2024-01-22 13:45:34 +00:00
return {
2024-10-09 09:28:34 +00:00
" age " : load_tensors ( lora_slider_age_weights_path ) ,
" cartoon_style " : load_tensors ( lora_slider_cartoon_style_weights_path ) ,
" eyesize " : load_tensors ( lora_slider_eyesize_weights_path ) ,
2024-01-22 13:45:34 +00:00
} , {
" age " : 0.3 ,
" cartoon_style " : - 0.2 ,
" eyesize " : - 0.2 ,
}
2023-08-04 13:28:41 +00:00
@pytest.fixture
def scene_image_inpainting_refonly ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " inpainting-scene.png " ) . convert ( " RGB " )
2023-08-04 13:28:41 +00:00
@pytest.fixture
def mask_image_inpainting_refonly ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " inpainting-mask.png " ) . convert ( " RGB " )
2023-08-04 13:28:41 +00:00
@pytest.fixture
def target_image_inpainting_refonly ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " inpainting-target.png " ) . convert ( " RGB " )
2023-08-04 13:28:41 +00:00
@pytest.fixture
def expected_image_inpainting_refonly ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_inpainting_refonly.png " ) . convert ( " RGB " )
2023-08-04 13:28:41 +00:00
@pytest.fixture
def expected_image_refonly ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_refonly.png " ) . convert ( " RGB " )
2023-08-04 13:28:41 +00:00
@pytest.fixture
def condition_image_refonly ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " cyberpunk_guide.png " ) . convert ( " RGB " )
2023-08-04 13:28:41 +00:00
2023-08-25 18:27:29 +00:00
@pytest.fixture
def expected_image_textual_inversion_random_init ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_textual_inversion_random_init.png " ) . convert ( " RGB " )
2023-08-25 18:27:29 +00:00
2023-09-18 08:48:05 +00:00
@pytest.fixture
def expected_multi_diffusion ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_multi_diffusion.png " ) . convert ( mode = " RGB " )
2023-09-18 08:48:05 +00:00
2024-07-11 13:05:15 +00:00
@pytest.fixture
def expected_multi_diffusion_dpm ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_multi_diffusion_dpm.png " ) . convert ( mode = " RGB " )
2024-07-11 13:05:15 +00:00
2023-10-12 13:04:57 +00:00
@pytest.fixture
def expected_restart ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_restart.png " ) . convert ( mode = " RGB " )
2023-10-12 13:04:57 +00:00
2023-11-18 14:24:11 +00:00
@pytest.fixture
def expected_freeu ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_freeu.png " ) . convert ( mode = " RGB " )
2023-11-18 14:24:11 +00:00
2024-01-22 13:45:34 +00:00
@pytest.fixture
def expected_sdxl_multi_loras ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_sdxl_multi_loras.png " ) . convert ( mode = " RGB " )
2024-01-22 13:45:34 +00:00
2024-01-16 15:13:40 +00:00
@pytest.fixture
def hello_world_assets ( ref_path : Path ) - > tuple [ Image . Image , Image . Image , Image . Image , Image . Image ] :
assets = Path ( __file__ ) . parent . parent . parent / " assets "
dropy = assets / " dropy_logo.png "
image_prompt = assets / " dragon_quest_slime.jpg "
condition_image = assets / " dropy_canny.png "
return (
2024-10-15 13:51:19 +00:00
Image . open ( dropy ) . convert ( mode = " RGB " ) ,
Image . open ( image_prompt ) . convert ( mode = " RGB " ) ,
Image . open ( condition_image ) . convert ( mode = " RGB " ) ,
Image . open ( ref_path / " expected_dropy_slime_9752.png " ) . convert ( mode = " RGB " ) ,
2024-01-16 15:13:40 +00:00
)
2023-08-25 18:27:29 +00:00
@pytest.fixture
def text_embedding_textual_inversion ( test_textual_inversion_path : Path ) - > torch . Tensor :
2024-01-19 15:37:01 +00:00
return load_tensors ( test_textual_inversion_path / " gta5-artwork " / " learned_embeds.bin " ) [ " <gta5-artwork> " ]
2023-08-25 18:27:29 +00:00
2023-08-04 13:28:41 +00:00
@pytest.fixture
def sd15_std (
2024-10-09 09:28:34 +00:00
sd15_text_encoder_weights_path : Path ,
sd15_autoencoder_weights_path : Path ,
sd15_unet_weights_path : Path ,
test_device : torch . device ,
2023-08-04 13:28:41 +00:00
) - > StableDiffusion_1 :
if test_device . type == " cpu " :
warn ( " not running on CPU, skipping " )
pytest . skip ( )
sd15 = StableDiffusion_1 ( device = test_device )
2024-10-09 09:28:34 +00:00
sd15 . clip_text_encoder . load_from_safetensors ( sd15_text_encoder_weights_path )
sd15 . lda . load_from_safetensors ( sd15_autoencoder_weights_path )
sd15 . unet . load_from_safetensors ( sd15_unet_weights_path )
2023-08-04 13:28:41 +00:00
return sd15
2024-07-23 08:52:40 +00:00
@pytest.fixture
def sd15_std_sde (
2024-10-09 09:28:34 +00:00
sd15_text_encoder_weights_path : Path ,
sd15_autoencoder_weights_path : Path ,
sd15_unet_weights_path : Path ,
test_device : torch . device ,
2024-07-23 08:52:40 +00:00
) - > StableDiffusion_1 :
if test_device . type == " cpu " :
warn ( " not running on CPU, skipping " )
pytest . skip ( )
sde_solver = DPMSolver ( num_inference_steps = 30 , last_step_first_order = True , params = SolverParams ( sde_variance = 1.0 ) )
sd15 = StableDiffusion_1 ( device = test_device , solver = sde_solver )
2024-10-09 09:28:34 +00:00
sd15 . clip_text_encoder . load_from_safetensors ( sd15_text_encoder_weights_path )
sd15 . lda . load_from_safetensors ( sd15_autoencoder_weights_path )
sd15 . unet . load_from_safetensors ( sd15_unet_weights_path )
2024-07-23 08:52:40 +00:00
return sd15
2023-08-04 13:28:41 +00:00
@pytest.fixture
def sd15_std_float16 (
2024-10-09 09:28:34 +00:00
sd15_text_encoder_weights_path : Path ,
sd15_autoencoder_weights_path : Path ,
sd15_unet_weights_path : Path ,
test_device : torch . device ,
2023-08-04 13:28:41 +00:00
) - > StableDiffusion_1 :
if test_device . type == " cpu " :
warn ( " not running on CPU, skipping " )
pytest . skip ( )
sd15 = StableDiffusion_1 ( device = test_device , dtype = torch . float16 )
2024-10-09 09:28:34 +00:00
sd15 . clip_text_encoder . load_from_safetensors ( sd15_text_encoder_weights_path )
sd15 . lda . load_from_safetensors ( sd15_autoencoder_weights_path )
sd15 . unet . load_from_safetensors ( sd15_unet_weights_path )
2023-08-04 13:28:41 +00:00
return sd15
2024-10-03 08:45:30 +00:00
@pytest.fixture
def sd15_std_bfloat16 (
2024-10-09 09:28:34 +00:00
sd15_text_encoder_weights_path : Path ,
sd15_autoencoder_weights_path : Path ,
sd15_unet_weights_path : Path ,
2024-10-03 08:45:30 +00:00
test_device : torch . device ,
) - > StableDiffusion_1 :
if test_device . type == " cpu " :
warn ( " not running on CPU, skipping " )
pytest . skip ( )
sd15 = StableDiffusion_1 ( device = test_device , dtype = torch . bfloat16 )
2024-10-09 09:28:34 +00:00
sd15 . clip_text_encoder . load_from_safetensors ( sd15_text_encoder_weights_path )
sd15 . lda . load_from_safetensors ( sd15_autoencoder_weights_path )
sd15 . unet . load_from_safetensors ( sd15_unet_weights_path )
2024-10-03 08:45:30 +00:00
return sd15
2023-08-04 13:28:41 +00:00
@pytest.fixture
def sd15_inpainting (
2024-10-09 09:28:34 +00:00
sd15_text_encoder_weights_path : Path ,
sd15_autoencoder_weights_path : Path ,
sd15_unet_inpainting_weights_path : Path ,
test_device : torch . device ,
2023-08-04 13:28:41 +00:00
) - > StableDiffusion_1_Inpainting :
if test_device . type == " cpu " :
warn ( " not running on CPU, skipping " )
pytest . skip ( )
2023-08-31 15:22:57 +00:00
unet = SD1UNet ( in_channels = 9 )
2023-08-04 13:28:41 +00:00
sd15 = StableDiffusion_1_Inpainting ( unet = unet , device = test_device )
2024-10-09 09:28:34 +00:00
sd15 . clip_text_encoder . load_from_safetensors ( sd15_text_encoder_weights_path )
sd15 . lda . load_from_safetensors ( sd15_autoencoder_weights_path )
sd15 . unet . load_from_safetensors ( sd15_unet_inpainting_weights_path )
2023-08-04 13:28:41 +00:00
return sd15
@pytest.fixture
def sd15_inpainting_float16 (
2024-10-09 09:28:34 +00:00
sd15_text_encoder_weights_path : Path ,
sd15_autoencoder_weights_path : Path ,
sd15_unet_inpainting_weights_path : Path ,
test_device : torch . device ,
2023-08-04 13:28:41 +00:00
) - > StableDiffusion_1_Inpainting :
if test_device . type == " cpu " :
warn ( " not running on CPU, skipping " )
pytest . skip ( )
2023-08-31 15:22:57 +00:00
unet = SD1UNet ( in_channels = 9 )
2023-08-04 13:28:41 +00:00
sd15 = StableDiffusion_1_Inpainting ( unet = unet , device = test_device , dtype = torch . float16 )
2024-10-09 09:28:34 +00:00
sd15 . clip_text_encoder . load_from_safetensors ( sd15_text_encoder_weights_path )
sd15 . lda . load_from_safetensors ( sd15_autoencoder_weights_path )
sd15 . unet . load_from_safetensors ( sd15_unet_inpainting_weights_path )
2023-08-04 13:28:41 +00:00
return sd15
@pytest.fixture
def sd15_ddim (
2024-10-09 09:28:34 +00:00
sd15_text_encoder_weights_path : Path ,
sd15_autoencoder_weights_path : Path ,
sd15_unet_weights_path : Path ,
test_device : torch . device ,
2023-08-04 13:28:41 +00:00
) - > StableDiffusion_1 :
if test_device . type == " cpu " :
warn ( " not running on CPU, skipping " )
pytest . skip ( )
2024-01-31 14:07:34 +00:00
ddim_solver = DDIM ( num_inference_steps = 20 )
sd15 = StableDiffusion_1 ( solver = ddim_solver , device = test_device )
2023-08-04 13:28:41 +00:00
2024-10-09 09:28:34 +00:00
sd15 . clip_text_encoder . load_from_safetensors ( sd15_text_encoder_weights_path )
sd15 . lda . load_from_safetensors ( sd15_autoencoder_weights_path )
sd15 . unet . load_from_safetensors ( sd15_unet_weights_path )
2023-08-04 13:28:41 +00:00
return sd15
2023-12-03 17:07:42 +00:00
@pytest.fixture
def sd15_ddim_karras (
2024-10-09 09:28:34 +00:00
sd15_text_encoder_weights_path : Path ,
sd15_autoencoder_weights_path : Path ,
sd15_unet_weights_path : Path ,
test_device : torch . device ,
2023-12-03 17:07:42 +00:00
) - > StableDiffusion_1 :
if test_device . type == " cpu " :
warn ( " not running on CPU, skipping " )
pytest . skip ( )
2024-02-22 14:16:22 +00:00
ddim_solver = DDIM ( num_inference_steps = 20 , params = SolverParams ( noise_schedule = NoiseSchedule . KARRAS ) )
2024-01-31 14:07:34 +00:00
sd15 = StableDiffusion_1 ( solver = ddim_solver , device = test_device )
2023-12-03 17:07:42 +00:00
2024-10-09 09:28:34 +00:00
sd15 . clip_text_encoder . load_from_safetensors ( sd15_text_encoder_weights_path )
sd15 . lda . load_from_safetensors ( sd15_autoencoder_weights_path )
sd15 . unet . load_from_safetensors ( sd15_unet_weights_path )
2023-12-03 17:07:42 +00:00
return sd15
2024-01-10 11:26:47 +00:00
@pytest.fixture
def sd15_euler (
2024-10-09 09:28:34 +00:00
sd15_text_encoder_weights_path : Path ,
sd15_autoencoder_weights_path : Path ,
sd15_unet_weights_path : Path ,
test_device : torch . device ,
2024-01-10 11:26:47 +00:00
) - > StableDiffusion_1 :
if test_device . type == " cpu " :
warn ( " not running on CPU, skipping " )
pytest . skip ( )
2024-01-31 14:07:34 +00:00
euler_solver = Euler ( num_inference_steps = 30 )
sd15 = StableDiffusion_1 ( solver = euler_solver , device = test_device )
2024-01-10 11:26:47 +00:00
2024-10-09 09:28:34 +00:00
sd15 . clip_text_encoder . load_from_safetensors ( sd15_text_encoder_weights_path )
sd15 . lda . load_from_safetensors ( sd15_autoencoder_weights_path )
sd15 . unet . load_from_safetensors ( sd15_unet_weights_path )
2024-01-10 11:26:47 +00:00
return sd15
2023-09-06 10:23:53 +00:00
@pytest.fixture
def sd15_ddim_lda_ft_mse (
2024-10-09 09:28:34 +00:00
sd15_text_encoder_weights_path : Path ,
sd15_autoencoder_mse_weights_path : Path ,
sd15_unet_weights_path : Path ,
test_device : torch . device ,
2023-09-06 10:23:53 +00:00
) - > StableDiffusion_1 :
if test_device . type == " cpu " :
warn ( " not running on CPU, skipping " )
pytest . skip ( )
2024-01-31 14:07:34 +00:00
ddim_solver = DDIM ( num_inference_steps = 20 )
sd15 = StableDiffusion_1 ( solver = ddim_solver , device = test_device )
2023-09-06 10:23:53 +00:00
2024-10-09 09:28:34 +00:00
sd15 . clip_text_encoder . load_from_safetensors ( sd15_text_encoder_weights_path )
sd15 . lda . load_from_safetensors ( sd15_autoencoder_mse_weights_path )
sd15 . unet . load_from_safetensors ( sd15_unet_weights_path )
2023-09-06 10:23:53 +00:00
return sd15
2023-09-06 16:43:02 +00:00
@pytest.fixture
def sdxl_ddim (
2024-10-09 09:28:34 +00:00
sdxl_text_encoder_weights_path : Path ,
sdxl_autoencoder_weights_path : Path ,
sdxl_unet_weights_path : Path ,
test_device : torch . device ,
2023-09-06 16:43:02 +00:00
) - > StableDiffusion_XL :
if test_device . type == " cpu " :
warn ( message = " not running on CPU, skipping " )
pytest . skip ( )
2024-01-31 14:07:34 +00:00
solver = DDIM ( num_inference_steps = 30 )
sdxl = StableDiffusion_XL ( solver = solver , device = test_device )
2023-09-06 16:43:02 +00:00
2024-10-09 09:28:34 +00:00
sdxl . clip_text_encoder . load_from_safetensors ( tensors_path = sdxl_text_encoder_weights_path )
sdxl . lda . load_from_safetensors ( tensors_path = sdxl_autoencoder_weights_path )
sdxl . unet . load_from_safetensors ( tensors_path = sdxl_unet_weights_path )
2023-09-06 16:43:02 +00:00
return sdxl
2024-01-16 15:13:40 +00:00
@pytest.fixture
def sdxl_ddim_lda_fp16_fix (
2024-10-09 09:28:34 +00:00
sdxl_text_encoder_weights_path : Path ,
sdxl_autoencoder_fp16fix_weights_path : Path ,
sdxl_unet_weights_path : Path ,
test_device : torch . device ,
2024-01-16 15:13:40 +00:00
) - > StableDiffusion_XL :
if test_device . type == " cpu " :
warn ( message = " not running on CPU, skipping " )
pytest . skip ( )
2024-01-31 14:07:34 +00:00
solver = DDIM ( num_inference_steps = 30 )
sdxl = StableDiffusion_XL ( solver = solver , device = test_device )
2024-01-16 15:13:40 +00:00
2024-10-09 09:28:34 +00:00
sdxl . clip_text_encoder . load_from_safetensors ( tensors_path = sdxl_text_encoder_weights_path )
sdxl . lda . load_from_safetensors ( tensors_path = sdxl_autoencoder_fp16fix_weights_path )
sdxl . unet . load_from_safetensors ( tensors_path = sdxl_unet_weights_path )
2024-01-16 15:13:40 +00:00
return sdxl
2024-01-30 17:38:34 +00:00
@pytest.fixture
def sdxl_euler_deterministic ( sdxl_ddim : StableDiffusion_XL ) - > StableDiffusion_XL :
return StableDiffusion_XL (
unet = sdxl_ddim . unet ,
lda = sdxl_ddim . lda ,
clip_text_encoder = sdxl_ddim . clip_text_encoder ,
solver = Euler ( num_inference_steps = 30 ) ,
device = sdxl_ddim . device ,
dtype = sdxl_ddim . dtype ,
)
2024-07-11 13:05:15 +00:00
@pytest.fixture ( scope = " module " )
def multi_upscaler (
2024-10-09 09:28:34 +00:00
controlnet_tiles_weights_path : Path ,
sd15_text_encoder_weights_path : Path ,
sd15_autoencoder_mse_weights_path : Path ,
sd15_unet_weights_path : Path ,
2024-07-11 13:05:15 +00:00
test_device : torch . device ,
) - > MultiUpscaler :
return MultiUpscaler (
checkpoints = UpscalerCheckpoints (
2024-10-09 09:28:34 +00:00
unet = sd15_unet_weights_path ,
clip_text_encoder = sd15_text_encoder_weights_path ,
lda = sd15_autoencoder_mse_weights_path ,
controlnet_tile = controlnet_tiles_weights_path ,
2024-07-11 13:05:15 +00:00
) ,
device = test_device ,
dtype = torch . float32 ,
)
@pytest.fixture ( scope = " module " )
def clarity_example ( ref_path : Path ) - > Image . Image :
return Image . open ( ref_path / " clarity_input_example.png " )
@pytest.fixture ( scope = " module " )
def expected_multi_upscaler ( ref_path : Path ) - > Image . Image :
return Image . open ( ref_path / " expected_multi_upscaler.png " )
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-08-04 13:28:41 +00:00
def test_diffusion_std_random_init (
2024-10-09 09:28:34 +00:00
sd15_std : StableDiffusion_1 ,
expected_image_std_random_init : Image . Image ,
test_device : torch . device ,
2023-08-04 13:28:41 +00:00
) :
sd15 = sd15_std
prompt = " a cute cat, detailed high-quality professional image "
negative_prompt = " lowres, bad anatomy, bad hands, cropped, worst quality "
2023-09-12 15:55:39 +00:00
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
2023-08-04 13:28:41 +00:00
2024-01-19 09:55:04 +00:00
sd15 . set_inference_steps ( 30 )
2023-08-04 13:28:41 +00:00
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 64 , 64 , device = test_device )
2023-09-12 15:55:39 +00:00
for step in sd15 . steps :
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sd15 . lda . latents_to_image ( x )
2023-08-04 13:28:41 +00:00
ensure_similar_images ( predicted_image , expected_image_std_random_init )
2024-10-03 08:45:30 +00:00
@no_grad ( )
def test_diffusion_std_random_init_bfloat16 (
sd15_std_bfloat16 : StableDiffusion_1 ,
expected_image_std_random_init_bfloat16 : Image . Image ,
) :
sd15 = sd15_std_bfloat16
prompt = " a cute cat, detailed high-quality professional image "
negative_prompt = " lowres, bad anatomy, bad hands, cropped, worst quality "
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
sd15 . set_inference_steps ( 30 )
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 64 , 64 , device = sd15 . device , dtype = sd15 . dtype )
for step in sd15 . steps :
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
predicted_image = sd15 . lda . latents_to_image ( x )
2024-10-14 11:50:56 +00:00
ensure_similar_images ( predicted_image , expected_image_std_random_init_bfloat16 , min_psnr = 30 , min_ssim = 0.97 )
2024-10-03 08:45:30 +00:00
2024-07-23 08:52:40 +00:00
@no_grad ( )
def test_diffusion_std_sde_random_init (
sd15_std_sde : StableDiffusion_1 , expected_image_std_sde_random_init : Image . Image , test_device : torch . device
) :
sd15 = sd15_std_sde
prompt = " a cute cat, detailed high-quality professional image "
negative_prompt = " lowres, bad anatomy, bad hands, cropped, worst quality "
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
sd15 . set_inference_steps ( 50 )
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 64 , 64 , device = test_device )
for step in sd15 . steps :
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
predicted_image = sd15 . lda . latents_to_image ( x )
ensure_similar_images ( predicted_image , expected_image_std_sde_random_init )
2024-09-06 10:56:24 +00:00
@no_grad ( )
def test_diffusion_std_sde_karras_random_init (
sd15_std_sde : StableDiffusion_1 , expected_image_std_sde_karras_random_init : Image . Image , test_device : torch . device
) :
sd15 = sd15_std_sde
prompt = " a cute cat, detailed high-quality professional image "
negative_prompt = " lowres, bad anatomy, bad hands, cropped, worst quality "
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
sd15 . solver = DPMSolver (
num_inference_steps = 18 ,
last_step_first_order = True ,
params = SolverParams ( sde_variance = 1.0 , sigma_schedule = NoiseSchedule . KARRAS ) ,
device = test_device ,
)
manual_seed ( 2 )
x = sd15 . init_latents ( ( 512 , 512 ) )
for step in sd15 . steps :
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
predicted_image = sd15 . lda . latents_to_image ( x )
ensure_similar_images ( predicted_image , expected_image_std_sde_karras_random_init )
2024-02-05 21:19:45 +00:00
@no_grad ( )
def test_diffusion_batch2 ( sd15_std : StableDiffusion_1 ) :
sd15 = sd15_std
prompt1 = " a cute cat, detailed high-quality professional image "
negative_prompt1 = " lowres, bad anatomy, bad hands, cropped, worst quality "
prompt2 = " a cute dog "
negative_prompt2 = " lowres, bad anatomy, bad hands "
clip_text_embedding_b2 = sd15 . compute_clip_text_embedding (
text = [ prompt1 , prompt2 ] , negative_text = [ negative_prompt1 , negative_prompt2 ]
)
step = sd15 . steps [ 0 ]
manual_seed ( 2 )
rand_b2 = torch . randn ( 2 , 4 , 64 , 64 , device = sd15 . device )
x_b2 = sd15 (
rand_b2 ,
step = step ,
clip_text_embedding = clip_text_embedding_b2 ,
condition_scale = 7.5 ,
)
assert x_b2 . shape == ( 2 , 4 , 64 , 64 )
rand_1 = rand_b2 [ 0 : 1 ]
clip_text_embedding_1 = sd15 . compute_clip_text_embedding ( text = [ prompt1 ] , negative_text = [ negative_prompt1 ] )
x_1 = sd15 (
rand_1 ,
step = step ,
clip_text_embedding = clip_text_embedding_1 ,
condition_scale = 7.5 ,
)
rand_2 = rand_b2 [ 1 : 2 ]
clip_text_embedding_2 = sd15 . compute_clip_text_embedding ( text = [ prompt2 ] , negative_text = [ negative_prompt2 ] )
x_2 = sd15 (
rand_2 ,
step = step ,
clip_text_embedding = clip_text_embedding_2 ,
condition_scale = 7.5 ,
)
# The 5e-3 tolerance is detailed in https://github.com/finegrain-ai/refiners/pull/263#issuecomment-1956404911
assert torch . allclose (
x_b2 [ 0 ] , x_1 [ 0 ] , atol = 5e-3 , rtol = 0
) , f " Batch 2 and batch1 output should be the same and are distant of { torch . max ( ( x_b2 [ 0 ] - x_1 [ 0 ] ) . abs ( ) ) . item ( ) } "
assert torch . allclose (
x_b2 [ 1 ] , x_2 [ 0 ] , atol = 5e-3 , rtol = 0
) , f " Batch 2 and batch1 output should be the same and are distant of { torch . max ( ( x_b2 [ 1 ] - x_2 [ 0 ] ) . abs ( ) ) . item ( ) } "
2024-01-10 11:26:47 +00:00
@no_grad ( )
def test_diffusion_std_random_init_euler (
sd15_euler : StableDiffusion_1 , expected_image_std_random_init_euler : Image . Image , test_device : torch . device
) :
sd15 = sd15_euler
2024-01-31 14:07:34 +00:00
euler_solver = sd15_euler . solver
assert isinstance ( euler_solver , Euler )
2024-01-10 11:26:47 +00:00
prompt = " a cute cat, detailed high-quality professional image "
negative_prompt = " lowres, bad anatomy, bad hands, cropped, worst quality "
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
2024-01-19 09:55:04 +00:00
sd15 . set_inference_steps ( 30 )
2024-01-10 11:26:47 +00:00
manual_seed ( 2 )
2024-02-23 15:45:21 +00:00
x = sd15 . init_latents ( ( 512 , 512 ) ) . to ( sd15 . device , sd15 . dtype )
2024-01-10 11:26:47 +00:00
for step in sd15 . steps :
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sd15 . lda . latents_to_image ( x )
2024-01-10 11:26:47 +00:00
ensure_similar_images ( predicted_image , expected_image_std_random_init_euler )
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-12-03 17:07:42 +00:00
def test_diffusion_karras_random_init (
sd15_ddim_karras : StableDiffusion_1 , expected_karras_random_init : Image . Image , test_device : torch . device
) :
sd15 = sd15_ddim_karras
prompt = " a cute cat, detailed high-quality professional image "
negative_prompt = " lowres, bad anatomy, bad hands, cropped, worst quality "
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 64 , 64 , device = test_device )
for step in sd15 . steps :
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sd15 . lda . latents_to_image ( x )
2023-12-03 17:07:42 +00:00
ensure_similar_images ( predicted_image , expected_karras_random_init , min_psnr = 35 , min_ssim = 0.98 )
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-08-04 13:28:41 +00:00
def test_diffusion_std_random_init_float16 (
sd15_std_float16 : StableDiffusion_1 , expected_image_std_random_init : Image . Image , test_device : torch . device
) :
sd15 = sd15_std_float16
prompt = " a cute cat, detailed high-quality professional image "
negative_prompt = " lowres, bad anatomy, bad hands, cropped, worst quality "
2023-09-12 15:55:39 +00:00
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
2023-08-04 13:28:41 +00:00
assert clip_text_embedding . dtype == torch . float16
2024-01-19 09:55:04 +00:00
sd15 . set_inference_steps ( 30 )
2023-08-04 13:28:41 +00:00
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 64 , 64 , device = test_device , dtype = torch . float16 )
2023-09-12 15:55:39 +00:00
for step in sd15 . steps :
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sd15 . lda . latents_to_image ( x )
2023-08-04 13:28:41 +00:00
ensure_similar_images ( predicted_image , expected_image_std_random_init , min_psnr = 35 , min_ssim = 0.98 )
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-10-09 14:57:58 +00:00
def test_diffusion_std_random_init_sag (
sd15_std : StableDiffusion_1 , expected_image_std_random_init_sag : Image . Image , test_device : torch . device
) :
sd15 = sd15_std
prompt = " a cute cat, detailed high-quality professional image "
negative_prompt = " lowres, bad anatomy, bad hands, cropped, worst quality "
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
2024-01-19 09:55:04 +00:00
sd15 . set_inference_steps ( 30 )
2023-10-09 14:57:58 +00:00
sd15 . set_self_attention_guidance ( enable = True , scale = 0.75 )
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 64 , 64 , device = test_device )
for step in sd15 . steps :
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sd15 . lda . latents_to_image ( x )
2023-10-09 14:57:58 +00:00
ensure_similar_images ( predicted_image , expected_image_std_random_init_sag )
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-08-04 13:28:41 +00:00
def test_diffusion_std_init_image (
sd15_std : StableDiffusion_1 ,
cutecat_init : Image . Image ,
expected_image_std_init_image : Image . Image ,
) :
sd15 = sd15_std
prompt = " a cute cat, detailed high-quality professional image "
negative_prompt = " lowres, bad anatomy, bad hands, cropped, worst quality "
2023-09-12 15:55:39 +00:00
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
2023-08-04 13:28:41 +00:00
2024-01-19 09:55:04 +00:00
sd15 . set_inference_steps ( 35 , first_step = 5 )
2023-08-04 13:28:41 +00:00
manual_seed ( 2 )
2024-01-19 09:55:04 +00:00
x = sd15 . init_latents ( ( 512 , 512 ) , cutecat_init )
2023-08-04 13:28:41 +00:00
2024-01-19 09:55:04 +00:00
for step in sd15 . steps :
2023-09-12 15:55:39 +00:00
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sd15 . lda . latents_to_image ( x )
2023-08-04 13:28:41 +00:00
ensure_similar_images ( predicted_image , expected_image_std_init_image )
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-09-20 08:15:17 +00:00
def test_rectangular_init_latents (
sd15_std : StableDiffusion_1 ,
cutecat_init : Image . Image ,
) :
sd15 = sd15_std
# Just check latents initialization with a non-square image (and not the entire diffusion)
width , height = 512 , 504
rect_init_image = cutecat_init . crop ( ( 0 , 0 , width , height ) )
x = sd15 . init_latents ( ( height , width ) , rect_init_image )
2024-02-01 14:05:43 +00:00
assert sd15 . lda . latents_to_image ( x ) . size == ( width , height )
2023-09-20 08:15:17 +00:00
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-08-04 13:28:41 +00:00
def test_diffusion_inpainting (
sd15_inpainting : StableDiffusion_1_Inpainting ,
kitchen_dog : Image . Image ,
kitchen_dog_mask : Image . Image ,
expected_image_std_inpainting : Image . Image ,
test_device : torch . device ,
) :
sd15 = sd15_inpainting
prompt = " a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen "
negative_prompt = " lowres, bad anatomy, bad hands, cropped, worst quality "
2023-09-12 15:55:39 +00:00
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
2023-08-04 13:28:41 +00:00
2024-01-19 09:55:04 +00:00
sd15 . set_inference_steps ( 30 )
2023-08-04 13:28:41 +00:00
sd15 . set_inpainting_conditions ( kitchen_dog , kitchen_dog_mask )
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 64 , 64 , device = test_device )
2023-09-12 15:55:39 +00:00
for step in sd15 . steps :
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sd15 . lda . latents_to_image ( x )
2023-08-04 13:28:41 +00:00
# PSNR and SSIM values are large because with float32 we get large differences even v.s. ourselves.
ensure_similar_images ( predicted_image , expected_image_std_inpainting , min_psnr = 25 , min_ssim = 0.95 )
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-08-04 13:28:41 +00:00
def test_diffusion_inpainting_float16 (
sd15_inpainting_float16 : StableDiffusion_1_Inpainting ,
kitchen_dog : Image . Image ,
kitchen_dog_mask : Image . Image ,
expected_image_std_inpainting : Image . Image ,
test_device : torch . device ,
) :
sd15 = sd15_inpainting_float16
prompt = " a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen "
negative_prompt = " lowres, bad anatomy, bad hands, cropped, worst quality "
2023-09-12 15:55:39 +00:00
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
2023-08-04 13:28:41 +00:00
assert clip_text_embedding . dtype == torch . float16
2024-01-19 09:55:04 +00:00
sd15 . set_inference_steps ( 30 )
2023-08-04 13:28:41 +00:00
sd15 . set_inpainting_conditions ( kitchen_dog , kitchen_dog_mask )
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 64 , 64 , device = test_device , dtype = torch . float16 )
2023-09-12 15:55:39 +00:00
for step in sd15 . steps :
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sd15 . lda . latents_to_image ( x )
2023-08-04 13:28:41 +00:00
# PSNR and SSIM values are large because float16 is even worse than float32.
2024-10-14 11:50:56 +00:00
ensure_similar_images ( predicted_image , expected_image_std_inpainting , min_psnr = 25 , min_ssim = 0.95 , min_dinov2 = 0.96 )
2023-08-04 13:28:41 +00:00
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-08-04 13:28:41 +00:00
def test_diffusion_controlnet (
sd15_std : StableDiffusion_1 ,
controlnet_data : tuple [ str , Image . Image , Image . Image , Path ] ,
test_device : torch . device ,
) :
sd15 = sd15_std
cn_name , condition_image , expected_image , cn_weights_path = controlnet_data
if not cn_weights_path . is_file ( ) :
warn ( f " could not find weights at { cn_weights_path } , skipping " )
pytest . skip ( allow_module_level = True )
prompt = " a cute cat, detailed high-quality professional image "
negative_prompt = " lowres, bad anatomy, bad hands, cropped, worst quality "
2023-09-12 15:55:39 +00:00
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
2023-08-04 13:28:41 +00:00
2024-01-19 09:55:04 +00:00
sd15 . set_inference_steps ( 30 )
2023-08-04 13:28:41 +00:00
2023-08-31 08:40:01 +00:00
controlnet = SD1ControlnetAdapter (
sd15 . unet , name = cn_name , scale = 0.5 , weights = load_from_safetensors ( cn_weights_path )
) . inject ( )
2023-08-04 13:28:41 +00:00
cn_condition = image_to_tensor ( condition_image . convert ( " RGB " ) , device = test_device )
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 64 , 64 , device = test_device )
2023-09-12 15:55:39 +00:00
for step in sd15 . steps :
controlnet . set_controlnet_condition ( cn_condition )
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sd15 . lda . latents_to_image ( x )
2023-08-04 13:28:41 +00:00
ensure_similar_images ( predicted_image , expected_image , min_psnr = 35 , min_ssim = 0.98 )
2024-06-24 09:05:19 +00:00
@no_grad ( )
def test_diffusion_controlnet_tile_upscale (
sd15_std : StableDiffusion_1 ,
controlnet_data_tile : tuple [ Image . Image , Image . Image , Path ] ,
test_device : torch . device ,
) :
sd15 = sd15_std
condition_image , expected_image , cn_weights_path = controlnet_data_tile
controlnet : SD1ControlnetAdapter = SD1ControlnetAdapter (
sd15 . unet , name = " tile " , scale = 1.0 , weights = load_from_safetensors ( cn_weights_path )
) . inject ( )
cn_condition = image_to_tensor ( condition_image , device = test_device )
prompt = " best quality "
negative_prompt = " blur, lowres, bad anatomy, bad hands, cropped, worst quality "
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
manual_seed ( 0 )
x = sd15 . init_latents ( ( 1024 , 1024 ) , condition_image ) . to ( test_device )
for step in sd15 . steps :
controlnet . set_controlnet_condition ( cn_condition )
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
predicted_image = sd15 . lda . latents_to_image ( x )
# Note: rather large tolerances are used on purpose here (loose comparison with diffusers' output)
2024-10-14 11:50:56 +00:00
ensure_similar_images ( predicted_image , expected_image , min_psnr = 24 , min_ssim = 0.75 , min_dinov2 = 0.94 )
2024-06-24 09:05:19 +00:00
2024-06-24 09:32:27 +00:00
@no_grad ( )
def test_diffusion_controlnet_scale_decay (
sd15_std : StableDiffusion_1 ,
controlnet_data_scale_decay : tuple [ str , Image . Image , Image . Image , Path ] ,
test_device : torch . device ,
) :
sd15 = sd15_std
cn_name , condition_image , expected_image , cn_weights_path = controlnet_data_scale_decay
if not cn_weights_path . is_file ( ) :
warn ( f " could not find weights at { cn_weights_path } , skipping " )
pytest . skip ( allow_module_level = True )
prompt = " a cute cat, detailed high-quality professional image "
negative_prompt = " lowres, bad anatomy, bad hands, cropped, worst quality "
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
sd15 . set_inference_steps ( 30 )
# Using default value of 0.825 chosen by lvmin
# https://github.com/Mikubill/sd-webui-controlnet/blob/8e143d3545140b8f0398dfbe1d95a0a766019283/scripts/hook.py#L472
controlnet = SD1ControlnetAdapter (
sd15 . unet , name = cn_name , scale = 0.5 , scale_decay = 0.825 , weights = load_from_safetensors ( cn_weights_path )
) . inject ( )
cn_condition = image_to_tensor ( condition_image . convert ( " RGB " ) , device = test_device )
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 64 , 64 , device = test_device )
for step in sd15 . steps :
controlnet . set_controlnet_condition ( cn_condition )
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
predicted_image = sd15 . lda . latents_to_image ( x )
ensure_similar_images ( predicted_image , expected_image , min_psnr = 35 , min_ssim = 0.98 )
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-08-04 13:28:41 +00:00
def test_diffusion_controlnet_structural_copy (
sd15_std : StableDiffusion_1 ,
controlnet_data_canny : tuple [ str , Image . Image , Image . Image , Path ] ,
test_device : torch . device ,
) :
sd15_base = sd15_std
sd15 = sd15_base . structural_copy ( )
cn_name , condition_image , expected_image , cn_weights_path = controlnet_data_canny
if not cn_weights_path . is_file ( ) :
warn ( f " could not find weights at { cn_weights_path } , skipping " )
pytest . skip ( allow_module_level = True )
prompt = " a cute cat, detailed high-quality professional image "
negative_prompt = " lowres, bad anatomy, bad hands, cropped, worst quality "
2023-09-12 15:55:39 +00:00
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
2023-08-04 13:28:41 +00:00
2024-01-19 09:55:04 +00:00
sd15 . set_inference_steps ( 30 )
2023-08-04 13:28:41 +00:00
2023-08-31 08:40:01 +00:00
controlnet = SD1ControlnetAdapter (
sd15 . unet , name = cn_name , scale = 0.5 , weights = load_from_safetensors ( cn_weights_path )
) . inject ( )
2023-08-04 13:28:41 +00:00
cn_condition = image_to_tensor ( condition_image . convert ( " RGB " ) , device = test_device )
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 64 , 64 , device = test_device )
2023-09-12 15:55:39 +00:00
for step in sd15 . steps :
controlnet . set_controlnet_condition ( cn_condition )
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sd15 . lda . latents_to_image ( x )
2023-08-04 13:28:41 +00:00
ensure_similar_images ( predicted_image , expected_image , min_psnr = 35 , min_ssim = 0.98 )
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-08-04 13:28:41 +00:00
def test_diffusion_controlnet_float16 (
sd15_std_float16 : StableDiffusion_1 ,
controlnet_data_canny : tuple [ str , Image . Image , Image . Image , Path ] ,
test_device : torch . device ,
) :
sd15 = sd15_std_float16
cn_name , condition_image , expected_image , cn_weights_path = controlnet_data_canny
if not cn_weights_path . is_file ( ) :
warn ( f " could not find weights at { cn_weights_path } , skipping " )
pytest . skip ( allow_module_level = True )
prompt = " a cute cat, detailed high-quality professional image "
negative_prompt = " lowres, bad anatomy, bad hands, cropped, worst quality "
2023-09-12 15:55:39 +00:00
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
2023-08-04 13:28:41 +00:00
2024-01-19 09:55:04 +00:00
sd15 . set_inference_steps ( 30 )
2023-08-04 13:28:41 +00:00
2023-08-31 08:40:01 +00:00
controlnet = SD1ControlnetAdapter (
sd15 . unet , name = cn_name , scale = 0.5 , weights = load_from_safetensors ( cn_weights_path )
) . inject ( )
2023-08-04 13:28:41 +00:00
cn_condition = image_to_tensor ( condition_image . convert ( " RGB " ) , device = test_device , dtype = torch . float16 )
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 64 , 64 , device = test_device , dtype = torch . float16 )
2023-09-12 15:55:39 +00:00
for step in sd15 . steps :
controlnet . set_controlnet_condition ( cn_condition )
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sd15 . lda . latents_to_image ( x )
2023-08-04 13:28:41 +00:00
ensure_similar_images ( predicted_image , expected_image , min_psnr = 35 , min_ssim = 0.98 )
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-08-31 08:40:01 +00:00
def test_diffusion_controlnet_stack (
sd15_std : StableDiffusion_1 ,
controlnet_data_depth : tuple [ str , Image . Image , Image . Image , Path ] ,
controlnet_data_canny : tuple [ str , Image . Image , Image . Image , Path ] ,
expected_image_controlnet_stack : Image . Image ,
test_device : torch . device ,
) :
sd15 = sd15_std
_ , depth_condition_image , _ , depth_cn_weights_path = controlnet_data_depth
_ , canny_condition_image , _ , canny_cn_weights_path = controlnet_data_canny
if not canny_cn_weights_path . is_file ( ) :
warn ( f " could not find weights at { canny_cn_weights_path } , skipping " )
pytest . skip ( allow_module_level = True )
if not depth_cn_weights_path . is_file ( ) :
warn ( f " could not find weights at { depth_cn_weights_path } , skipping " )
pytest . skip ( allow_module_level = True )
prompt = " a cute cat, detailed high-quality professional image "
negative_prompt = " lowres, bad anatomy, bad hands, cropped, worst quality "
2023-09-12 15:55:39 +00:00
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
2023-08-31 08:40:01 +00:00
2024-01-19 09:55:04 +00:00
sd15 . set_inference_steps ( 30 )
2023-08-31 08:40:01 +00:00
depth_controlnet = SD1ControlnetAdapter (
sd15 . unet , name = " depth " , scale = 0.3 , weights = load_from_safetensors ( depth_cn_weights_path )
) . inject ( )
canny_controlnet = SD1ControlnetAdapter (
sd15 . unet , name = " canny " , scale = 0.7 , weights = load_from_safetensors ( canny_cn_weights_path )
) . inject ( )
depth_cn_condition = image_to_tensor ( depth_condition_image . convert ( " RGB " ) , device = test_device )
canny_cn_condition = image_to_tensor ( canny_condition_image . convert ( " RGB " ) , device = test_device )
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 64 , 64 , device = test_device )
2023-09-12 15:55:39 +00:00
for step in sd15 . steps :
depth_controlnet . set_controlnet_condition ( depth_cn_condition )
canny_controlnet . set_controlnet_condition ( canny_cn_condition )
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sd15 . lda . latents_to_image ( x )
2023-08-31 08:40:01 +00:00
ensure_similar_images ( predicted_image , expected_image_controlnet_stack , min_psnr = 35 , min_ssim = 0.98 )
2024-02-14 15:27:13 +00:00
@no_grad ( )
2024-02-14 16:32:34 +00:00
def test_diffusion_sdxl_control_lora (
2024-02-14 15:27:13 +00:00
controllora_sdxl_config : tuple [ Image . Image , dict [ str , ControlLoraResolvedConfig ] ] ,
sdxl_ddim_lda_fp16_fix : StableDiffusion_XL ,
) - > None :
sdxl = sdxl_ddim_lda_fp16_fix . to ( dtype = torch . float16 )
sdxl . dtype = torch . float16 # FIXME: should not be necessary
expected_image = controllora_sdxl_config [ 0 ]
configs = controllora_sdxl_config [ 1 ]
adapters : dict [ str , ControlLoraAdapter ] = { }
for config_name , config in configs . items ( ) :
2024-10-09 09:28:34 +00:00
if not config . weights_path . is_file ( ) :
pytest . skip ( f " File not found: { config . weights_path } " )
2024-02-14 15:27:13 +00:00
adapter = ControlLoraAdapter (
name = config_name ,
scale = config . scale ,
target = sdxl . unet ,
weights = load_from_safetensors (
path = config . weights_path ,
device = sdxl . device ,
) ,
)
adapter . set_condition (
image_to_tensor (
image = config . condition_image ,
device = sdxl . device ,
dtype = sdxl . dtype ,
)
)
adapters [ config_name ] = adapter
# inject all the control lora adapters
for adapter in adapters . values ( ) :
adapter . inject ( )
# compute the text embeddings
prompt = " a cute cat, flying in the air, detailed high-quality professional image, blank background "
negative_prompt = " lowres, bad anatomy, bad hands, cropped, worst quality, watermarks "
clip_text_embedding , pooled_text_embedding = sdxl . compute_clip_text_embedding (
text = prompt ,
negative_text = negative_prompt ,
)
# initialize the latents
manual_seed ( 2 )
x = torch . randn (
( 1 , 4 , 128 , 128 ) ,
device = sdxl . device ,
dtype = sdxl . dtype ,
)
# denoise
for step in sdxl . steps :
x = sdxl (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
pooled_text_embedding = pooled_text_embedding ,
time_ids = sdxl . default_time_ids ,
)
# decode latent to image
2024-10-15 12:52:24 +00:00
predicted_image = sdxl . lda . latents_to_image ( x )
2024-02-14 15:27:13 +00:00
# ensure the predicted image is similar to the expected image
ensure_similar_images (
img_1 = predicted_image ,
img_2 = expected_image ,
min_psnr = 35 ,
min_ssim = 0.99 ,
)
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-08-04 13:28:41 +00:00
def test_diffusion_lora (
sd15_std : StableDiffusion_1 ,
2024-01-18 14:34:29 +00:00
lora_data_pokemon : tuple [ Image . Image , dict [ str , torch . Tensor ] ] ,
2023-08-04 13:28:41 +00:00
test_device : torch . device ,
2024-01-18 14:34:29 +00:00
) - > None :
2023-08-04 13:28:41 +00:00
sd15 = sd15_std
2024-01-18 14:34:29 +00:00
expected_image , lora_weights = lora_data_pokemon
2023-08-04 13:28:41 +00:00
prompt = " a cute cat "
2023-09-12 15:55:39 +00:00
clip_text_embedding = sd15 . compute_clip_text_embedding ( prompt )
2023-08-04 13:28:41 +00:00
2024-01-19 09:55:04 +00:00
sd15 . set_inference_steps ( 30 )
2023-08-04 13:28:41 +00:00
2024-01-22 13:45:34 +00:00
SDLoraManager ( sd15 ) . add_loras ( " pokemon " , lora_weights , scale = 1 )
2023-08-04 13:28:41 +00:00
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 64 , 64 , device = test_device )
2023-09-12 15:55:39 +00:00
for step in sd15 . steps :
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sd15 . lda . latents_to_image ( x )
2023-08-04 13:28:41 +00:00
ensure_similar_images ( predicted_image , expected_image , min_psnr = 35 , min_ssim = 0.98 )
2024-02-05 21:19:45 +00:00
@no_grad ( )
def test_diffusion_sdxl_batch2 ( sdxl_ddim : StableDiffusion_XL ) - > None :
sdxl = sdxl_ddim
prompt1 = " a cute cat, detailed high-quality professional image "
negative_prompt1 = " lowres, bad anatomy, bad hands, cropped, worst quality "
prompt2 = " a cute dog "
negative_prompt2 = " lowres, bad anatomy, bad hands "
clip_text_embedding_b2 , pooled_text_embedding_b2 = sdxl . compute_clip_text_embedding (
text = [ prompt1 , prompt2 ] , negative_text = [ negative_prompt1 , negative_prompt2 ]
)
time_ids = sdxl . default_time_ids
time_ids_b2 = sdxl . default_time_ids . repeat ( 2 , 1 )
manual_seed ( seed = 2 )
x_b2 = torch . randn ( 2 , 4 , 128 , 128 , device = sdxl . device , dtype = sdxl . dtype )
x_1 = x_b2 [ 0 : 1 ]
x_2 = x_b2 [ 1 : 2 ]
x_b2 = sdxl (
x_b2 ,
step = sdxl . steps [ 0 ] ,
clip_text_embedding = clip_text_embedding_b2 ,
pooled_text_embedding = pooled_text_embedding_b2 ,
time_ids = time_ids_b2 ,
)
clip_text_embedding_1 , pooled_text_embedding_1 = sdxl . compute_clip_text_embedding (
text = prompt1 , negative_text = negative_prompt1
)
x_1 = sdxl (
x_1 ,
step = sdxl . steps [ 0 ] ,
clip_text_embedding = clip_text_embedding_1 ,
pooled_text_embedding = pooled_text_embedding_1 ,
time_ids = time_ids ,
)
clip_text_embedding_2 , pooled_text_embedding_2 = sdxl . compute_clip_text_embedding (
text = prompt2 , negative_text = negative_prompt2
)
x_2 = sdxl (
x_2 ,
step = sdxl . steps [ 0 ] ,
clip_text_embedding = clip_text_embedding_2 ,
pooled_text_embedding = pooled_text_embedding_2 ,
time_ids = time_ids ,
)
# The 5e-3 tolerance is detailed in https://github.com/finegrain-ai/refiners/pull/263#issuecomment-1956404911
assert torch . allclose (
x_b2 [ 0 ] , x_1 [ 0 ] , atol = 5e-3 , rtol = 0
) , f " Batch 2 and batch1 output should be the same and are distant of { torch . max ( ( x_b2 [ 0 ] - x_1 [ 0 ] ) . abs ( ) ) . item ( ) } "
assert torch . allclose (
x_b2 [ 1 ] , x_2 [ 0 ] , atol = 5e-3 , rtol = 0
) , f " Batch 2 and batch1 output should be the same and are distant of { torch . max ( ( x_b2 [ 1 ] - x_2 [ 0 ] ) . abs ( ) ) . item ( ) } "
2023-12-29 09:59:51 +00:00
@no_grad ( )
2024-01-18 14:34:29 +00:00
def test_diffusion_sdxl_lora (
sdxl_ddim : StableDiffusion_XL ,
lora_data_dpo : tuple [ Image . Image , dict [ str , torch . Tensor ] ] ,
) - > None :
sdxl = sdxl_ddim
expected_image , lora_weights = lora_data_dpo
# parameters are the same as https://huggingface.co/radames/sdxl-DPO-LoRA
# except that we are using DDIM instead of sde-dpmsolver++
seed = 12341234123
guidance_scale = 7.5
lora_scale = 1.4
prompt = " professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography "
negative_prompt = " 3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white "
2024-03-05 15:51:02 +00:00
SDLoraManager ( sdxl ) . add_loras ( " dpo " , lora_weights , scale = lora_scale , unet_inclusions = [ " CrossAttentionBlock " ] )
2024-01-18 14:34:29 +00:00
clip_text_embedding , pooled_text_embedding = sdxl . compute_clip_text_embedding (
text = prompt , negative_text = negative_prompt
)
2023-09-11 12:51:11 +00:00
2024-01-18 14:34:29 +00:00
time_ids = sdxl . default_time_ids
2024-01-19 09:55:04 +00:00
sdxl . set_inference_steps ( 40 )
2023-09-11 12:51:11 +00:00
2024-01-18 14:34:29 +00:00
manual_seed ( seed = seed )
x = torch . randn ( 1 , 4 , 128 , 128 , device = sdxl . device , dtype = sdxl . dtype )
2023-09-11 12:51:11 +00:00
2024-01-22 13:45:34 +00:00
for step in sdxl . steps :
x = sdxl (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
pooled_text_embedding = pooled_text_embedding ,
time_ids = time_ids ,
condition_scale = guidance_scale ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sdxl . lda . latents_to_image ( x )
2024-01-22 13:45:34 +00:00
ensure_similar_images ( predicted_image , expected_image , min_psnr = 35 , min_ssim = 0.98 )
@no_grad ( )
def test_diffusion_sdxl_multiple_loras (
sdxl_ddim : StableDiffusion_XL ,
lora_data_dpo : tuple [ Image . Image , dict [ str , torch . Tensor ] ] ,
lora_sliders : tuple [ dict [ str , dict [ str , torch . Tensor ] ] , dict [ str , float ] ] ,
expected_sdxl_multi_loras : Image . Image ,
) - > None :
sdxl = sdxl_ddim
expected_image = expected_sdxl_multi_loras
2024-03-05 15:52:21 +00:00
_ , dpo_weights = lora_data_dpo
slider_loras , slider_scales = lora_sliders
2024-01-22 13:45:34 +00:00
2024-03-05 15:46:28 +00:00
manager = SDLoraManager ( sdxl )
2024-03-05 15:52:21 +00:00
for lora_name , lora_weights in slider_loras . items ( ) :
manager . add_loras (
lora_name ,
lora_weights ,
slider_scales [ lora_name ] ,
unet_inclusions = [ " SelfAttention " , " ResidualBlock " , " Downsample " , " Upsample " ] ,
)
manager . add_loras ( " dpo " , dpo_weights , 1.4 , unet_inclusions = [ " CrossAttentionBlock " ] )
2024-01-22 13:45:34 +00:00
# parameters are the same as https://huggingface.co/radames/sdxl-DPO-LoRA
# except that we are using DDIM instead of sde-dpmsolver++
n_steps = 40
seed = 12341234123
guidance_scale = 4
prompt = " professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography "
negative_prompt = " 3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white "
clip_text_embedding , pooled_text_embedding = sdxl . compute_clip_text_embedding (
text = prompt , negative_text = negative_prompt
)
time_ids = sdxl . default_time_ids
sdxl . set_inference_steps ( n_steps )
manual_seed ( seed = seed )
x = torch . randn ( 1 , 4 , 128 , 128 , device = sdxl . device , dtype = sdxl . dtype )
2024-01-18 14:34:29 +00:00
for step in sdxl . steps :
x = sdxl (
2023-09-12 15:55:39 +00:00
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
2024-01-18 14:34:29 +00:00
pooled_text_embedding = pooled_text_embedding ,
time_ids = time_ids ,
condition_scale = guidance_scale ,
2023-09-12 15:55:39 +00:00
)
2023-09-04 13:33:40 +00:00
2024-02-01 14:05:43 +00:00
predicted_image = sdxl . lda . latents_to_image ( x )
2023-09-04 13:33:40 +00:00
ensure_similar_images ( predicted_image , expected_image , min_psnr = 35 , min_ssim = 0.98 )
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-08-04 13:28:41 +00:00
def test_diffusion_refonly (
sd15_ddim : StableDiffusion_1 ,
condition_image_refonly : Image . Image ,
expected_image_refonly : Image . Image ,
test_device : torch . device ,
) :
sd15 = sd15_ddim
2023-09-12 15:55:39 +00:00
prompt = " Chicken "
clip_text_embedding = sd15 . compute_clip_text_embedding ( prompt )
2023-08-04 13:28:41 +00:00
2023-09-14 08:40:24 +00:00
refonly_adapter = ReferenceOnlyControlAdapter ( sd15 . unet ) . inject ( )
2023-08-04 13:28:41 +00:00
2024-02-01 14:05:43 +00:00
guide = sd15 . lda . image_to_latents ( condition_image_refonly )
2023-08-04 13:28:41 +00:00
guide = torch . cat ( ( guide , guide ) )
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 64 , 64 , device = test_device )
2023-09-12 15:55:39 +00:00
for step in sd15 . steps :
noise = torch . randn ( 2 , 4 , 64 , 64 , device = test_device )
2024-01-31 14:07:34 +00:00
noised_guide = sd15 . solver . add_noise ( guide , noise , step )
2023-09-14 08:40:24 +00:00
refonly_adapter . set_controlnet_condition ( noised_guide )
2023-09-12 15:55:39 +00:00
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
torch . randn ( 2 , 4 , 64 , 64 , device = test_device ) # for SD Web UI reproductibility only
2024-02-01 14:05:43 +00:00
predicted_image = sd15 . lda . latents_to_image ( x )
2023-08-04 13:28:41 +00:00
2024-01-19 13:52:05 +00:00
# min_psnr lowered to 33 because this reference image was generated without noise removal (see #192)
ensure_similar_images ( predicted_image , expected_image_refonly , min_psnr = 33 , min_ssim = 0.99 )
2023-08-04 13:28:41 +00:00
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-08-04 13:28:41 +00:00
def test_diffusion_inpainting_refonly (
sd15_inpainting : StableDiffusion_1_Inpainting ,
scene_image_inpainting_refonly : Image . Image ,
target_image_inpainting_refonly : Image . Image ,
mask_image_inpainting_refonly : Image . Image ,
expected_image_inpainting_refonly : Image . Image ,
test_device : torch . device ,
) :
sd15 = sd15_inpainting
2023-09-12 15:55:39 +00:00
prompt = " " # unconditional
clip_text_embedding = sd15 . compute_clip_text_embedding ( prompt )
2023-08-04 13:28:41 +00:00
2023-09-14 08:40:24 +00:00
refonly_adapter = ReferenceOnlyControlAdapter ( sd15 . unet ) . inject ( )
2023-08-04 13:28:41 +00:00
2024-01-19 09:55:04 +00:00
sd15 . set_inference_steps ( 30 )
2023-08-04 13:28:41 +00:00
sd15 . set_inpainting_conditions ( target_image_inpainting_refonly , mask_image_inpainting_refonly )
2024-02-01 14:05:43 +00:00
guide = sd15 . lda . image_to_latents ( scene_image_inpainting_refonly )
2023-08-31 08:40:01 +00:00
guide = torch . cat ( ( guide , guide ) )
2023-08-04 13:28:41 +00:00
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 64 , 64 , device = test_device )
2023-09-12 15:55:39 +00:00
for step in sd15 . steps :
noise = torch . randn_like ( guide )
2024-01-31 14:07:34 +00:00
noised_guide = sd15 . solver . add_noise ( guide , noise , step )
2023-09-12 15:55:39 +00:00
# See https://github.com/Mikubill/sd-webui-controlnet/pull/1275 ("1.1.170 reference-only begin to support
# inpaint variation models")
noised_guide = torch . cat ( [ noised_guide , torch . zeros_like ( noised_guide ) [ : , 0 : 1 , : , : ] , guide ] , dim = 1 )
2023-09-14 08:40:24 +00:00
refonly_adapter . set_controlnet_condition ( noised_guide )
2023-09-12 15:55:39 +00:00
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sd15 . lda . latents_to_image ( x )
2023-08-04 13:28:41 +00:00
ensure_similar_images ( predicted_image , expected_image_inpainting_refonly , min_psnr = 35 , min_ssim = 0.99 )
2023-08-25 18:27:29 +00:00
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-08-25 18:27:29 +00:00
def test_diffusion_textual_inversion_random_init (
sd15_std : StableDiffusion_1 ,
expected_image_textual_inversion_random_init : Image . Image ,
text_embedding_textual_inversion : torch . Tensor ,
test_device : torch . device ,
) :
sd15 = sd15_std
conceptExtender = ConceptExtender ( sd15 . clip_text_encoder )
conceptExtender . add_concept ( " <gta5-artwork> " , text_embedding_textual_inversion )
conceptExtender . inject ( )
prompt = " a cute cat on a <gta5-artwork> "
2023-09-12 15:55:39 +00:00
clip_text_embedding = sd15 . compute_clip_text_embedding ( prompt )
2023-08-25 18:27:29 +00:00
2024-01-19 09:55:04 +00:00
sd15 . set_inference_steps ( 30 )
2023-08-25 18:27:29 +00:00
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 64 , 64 , device = test_device )
2023-09-12 15:55:39 +00:00
for step in sd15 . steps :
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sd15 . lda . latents_to_image ( x )
2023-08-25 18:27:29 +00:00
ensure_similar_images ( predicted_image , expected_image_textual_inversion_random_init , min_psnr = 35 , min_ssim = 0.98 )
2023-09-06 10:23:53 +00:00
2024-06-30 18:56:32 +00:00
@no_grad ( )
def test_diffusion_ella_adapter (
sd15_std_float16 : StableDiffusion_1 ,
2024-10-09 09:28:34 +00:00
ella_sd15_tsc_t5xl_weights_path : Path ,
t5xl_transformers_path : str ,
2024-06-30 18:56:32 +00:00
expected_image_ella_adapter : Image . Image ,
test_device : torch . device ,
2024-10-09 09:28:34 +00:00
use_local_weights : bool ,
2024-06-30 18:56:32 +00:00
) :
sd15 = sd15_std_float16
2024-10-09 09:28:34 +00:00
t5_encoder = T5TextEmbedder (
pretrained_path = t5xl_transformers_path ,
local_files_only = use_local_weights ,
max_length = 128 ,
) . to ( test_device , torch . float16 )
2024-06-30 18:56:32 +00:00
prompt = " a chinese man wearing a white shirt and a checkered headscarf, holds a large falcon near his shoulder. the falcon has dark feathers with a distinctive beak. the background consists of a clear sky and a fence, suggesting an outdoor setting, possibly a desert or arid region "
negative_prompt = " "
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
assert clip_text_embedding . dtype == torch . float16
llm_text_embedding , negative_prompt_embeds = t5_encoder ( prompt ) , t5_encoder ( negative_prompt )
prompt_embedding = torch . cat ( ( negative_prompt_embeds , llm_text_embedding ) ) . to ( test_device , torch . float16 )
2024-10-09 09:28:34 +00:00
adapter = SD1ELLAAdapter ( target = sd15 . unet , weights = load_from_safetensors ( ella_sd15_tsc_t5xl_weights_path ) )
2024-06-30 18:56:32 +00:00
adapter . inject ( )
sd15 . set_inference_steps ( 50 )
manual_seed ( 1001 )
x = torch . randn ( 1 , 4 , 64 , 64 , device = test_device , dtype = torch . float16 )
for step in sd15 . steps :
adapter . set_llm_text_embedding ( prompt_embedding )
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 12 ,
)
predicted_image = sd15 . lda . latents_to_image ( x )
2024-10-14 11:50:56 +00:00
ensure_similar_images ( predicted_image , expected_image_ella_adapter , min_psnr = 31 , min_ssim = 0.98 )
2024-06-30 18:56:32 +00:00
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-09-06 10:23:53 +00:00
def test_diffusion_ip_adapter (
sd15_ddim_lda_ft_mse : StableDiffusion_1 ,
2024-10-09 09:28:34 +00:00
ip_adapter_sd15_weights_path : Path ,
clip_image_encoder_huge_weights_path : Path ,
2023-09-06 10:23:53 +00:00
woman_image : Image . Image ,
expected_image_ip_adapter_woman : Image . Image ,
test_device : torch . device ,
) :
sd15 = sd15_ddim_lda_ft_mse . to ( dtype = torch . float16 )
# See tencent-ailab/IP-Adapter best practices section:
#
# If you only use the image prompt, you can set the scale=1.0 and text_prompt="" (or some generic text
# prompts, e.g. "best quality", you can also use any negative text prompt).
#
# The prompts below are the ones used by default by IPAdapter's generate method if none are specified
prompt = " best quality, high quality "
negative_prompt = " monochrome, lowres, bad anatomy, worst quality, low quality "
2024-10-09 09:28:34 +00:00
ip_adapter = SD1IPAdapter ( target = sd15 . unet , weights = load_from_safetensors ( ip_adapter_sd15_weights_path ) )
ip_adapter . clip_image_encoder . load_from_safetensors ( clip_image_encoder_huge_weights_path )
2023-09-06 10:23:53 +00:00
ip_adapter . inject ( )
2023-09-12 15:55:39 +00:00
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
clip_image_embedding = ip_adapter . compute_clip_image_embedding ( ip_adapter . preprocess_image ( woman_image ) )
2024-01-08 13:39:55 +00:00
ip_adapter . set_clip_image_embedding ( clip_image_embedding )
2023-09-06 10:23:53 +00:00
2024-01-21 17:58:44 +00:00
sd15 . set_inference_steps ( 50 )
2023-09-06 10:23:53 +00:00
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 64 , 64 , device = test_device , dtype = torch . float16 )
2023-09-12 15:55:39 +00:00
for step in sd15 . steps :
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sd15 . lda . latents_to_image ( x )
2023-09-06 10:23:53 +00:00
ensure_similar_images ( predicted_image , expected_image_ip_adapter_woman )
2023-09-06 16:43:02 +00:00
2024-01-30 10:40:16 +00:00
@no_grad ( )
def test_diffusion_ip_adapter_multi (
sd15_ddim_lda_ft_mse : StableDiffusion_1 ,
2024-10-09 09:28:34 +00:00
ip_adapter_sd15_weights_path : Path ,
clip_image_encoder_huge_weights_path : Path ,
2024-01-30 10:40:16 +00:00
woman_image : Image . Image ,
statue_image : Image . Image ,
expected_image_ip_adapter_multi : Image . Image ,
test_device : torch . device ,
) :
sd15 = sd15_ddim_lda_ft_mse . to ( dtype = torch . float16 )
prompt = " best quality, high quality "
negative_prompt = " monochrome, lowres, bad anatomy, worst quality, low quality "
2024-10-09 09:28:34 +00:00
ip_adapter = SD1IPAdapter ( target = sd15 . unet , weights = load_from_safetensors ( ip_adapter_sd15_weights_path ) )
ip_adapter . clip_image_encoder . load_from_safetensors ( clip_image_encoder_huge_weights_path )
2024-01-30 10:40:16 +00:00
ip_adapter . inject ( )
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
clip_image_embedding = ip_adapter . compute_clip_image_embedding ( [ woman_image , statue_image ] , weights = [ 1.0 , 1.4 ] )
ip_adapter . set_clip_image_embedding ( clip_image_embedding )
sd15 . set_inference_steps ( 50 )
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 64 , 64 , device = test_device , dtype = torch . float16 )
for step in sd15 . steps :
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
2024-10-15 12:52:24 +00:00
predicted_image = sd15 . lda . latents_to_image ( x )
2024-01-30 10:40:16 +00:00
2024-10-14 11:50:56 +00:00
ensure_similar_images ( predicted_image , expected_image_ip_adapter_multi , min_psnr = 43 , min_ssim = 0.98 )
2024-01-30 10:40:16 +00:00
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-09-12 15:28:13 +00:00
def test_diffusion_sdxl_ip_adapter (
sdxl_ddim : StableDiffusion_XL ,
2024-10-09 09:28:34 +00:00
ip_adapter_sdxl_weights_path : Path ,
clip_image_encoder_huge_weights_path : Path ,
2023-09-12 15:28:13 +00:00
woman_image : Image . Image ,
expected_image_sdxl_ip_adapter_woman : Image . Image ,
test_device : torch . device ,
) :
sdxl = sdxl_ddim . to ( dtype = torch . float16 )
prompt = " best quality, high quality "
negative_prompt = " monochrome, lowres, bad anatomy, worst quality, low quality "
2024-10-09 09:28:34 +00:00
ip_adapter = SDXLIPAdapter ( target = sdxl . unet , weights = load_from_safetensors ( ip_adapter_sdxl_weights_path ) )
ip_adapter . clip_image_encoder . load_from_safetensors ( clip_image_encoder_huge_weights_path )
2023-09-12 15:28:13 +00:00
ip_adapter . inject ( )
2023-12-29 09:59:51 +00:00
with no_grad ( ) :
2023-09-12 15:28:13 +00:00
clip_text_embedding , pooled_text_embedding = sdxl . compute_clip_text_embedding (
text = prompt , negative_text = negative_prompt
)
clip_image_embedding = ip_adapter . compute_clip_image_embedding ( ip_adapter . preprocess_image ( woman_image ) )
2024-01-08 13:39:55 +00:00
ip_adapter . set_clip_image_embedding ( clip_image_embedding )
2023-09-12 15:28:13 +00:00
time_ids = sdxl . default_time_ids
2024-01-19 09:55:04 +00:00
sdxl . set_inference_steps ( 30 )
2023-09-12 15:28:13 +00:00
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 128 , 128 , device = test_device , dtype = torch . float16 )
2023-12-29 09:59:51 +00:00
with no_grad ( ) :
2023-09-12 15:28:13 +00:00
for step in sdxl . steps :
x = sdxl (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
pooled_text_embedding = pooled_text_embedding ,
time_ids = time_ids ,
condition_scale = 5 ,
)
# See https://huggingface.co/madebyollin/sdxl-vae-fp16-fix: "SDXL-VAE generates NaNs in fp16 because the
# internal activation values are too big"
sdxl . lda . to ( dtype = torch . float32 )
2024-02-01 14:05:43 +00:00
predicted_image = sdxl . lda . latents_to_image ( x . to ( dtype = torch . float32 ) )
2023-09-12 15:28:13 +00:00
ensure_similar_images ( predicted_image , expected_image_sdxl_ip_adapter_woman )
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-09-13 12:13:41 +00:00
def test_diffusion_ip_adapter_controlnet (
sd15_ddim : StableDiffusion_1 ,
2024-10-09 09:28:34 +00:00
ip_adapter_sd15_weights_path : Path ,
clip_image_encoder_huge_weights_path : Path ,
2023-09-13 12:13:41 +00:00
lora_data_pokemon : tuple [ Image . Image , Path ] ,
controlnet_data_depth : tuple [ str , Image . Image , Image . Image , Path ] ,
expected_image_ip_adapter_controlnet : Image . Image ,
test_device : torch . device ,
) :
sd15 = sd15_ddim . to ( dtype = torch . float16 )
input_image , _ = lora_data_pokemon # use the Pokemon LoRA output as input
_ , depth_condition_image , _ , depth_cn_weights_path = controlnet_data_depth
prompt = " best quality, high quality "
negative_prompt = " monochrome, lowres, bad anatomy, worst quality, low quality "
2024-10-09 09:28:34 +00:00
ip_adapter = SD1IPAdapter ( target = sd15 . unet , weights = load_from_safetensors ( ip_adapter_sd15_weights_path ) )
ip_adapter . clip_image_encoder . load_from_safetensors ( clip_image_encoder_huge_weights_path )
2023-09-13 12:13:41 +00:00
ip_adapter . inject ( )
2023-09-22 15:27:23 +00:00
depth_controlnet = SD1ControlnetAdapter (
sd15 . unet ,
name = " depth " ,
scale = 1.0 ,
weights = load_from_safetensors ( depth_cn_weights_path ) ,
) . inject ( )
2023-09-13 12:13:41 +00:00
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
clip_image_embedding = ip_adapter . compute_clip_image_embedding ( ip_adapter . preprocess_image ( input_image ) )
2024-01-08 13:39:55 +00:00
ip_adapter . set_clip_image_embedding ( clip_image_embedding )
2023-09-13 12:13:41 +00:00
depth_cn_condition = image_to_tensor (
depth_condition_image . convert ( " RGB " ) ,
device = test_device ,
dtype = torch . float16 ,
)
2024-01-19 09:55:04 +00:00
sd15 . set_inference_steps ( 50 )
2023-09-13 12:13:41 +00:00
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 64 , 64 , device = test_device , dtype = torch . float16 )
for step in sd15 . steps :
depth_controlnet . set_controlnet_condition ( depth_cn_condition )
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sd15 . lda . latents_to_image ( x )
2023-09-13 12:13:41 +00:00
ensure_similar_images ( predicted_image , expected_image_ip_adapter_controlnet )
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-09-29 12:34:45 +00:00
def test_diffusion_ip_adapter_plus (
sd15_ddim_lda_ft_mse : StableDiffusion_1 ,
2024-10-09 09:28:34 +00:00
ip_adapter_sd15_plus_weights_path : Path ,
clip_image_encoder_huge_weights_path : Path ,
2023-09-29 12:34:45 +00:00
statue_image : Image . Image ,
expected_image_ip_adapter_plus_statue : Image . Image ,
test_device : torch . device ,
) :
sd15 = sd15_ddim_lda_ft_mse . to ( dtype = torch . float16 )
prompt = " best quality, high quality "
negative_prompt = " monochrome, lowres, bad anatomy, worst quality, low quality "
ip_adapter = SD1IPAdapter (
2024-10-09 09:28:34 +00:00
target = sd15 . unet , weights = load_from_safetensors ( ip_adapter_sd15_plus_weights_path ) , fine_grained = True
2023-09-29 12:34:45 +00:00
)
2024-10-09 09:28:34 +00:00
ip_adapter . clip_image_encoder . load_from_safetensors ( clip_image_encoder_huge_weights_path )
2023-09-29 12:34:45 +00:00
ip_adapter . inject ( )
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
clip_image_embedding = ip_adapter . compute_clip_image_embedding ( ip_adapter . preprocess_image ( statue_image ) )
2024-01-08 13:39:55 +00:00
ip_adapter . set_clip_image_embedding ( clip_image_embedding )
2023-09-29 12:34:45 +00:00
2024-01-21 17:58:44 +00:00
sd15 . set_inference_steps ( 50 )
2023-09-29 12:34:45 +00:00
manual_seed ( 42 ) # seed=42 is used in the official IP-Adapter demo
x = torch . randn ( 1 , 4 , 64 , 64 , device = test_device , dtype = torch . float16 )
for step in sd15 . steps :
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sd15 . lda . latents_to_image ( x )
2023-09-29 12:34:45 +00:00
ensure_similar_images ( predicted_image , expected_image_ip_adapter_plus_statue , min_psnr = 35 , min_ssim = 0.98 )
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-09-29 12:34:45 +00:00
def test_diffusion_sdxl_ip_adapter_plus (
sdxl_ddim : StableDiffusion_XL ,
2024-10-09 09:28:34 +00:00
ip_adapter_sdxl_plus_weights_path : Path ,
clip_image_encoder_huge_weights_path : Path ,
2023-09-29 12:34:45 +00:00
woman_image : Image . Image ,
expected_image_sdxl_ip_adapter_plus_woman : Image . Image ,
test_device : torch . device ,
) :
sdxl = sdxl_ddim . to ( dtype = torch . float16 )
prompt = " best quality, high quality "
negative_prompt = " monochrome, lowres, bad anatomy, worst quality, low quality "
ip_adapter = SDXLIPAdapter (
2024-10-09 09:28:34 +00:00
target = sdxl . unet , weights = load_from_safetensors ( ip_adapter_sdxl_plus_weights_path ) , fine_grained = True
2023-09-29 12:34:45 +00:00
)
2024-10-09 09:28:34 +00:00
ip_adapter . clip_image_encoder . load_from_safetensors ( clip_image_encoder_huge_weights_path )
2023-09-29 12:34:45 +00:00
ip_adapter . inject ( )
clip_text_embedding , pooled_text_embedding = sdxl . compute_clip_text_embedding (
text = prompt , negative_text = negative_prompt
)
clip_image_embedding = ip_adapter . compute_clip_image_embedding ( ip_adapter . preprocess_image ( woman_image ) )
2024-01-08 13:39:55 +00:00
ip_adapter . set_clip_image_embedding ( clip_image_embedding )
2023-09-29 12:34:45 +00:00
time_ids = sdxl . default_time_ids
2024-01-19 09:55:04 +00:00
sdxl . set_inference_steps ( 30 )
2023-09-29 12:34:45 +00:00
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 128 , 128 , device = test_device , dtype = torch . float16 )
for step in sdxl . steps :
x = sdxl (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
pooled_text_embedding = pooled_text_embedding ,
time_ids = time_ids ,
condition_scale = 5 ,
)
sdxl . lda . to ( dtype = torch . float32 )
2024-02-01 14:05:43 +00:00
predicted_image = sdxl . lda . latents_to_image ( x . to ( dtype = torch . float32 ) )
2023-09-29 12:34:45 +00:00
2024-10-14 11:50:56 +00:00
ensure_similar_images ( predicted_image , expected_image_sdxl_ip_adapter_plus_woman , min_psnr = 43 , min_ssim = 0.98 )
2023-09-29 12:34:45 +00:00
2023-12-29 09:59:51 +00:00
@no_grad ( )
2024-02-06 17:58:47 +00:00
@pytest.mark.parametrize ( " structural_copy " , [ False , True ] )
def test_diffusion_sdxl_random_init (
sdxl_ddim : StableDiffusion_XL ,
expected_sdxl_ddim_random_init : Image . Image ,
test_device : torch . device ,
structural_copy : bool ,
2023-09-06 16:43:02 +00:00
) - > None :
2024-02-06 17:58:47 +00:00
sdxl = sdxl_ddim . structural_copy ( ) if structural_copy else sdxl_ddim
2023-09-06 16:43:02 +00:00
expected_image = expected_sdxl_ddim_random_init
prompt = " a cute cat, detailed high-quality professional image "
negative_prompt = " lowres, bad anatomy, bad hands, cropped, worst quality "
clip_text_embedding , pooled_text_embedding = sdxl . compute_clip_text_embedding (
text = prompt , negative_text = negative_prompt
)
time_ids = sdxl . default_time_ids
2024-01-19 09:55:04 +00:00
sdxl . set_inference_steps ( 30 )
2023-09-06 16:43:02 +00:00
manual_seed ( seed = 2 )
x = torch . randn ( 1 , 4 , 128 , 128 , device = test_device )
for step in sdxl . steps :
x = sdxl (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
pooled_text_embedding = pooled_text_embedding ,
time_ids = time_ids ,
condition_scale = 5 ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sdxl . lda . latents_to_image ( x = x )
2023-09-06 16:43:02 +00:00
ensure_similar_images ( img_1 = predicted_image , img_2 = expected_image , min_psnr = 35 , min_ssim = 0.98 )
2023-09-18 08:48:05 +00:00
2023-12-29 09:59:51 +00:00
@no_grad ( )
2024-02-06 17:43:46 +00:00
def test_diffusion_sdxl_random_init_sag (
2023-10-09 14:57:58 +00:00
sdxl_ddim : StableDiffusion_XL , expected_sdxl_ddim_random_init_sag : Image . Image , test_device : torch . device
) - > None :
sdxl = sdxl_ddim
expected_image = expected_sdxl_ddim_random_init_sag
prompt = " a cute cat, detailed high-quality professional image "
negative_prompt = " lowres, bad anatomy, bad hands, cropped, worst quality "
clip_text_embedding , pooled_text_embedding = sdxl . compute_clip_text_embedding (
text = prompt , negative_text = negative_prompt
)
time_ids = sdxl . default_time_ids
2024-01-19 09:55:04 +00:00
sdxl . set_inference_steps ( 30 )
2023-10-09 14:57:58 +00:00
sdxl . set_self_attention_guidance ( enable = True , scale = 0.75 )
manual_seed ( seed = 2 )
x = torch . randn ( 1 , 4 , 128 , 128 , device = test_device )
for step in sdxl . steps :
x = sdxl (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
pooled_text_embedding = pooled_text_embedding ,
time_ids = time_ids ,
condition_scale = 5 ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sdxl . lda . latents_to_image ( x = x )
2023-10-09 14:57:58 +00:00
ensure_similar_images ( img_1 = predicted_image , img_2 = expected_image )
2024-01-30 15:49:30 +00:00
@no_grad ( )
def test_diffusion_sdxl_sliced_attention (
sdxl_ddim : StableDiffusion_XL , expected_sdxl_ddim_random_init : Image . Image
) - > None :
unet = sdxl_ddim . unet . structural_copy ( )
for layer in unet . layers ( ScaledDotProductAttention ) :
layer . slice_size = 2048
sdxl = StableDiffusion_XL (
unet = unet ,
lda = sdxl_ddim . lda ,
clip_text_encoder = sdxl_ddim . clip_text_encoder ,
solver = sdxl_ddim . solver ,
device = sdxl_ddim . device ,
dtype = sdxl_ddim . dtype ,
)
expected_image = expected_sdxl_ddim_random_init
prompt = " a cute cat, detailed high-quality professional image "
negative_prompt = " lowres, bad anatomy, bad hands, cropped, worst quality "
clip_text_embedding , pooled_text_embedding = sdxl . compute_clip_text_embedding (
text = prompt , negative_text = negative_prompt
)
time_ids = sdxl . default_time_ids
sdxl . set_inference_steps ( 30 )
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 128 , 128 , device = sdxl . device , dtype = sdxl . dtype )
for step in sdxl . steps :
x = sdxl (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
pooled_text_embedding = pooled_text_embedding ,
time_ids = time_ids ,
condition_scale = 5 ,
)
2024-10-15 12:52:24 +00:00
predicted_image = sdxl . lda . latents_to_image ( x )
2024-01-30 15:49:30 +00:00
ensure_similar_images ( predicted_image , expected_image , min_psnr = 35 , min_ssim = 0.98 )
2024-01-30 17:38:34 +00:00
@no_grad ( )
def test_diffusion_sdxl_euler_deterministic (
sdxl_euler_deterministic : StableDiffusion_XL , expected_sdxl_euler_random_init : Image . Image
) - > None :
sdxl = sdxl_euler_deterministic
assert isinstance ( sdxl . solver , Euler )
expected_image = expected_sdxl_euler_random_init
prompt = " a cute cat, detailed high-quality professional image "
negative_prompt = " lowres, bad anatomy, bad hands, cropped, worst quality "
clip_text_embedding , pooled_text_embedding = sdxl . compute_clip_text_embedding (
text = prompt , negative_text = negative_prompt
)
time_ids = sdxl . default_time_ids
sdxl . set_inference_steps ( 30 )
manual_seed ( 2 )
2024-02-23 15:45:21 +00:00
x = sdxl . init_latents ( ( 1024 , 1024 ) ) . to ( sdxl . device , sdxl . dtype )
2024-01-30 17:38:34 +00:00
for step in sdxl . steps :
x = sdxl (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
pooled_text_embedding = pooled_text_embedding ,
time_ids = time_ids ,
condition_scale = 5 ,
)
2024-10-15 12:52:24 +00:00
predicted_image = sdxl . lda . latents_to_image ( x )
2024-01-30 17:38:34 +00:00
ensure_similar_images ( predicted_image , expected_image )
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-09-18 08:48:05 +00:00
def test_multi_diffusion ( sd15_ddim : StableDiffusion_1 , expected_multi_diffusion : Image . Image ) - > None :
manual_seed ( seed = 2 )
sd = sd15_ddim
multi_diffusion = SD1MultiDiffusion ( sd )
clip_text_embedding = sd . compute_clip_text_embedding ( text = " a panorama of a mountain " )
2024-07-11 13:05:15 +00:00
# DDIM doesn't have an internal state, so we can share the same solver for all targets
target_1 = SD1DiffusionTarget (
tile = Tile ( top = 0 , left = 0 , bottom = 64 , right = 64 ) ,
solver = sd . solver ,
2023-09-18 08:48:05 +00:00
clip_text_embedding = clip_text_embedding ,
)
2024-07-11 13:05:15 +00:00
target_2 = SD1DiffusionTarget (
solver = sd . solver ,
tile = Tile ( top = 0 , left = 16 , bottom = 64 , right = 80 ) ,
2023-09-18 08:48:05 +00:00
clip_text_embedding = clip_text_embedding ,
2024-04-11 09:01:15 +00:00
condition_scale = 3 ,
2023-09-18 08:48:05 +00:00
start_step = 0 ,
)
noise = torch . randn ( 1 , 4 , 64 , 80 , device = sd . device , dtype = sd . dtype )
x = noise
for step in sd . steps :
x = multi_diffusion (
x ,
noise = noise ,
step = step ,
targets = [ target_1 , target_2 ] ,
)
2024-02-01 14:05:43 +00:00
result = sd . lda . latents_to_image ( x = x )
2023-09-18 08:48:05 +00:00
ensure_similar_images ( img_1 = result , img_2 = expected_multi_diffusion , min_psnr = 35 , min_ssim = 0.98 )
2023-09-24 19:32:06 +00:00
2024-07-11 13:05:15 +00:00
@no_grad ( )
def test_multi_diffusion_dpm ( sd15_std : StableDiffusion_1 , expected_multi_diffusion_dpm : Image . Image ) - > None :
manual_seed ( seed = 2 )
sd = sd15_std
multi_diffusion = SD1MultiDiffusion ( sd )
clip_text_embedding = sd . compute_clip_text_embedding ( text = " a panorama of a mountain " )
tiles = SD1MultiDiffusion . generate_latent_tiles ( size = Size ( 112 , 196 ) , tile_size = Size ( 96 , 64 ) , min_overlap = 12 )
targets = [
SD1DiffusionTarget (
tile = tile ,
solver = DPMSolver ( num_inference_steps = sd . solver . num_inference_steps , device = sd . device ) ,
clip_text_embedding = clip_text_embedding ,
)
for tile in tiles
]
noise = torch . randn ( 1 , 4 , 112 , 196 , device = sd . device , dtype = sd . dtype )
x = noise
for step in sd . steps :
x = multi_diffusion (
x ,
noise = noise ,
step = step ,
targets = targets ,
)
result = sd . lda . latents_to_image ( x = x )
ensure_similar_images ( img_1 = result , img_2 = expected_multi_diffusion_dpm , min_psnr = 35 , min_ssim = 0.98 )
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-09-24 19:32:06 +00:00
def test_t2i_adapter_depth (
sd15_std : StableDiffusion_1 ,
t2i_adapter_data_depth : tuple [ str , Image . Image , Image . Image , Path ] ,
test_device : torch . device ,
) :
sd15 = sd15_std
name , condition_image , expected_image , weights_path = t2i_adapter_data_depth
if not weights_path . is_file ( ) :
warn ( f " could not find weights at { weights_path } , skipping " )
pytest . skip ( allow_module_level = True )
prompt = " a cute cat, detailed high-quality professional image "
negative_prompt = " lowres, bad anatomy, bad hands, cropped, worst quality "
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
2024-01-19 09:55:04 +00:00
sd15 . set_inference_steps ( 30 )
2023-09-24 19:32:06 +00:00
t2i_adapter = SD1T2IAdapter ( target = sd15 . unet , name = name , weights = load_from_safetensors ( weights_path ) ) . inject ( )
condition = image_to_tensor ( condition_image . convert ( " RGB " ) , device = test_device )
t2i_adapter . set_condition_features ( features = t2i_adapter . compute_condition_features ( condition ) )
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 64 , 64 , device = test_device )
for step in sd15 . steps :
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sd15 . lda . latents_to_image ( x )
2023-09-24 19:32:06 +00:00
ensure_similar_images ( predicted_image , expected_image )
2023-09-24 20:05:56 +00:00
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-09-24 20:05:56 +00:00
def test_t2i_adapter_xl_canny (
sdxl_ddim : StableDiffusion_XL ,
t2i_adapter_xl_data_canny : tuple [ str , Image . Image , Image . Image , Path ] ,
test_device : torch . device ,
) :
sdxl = sdxl_ddim
name , condition_image , expected_image , weights_path = t2i_adapter_xl_data_canny
if not weights_path . is_file ( ) :
warn ( f " could not find weights at { weights_path } , skipping " )
pytest . skip ( allow_module_level = True )
prompt = " Mystical fairy in real, magic, 4k picture, high quality "
negative_prompt = (
" extra digit, fewer digits, cropped, worst quality, low quality, glitch, deformed, mutated, ugly, disfigured "
)
clip_text_embedding , pooled_text_embedding = sdxl . compute_clip_text_embedding (
text = prompt , negative_text = negative_prompt
)
time_ids = sdxl . default_time_ids
2024-01-19 09:55:04 +00:00
sdxl . set_inference_steps ( 30 )
2023-09-24 20:05:56 +00:00
t2i_adapter = SDXLT2IAdapter ( target = sdxl . unet , name = name , weights = load_from_safetensors ( weights_path ) ) . inject ( )
2024-03-08 09:56:27 +00:00
t2i_adapter . scale = 0.8
2023-09-24 20:05:56 +00:00
condition = image_to_tensor ( condition_image . convert ( " RGB " ) , device = test_device )
t2i_adapter . set_condition_features ( features = t2i_adapter . compute_condition_features ( condition ) )
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , condition_image . height / / 8 , condition_image . width / / 8 , device = test_device )
for step in sdxl . steps :
x = sdxl (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
pooled_text_embedding = pooled_text_embedding ,
time_ids = time_ids ,
condition_scale = 7.5 ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sdxl . lda . latents_to_image ( x )
2023-09-24 20:05:56 +00:00
ensure_similar_images ( predicted_image , expected_image )
2023-10-12 13:04:57 +00:00
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-10-12 13:04:57 +00:00
def test_restart (
sd15_ddim : StableDiffusion_1 ,
expected_restart : Image . Image ,
test_device : torch . device ,
) :
sd15 = sd15_ddim
prompt = " a cute cat, detailed high-quality professional image "
negative_prompt = " lowres, bad anatomy, bad hands, cropped, worst quality "
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
2024-01-19 09:55:04 +00:00
sd15 . set_inference_steps ( 30 )
2023-10-12 13:04:57 +00:00
restart = Restart ( ldm = sd15 )
manual_seed ( 2 )
x = torch . randn ( 1 , 4 , 64 , 64 , device = test_device )
for step in sd15 . steps :
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 8 ,
)
if step == restart . start_step :
x = restart (
x ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 8 ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sd15 . lda . latents_to_image ( x )
2023-10-12 13:04:57 +00:00
ensure_similar_images ( predicted_image , expected_restart , min_psnr = 35 , min_ssim = 0.98 )
2023-11-18 14:24:11 +00:00
2023-12-29 09:59:51 +00:00
@no_grad ( )
2023-11-18 14:24:11 +00:00
def test_freeu (
sd15_std : StableDiffusion_1 ,
expected_freeu : Image . Image ,
) :
sd15 = sd15_std
prompt = " best quality, high quality cute cat "
negative_prompt = " monochrome, lowres, bad anatomy, worst quality, low quality "
clip_text_embedding = sd15 . compute_clip_text_embedding ( text = prompt , negative_text = negative_prompt )
2024-01-19 09:55:04 +00:00
sd15 . set_inference_steps ( 50 , first_step = 1 )
2023-11-18 14:24:11 +00:00
SDFreeUAdapter (
sd15 . unet , backbone_scales = [ 1.2 , 1.2 , 1.2 , 1.4 , 1.4 , 1.4 ] , skip_scales = [ 0.9 , 0.9 , 0.9 , 0.2 , 0.2 , 0.2 ]
) . inject ( )
manual_seed ( 9752 )
2024-01-19 09:55:04 +00:00
x = sd15 . init_latents ( ( 512 , 512 ) ) . to ( device = sd15 . device , dtype = sd15 . dtype )
2023-11-18 14:24:11 +00:00
2024-01-19 09:55:04 +00:00
for step in sd15 . steps :
2023-11-18 14:24:11 +00:00
x = sd15 (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 7.5 ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sd15 . lda . latents_to_image ( x )
2023-11-18 14:24:11 +00:00
ensure_similar_images ( predicted_image , expected_freeu )
2024-01-16 15:13:40 +00:00
@no_grad ( )
def test_hello_world (
sdxl_ddim_lda_fp16_fix : StableDiffusion_XL ,
t2i_adapter_xl_data_canny : tuple [ str , Image . Image , Image . Image , Path ] ,
2024-10-09 09:28:34 +00:00
ip_adapter_sdxl_weights_path : Path ,
clip_image_encoder_huge_weights_path : Path ,
2024-01-16 15:13:40 +00:00
hello_world_assets : tuple [ Image . Image , Image . Image , Image . Image , Image . Image ] ,
) - > None :
sdxl = sdxl_ddim_lda_fp16_fix . to ( dtype = torch . float16 )
sdxl . dtype = torch . float16 # FIXME: should not be necessary
name , _ , _ , weights_path = t2i_adapter_xl_data_canny
init_image , image_prompt , condition_image , expected_image = hello_world_assets
if not weights_path . is_file ( ) :
warn ( f " could not find weights at { weights_path } , skipping " )
pytest . skip ( allow_module_level = True )
2024-10-09 09:28:34 +00:00
ip_adapter = SDXLIPAdapter ( target = sdxl . unet , weights = load_from_safetensors ( ip_adapter_sdxl_weights_path ) )
ip_adapter . clip_image_encoder . load_from_safetensors ( clip_image_encoder_huge_weights_path )
2024-01-16 15:13:40 +00:00
ip_adapter . inject ( )
image_embedding = ip_adapter . compute_clip_image_embedding ( ip_adapter . preprocess_image ( image_prompt ) )
ip_adapter . set_clip_image_embedding ( image_embedding )
# Note: default text prompts for IP-Adapter
clip_text_embedding , pooled_text_embedding = sdxl . compute_clip_text_embedding (
text = " best quality, high quality " , negative_text = " monochrome, lowres, bad anatomy, worst quality, low quality "
)
time_ids = sdxl . default_time_ids
t2i_adapter = SDXLT2IAdapter ( target = sdxl . unet , name = name , weights = load_from_safetensors ( weights_path ) ) . inject ( )
condition = image_to_tensor ( condition_image . convert ( " RGB " ) , device = sdxl . device , dtype = sdxl . dtype )
t2i_adapter . set_condition_features ( features = t2i_adapter . compute_condition_features ( condition ) )
2024-03-08 09:56:27 +00:00
ip_adapter . scale = 0.85
t2i_adapter . scale = 0.8
2024-01-19 09:55:04 +00:00
sdxl . set_inference_steps ( 50 , first_step = 1 )
2024-01-16 15:13:40 +00:00
sdxl . set_self_attention_guidance ( enable = True , scale = 0.75 )
manual_seed ( 9752 )
2024-01-19 09:55:04 +00:00
x = sdxl . init_latents ( size = ( 1024 , 1024 ) , init_image = init_image ) . to ( device = sdxl . device , dtype = sdxl . dtype )
for step in sdxl . steps :
2024-01-16 15:13:40 +00:00
x = sdxl (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
pooled_text_embedding = pooled_text_embedding ,
time_ids = time_ids ,
)
2024-02-01 14:05:43 +00:00
predicted_image = sdxl . lda . latents_to_image ( x )
2024-01-16 15:13:40 +00:00
ensure_similar_images ( predicted_image , expected_image )
2024-02-15 14:11:11 +00:00
@no_grad ( )
def test_style_aligned (
sdxl_ddim_lda_fp16_fix : StableDiffusion_XL ,
expected_style_aligned : Image . Image ,
) :
sdxl = sdxl_ddim_lda_fp16_fix . to ( dtype = torch . float16 )
sdxl . dtype = torch . float16 # FIXME: should not be necessary
style_aligned_adapter = StyleAlignedAdapter ( sdxl . unet )
style_aligned_adapter . inject ( )
set_of_prompts = [
" a toy train. macro photo. 3d game asset " ,
" a toy airplane. macro photo. 3d game asset " ,
" a toy bicycle. macro photo. 3d game asset " ,
" a toy car. macro photo. 3d game asset " ,
" a toy boat. macro photo. 3d game asset " ,
]
# create (context) embeddings from prompts
2024-02-21 15:21:23 +00:00
clip_text_embedding , pooled_text_embedding = sdxl . compute_clip_text_embedding (
text = set_of_prompts , negative_text = [ " " ] * len ( set_of_prompts )
)
2024-02-15 14:11:11 +00:00
time_ids = sdxl . default_time_ids . repeat ( len ( set_of_prompts ) , 1 )
# initialize latents
manual_seed ( seed = 2 )
x = torch . randn (
( len ( set_of_prompts ) , 4 , 128 , 128 ) ,
device = sdxl . device ,
dtype = sdxl . dtype ,
)
# denoise
for step in sdxl . steps :
x = sdxl (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
pooled_text_embedding = pooled_text_embedding ,
time_ids = time_ids ,
)
# decode latents
2024-10-15 12:52:24 +00:00
predicted_images = sdxl . lda . latents_to_images ( x )
2024-02-15 14:11:11 +00:00
# tile all images horizontally
merged_image = Image . new ( " RGB " , ( 1024 * len ( predicted_images ) , 1024 ) )
2024-10-14 11:50:56 +00:00
for i , image in enumerate ( predicted_images ) :
merged_image . paste ( image , ( 1024 * i , 0 ) )
2024-02-15 14:11:11 +00:00
# compare against reference image
2024-10-14 11:50:56 +00:00
ensure_similar_images ( merged_image , expected_style_aligned , min_psnr = 12 , min_ssim = 0.39 , min_dinov2 = 0.95 )
2024-07-11 13:05:15 +00:00
@no_grad ( )
def test_multi_upscaler (
multi_upscaler : MultiUpscaler ,
clarity_example : Image . Image ,
expected_multi_upscaler : Image . Image ,
) - > None :
2024-09-26 09:12:25 +00:00
generator = torch . Generator ( device = multi_upscaler . device )
generator . manual_seed ( 37 )
predicted_image = multi_upscaler . upscale ( clarity_example , generator = generator )
2024-10-14 11:50:56 +00:00
ensure_similar_images ( predicted_image , expected_multi_upscaler , min_psnr = 25 , min_ssim = 0.85 , min_dinov2 = 0.96 )
2024-08-12 09:33:15 +00:00
2024-10-01 08:56:40 +00:00
@no_grad ( )
def test_multi_upscaler_small (
multi_upscaler : MultiUpscaler ,
clarity_example : Image . Image ,
) - > None :
image = clarity_example . resize ( ( 16 , 16 ) )
image = multi_upscaler . upscale ( image ) # check we can upscale a small image
image = multi_upscaler . upscale ( image ) # check we can upscale it twice
2024-08-12 09:33:15 +00:00
@pytest.fixture ( scope = " module " )
def expected_ic_light ( ref_path : Path ) - > Image . Image :
2024-10-15 13:51:19 +00:00
return Image . open ( ref_path / " expected_ic_light.png " ) . convert ( " RGB " )
2024-08-12 09:33:15 +00:00
@pytest.fixture ( scope = " module " )
def ic_light_sd15_fc (
2024-10-09 09:28:34 +00:00
ic_light_sd15_fc_weights_path : Path ,
sd15_unet_weights_path : Path ,
sd15_autoencoder_weights_path : Path ,
sd15_text_encoder_weights_path : Path ,
2024-08-12 09:33:15 +00:00
test_device : torch . device ,
) - > ICLight :
return ICLight (
2024-10-09 09:28:34 +00:00
patch_weights = load_from_safetensors ( ic_light_sd15_fc_weights_path ) ,
unet = SD1UNet ( in_channels = 4 ) . load_from_safetensors ( sd15_unet_weights_path ) ,
lda = SD1Autoencoder ( ) . load_from_safetensors ( sd15_autoencoder_weights_path ) ,
clip_text_encoder = CLIPTextEncoderL ( ) . load_from_safetensors ( sd15_text_encoder_weights_path ) ,
2024-08-12 09:33:15 +00:00
device = test_device ,
)
@no_grad ( )
def test_ic_light (
kitchen_dog : Image . Image ,
kitchen_dog_mask : Image . Image ,
ic_light_sd15_fc : ICLight ,
expected_ic_light : Image . Image ,
test_device : torch . device ,
) - > None :
sd = ic_light_sd15_fc
manual_seed ( 2 )
clip_text_embedding = sd . compute_clip_text_embedding (
text = " a photo of dog, purple neon lighting " ,
negative_text = " lowres, bad anatomy, bad hands, cropped, worst quality " ,
)
ic_light_condition = sd . compute_gray_composite ( image = kitchen_dog , mask = kitchen_dog_mask . convert ( " L " ) )
sd . set_ic_light_condition ( ic_light_condition )
x = torch . randn ( 1 , 4 , 64 , 64 , device = test_device )
for step in sd . steps :
x = sd (
x ,
step = step ,
clip_text_embedding = clip_text_embedding ,
condition_scale = 2.0 ,
)
predicted_image = sd . lda . latents_to_image ( x )
ensure_similar_images ( predicted_image , expected_ic_light , min_psnr = 35 , min_ssim = 0.99 )