update tests to use new fixtures

This commit is contained in:
Laurent 2024-10-09 09:28:34 +00:00 committed by Laureηt
parent 94eeb1afc3
commit 316fe6e4f0
26 changed files with 836 additions and 1070 deletions

View file

@ -1,5 +1,4 @@
from pathlib import Path from pathlib import Path
from warnings import warn
import pytest import pytest
import torch import torch
@ -16,14 +15,8 @@ def manager() -> SDLoraManager:
@pytest.fixture @pytest.fixture
def weights(test_weights_path: Path) -> dict[str, torch.Tensor]: def weights(lora_pokemon_weights_path: Path) -> dict[str, torch.Tensor]:
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin" return load_tensors(lora_pokemon_weights_path)
if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
return load_tensors(weights_path)
def test_add_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None: def test_add_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:

View file

@ -9,6 +9,8 @@ import torch
from PIL import Image from PIL import Image
from tests.utils import T5TextEmbedder, ensure_similar_images from tests.utils import T5TextEmbedder, ensure_similar_images
from refiners.conversion import controllora_sdxl
from refiners.conversion.utils import Hub
from refiners.fluxion.layers.attentions import ScaledDotProductAttention from refiners.fluxion.layers.attentions import ScaledDotProductAttention
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad
from refiners.foundationals.clip.concepts import ConceptExtender from refiners.foundationals.clip.concepts import ConceptExtender
@ -45,6 +47,8 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler i
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAdapter from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAdapter
from ..weight_paths import get_path
def _img_open(path: Path) -> Image.Image: def _img_open(path: Path) -> Image.Image:
return Image.open(path) # type: ignore return Image.open(path) # type: ignore
@ -194,69 +198,85 @@ def expected_style_aligned(ref_path: Path) -> Image.Image:
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"]) @pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
def controlnet_data( def controlnet_data(
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest ref_path: Path,
controlnet_depth_weights_path: Path,
controlnet_canny_weights_path: Path,
controlnet_lineart_weights_path: Path,
controlnet_normals_weights_path: Path,
controlnet_sam_weights_path: Path,
request: pytest.FixtureRequest,
) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]: ) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]:
cn_name: str = request.param cn_name: str = request.param
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
weights_fn = {
"depth": "lllyasviel_control_v11f1p_sd15_depth",
"canny": "lllyasviel_control_v11p_sd15_canny",
"lineart": "lllyasviel_control_v11p_sd15_lineart",
"normals": "lllyasviel_control_v11p_sd15_normalbae",
"sam": "mfidabel_controlnet-segment-anything",
}
weights_path = test_weights_path / "controlnet" / f"{weights_fn[cn_name]}.safetensors" weights_fn = {
yield (cn_name, condition_image, expected_image, weights_path) "depth": controlnet_depth_weights_path,
"canny": controlnet_canny_weights_path,
"lineart": controlnet_lineart_weights_path,
"normals": controlnet_normals_weights_path,
"sam": controlnet_sam_weights_path,
}
weights_path = weights_fn[cn_name]
yield cn_name, condition_image, expected_image, weights_path
@pytest.fixture(scope="module", params=["canny"]) @pytest.fixture(scope="module", params=["canny"])
def controlnet_data_scale_decay( def controlnet_data_scale_decay(
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest ref_path: Path,
controlnet_canny_weights_path: Path,
request: pytest.FixtureRequest,
) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]: ) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]:
cn_name: str = request.param cn_name: str = request.param
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}_scale_decay.png").convert("RGB") expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}_scale_decay.png").convert("RGB")
weights_fn = {
"canny": "lllyasviel_control_v11p_sd15_canny",
}
weights_path = test_weights_path / "controlnet" / f"{weights_fn[cn_name]}.safetensors" weights_fn = {
"canny": controlnet_canny_weights_path,
}
weights_path = weights_fn[cn_name]
yield (cn_name, condition_image, expected_image, weights_path) yield (cn_name, condition_image, expected_image, weights_path)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def controlnet_data_tile(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, Image.Image, Path]: def controlnet_data_tile(
ref_path: Path,
controlnet_tiles_weights_path: Path,
) -> tuple[Image.Image, Image.Image, Path]:
condition_image = _img_open(ref_path / f"low_res_dog.png").convert("RGB").resize((1024, 1024)) # type: ignore condition_image = _img_open(ref_path / f"low_res_dog.png").convert("RGB").resize((1024, 1024)) # type: ignore
expected_image = _img_open(ref_path / f"expected_controlnet_tile.png").convert("RGB") expected_image = _img_open(ref_path / f"expected_controlnet_tile.png").convert("RGB")
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11f1e_sd15_tile.safetensors" return condition_image, expected_image, controlnet_tiles_weights_path
return condition_image, expected_image, weights_path
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: def controlnet_data_canny(
ref_path: Path,
controlnet_canny_weights_path: Path,
) -> tuple[str, Image.Image, Image.Image, Path]:
cn_name = "canny" cn_name = "canny"
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11p_sd15_canny.safetensors" return cn_name, condition_image, expected_image, controlnet_canny_weights_path
return cn_name, condition_image, expected_image, weights_path
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def controlnet_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: def controlnet_data_depth(
ref_path: Path,
controlnet_depth_weights_path: Path,
) -> tuple[str, Image.Image, Image.Image, Path]:
cn_name = "depth" cn_name = "depth"
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11f1p_sd15_depth.safetensors" return cn_name, condition_image, expected_image, controlnet_depth_weights_path
return cn_name, condition_image, expected_image, weights_path
@dataclass @dataclass
class ControlLoraConfig: class ControlLoraConfig:
scale: float scale: float
condition_path: str condition_path: str
weights_path: str weights: Hub
@dataclass @dataclass
@ -271,38 +291,38 @@ CONTROL_LORA_CONFIGS: dict[str, dict[str, ControlLoraConfig]] = {
"PyraCanny": ControlLoraConfig( "PyraCanny": ControlLoraConfig(
scale=1.0, scale=1.0,
condition_path="cutecat_guide_PyraCanny.png", condition_path="cutecat_guide_PyraCanny.png",
weights_path="refiners_control-lora-canny-rank128.safetensors", weights=controllora_sdxl.canny.converted,
), ),
}, },
"expected_controllora_CPDS.png": { "expected_controllora_CPDS.png": {
"CPDS": ControlLoraConfig( "CPDS": ControlLoraConfig(
scale=1.0, scale=1.0,
condition_path="cutecat_guide_CPDS.png", condition_path="cutecat_guide_CPDS.png",
weights_path="refiners_fooocus_xl_cpds_128.safetensors", weights=controllora_sdxl.cpds.converted,
), ),
}, },
"expected_controllora_PyraCanny+CPDS.png": { "expected_controllora_PyraCanny+CPDS.png": {
"PyraCanny": ControlLoraConfig( "PyraCanny": ControlLoraConfig(
scale=0.55, scale=0.55,
condition_path="cutecat_guide_PyraCanny.png", condition_path="cutecat_guide_PyraCanny.png",
weights_path="refiners_control-lora-canny-rank128.safetensors", weights=controllora_sdxl.canny.converted,
), ),
"CPDS": ControlLoraConfig( "CPDS": ControlLoraConfig(
scale=0.55, scale=0.55,
condition_path="cutecat_guide_CPDS.png", condition_path="cutecat_guide_CPDS.png",
weights_path="refiners_fooocus_xl_cpds_128.safetensors", weights=controllora_sdxl.cpds.converted,
), ),
}, },
"expected_controllora_disabled.png": { "expected_controllora_disabled.png": {
"PyraCanny": ControlLoraConfig( "PyraCanny": ControlLoraConfig(
scale=0.0, scale=0.0,
condition_path="cutecat_guide_PyraCanny.png", condition_path="cutecat_guide_PyraCanny.png",
weights_path="refiners_control-lora-canny-rank128.safetensors", weights=controllora_sdxl.canny.converted,
), ),
"CPDS": ControlLoraConfig( "CPDS": ControlLoraConfig(
scale=0.0, scale=0.0,
condition_path="cutecat_guide_CPDS.png", condition_path="cutecat_guide_CPDS.png",
weights_path="refiners_fooocus_xl_cpds_128.safetensors", weights=controllora_sdxl.cpds.converted,
), ),
}, },
} }
@ -311,8 +331,8 @@ CONTROL_LORA_CONFIGS: dict[str, dict[str, ControlLoraConfig]] = {
@pytest.fixture(params=CONTROL_LORA_CONFIGS.items()) @pytest.fixture(params=CONTROL_LORA_CONFIGS.items())
def controllora_sdxl_config( def controllora_sdxl_config(
request: pytest.FixtureRequest, request: pytest.FixtureRequest,
use_local_weights: bool,
ref_path: Path, ref_path: Path,
test_weights_path: Path,
) -> tuple[Image.Image, dict[str, ControlLoraResolvedConfig]]: ) -> tuple[Image.Image, dict[str, ControlLoraResolvedConfig]]:
name: str = request.param[0] name: str = request.param[0]
configs: dict[str, ControlLoraConfig] = request.param[1] configs: dict[str, ControlLoraConfig] = request.param[1]
@ -322,7 +342,7 @@ def controllora_sdxl_config(
config_name: ControlLoraResolvedConfig( config_name: ControlLoraResolvedConfig(
scale=config.scale, scale=config.scale,
condition_image=_img_open(ref_path / config.condition_path).convert("RGB"), condition_image=_img_open(ref_path / config.condition_path).convert("RGB"),
weights_path=test_weights_path / "control-loras" / config.weights_path, weights_path=get_path(config.weights, use_local_weights),
) )
for config_name, config in configs.items() for config_name, config in configs.items()
} }
@ -331,66 +351,57 @@ def controllora_sdxl_config(
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: def t2i_adapter_data_depth(
ref_path: Path,
t2i_depth_weights_path: Path,
) -> tuple[str, Image.Image, Image.Image, Path]:
name = "depth" name = "depth"
condition_image = _img_open(ref_path / f"cutecat_guide_{name}.png").convert("RGB") condition_image = _img_open(ref_path / f"cutecat_guide_{name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_t2i_adapter_{name}.png").convert("RGB") expected_image = _img_open(ref_path / f"expected_t2i_adapter_{name}.png").convert("RGB")
weights_path = test_weights_path / "T2I-Adapter" / "t2iadapter_depth_sd15v2.safetensors" return name, condition_image, expected_image, t2i_depth_weights_path
return name, condition_image, expected_image, weights_path
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def t2i_adapter_xl_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: def t2i_adapter_xl_data_canny(
ref_path: Path,
t2i_sdxl_canny_weights_path: Path,
) -> tuple[str, Image.Image, Image.Image, Path]:
name = "canny" name = "canny"
condition_image = _img_open(ref_path / f"fairy_guide_{name}.png").convert("RGB") condition_image = _img_open(ref_path / f"fairy_guide_{name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB") expected_image = _img_open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB")
weights_path = test_weights_path / "T2I-Adapter" / "t2i-adapter-canny-sdxl-1.0.safetensors" return name, condition_image, expected_image, t2i_sdxl_canny_weights_path
if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
return name, condition_image, expected_image, weights_path
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]: def lora_data_pokemon(
ref_path: Path,
lora_pokemon_weights_path: Path,
) -> tuple[Image.Image, dict[str, torch.Tensor]]:
expected_image = _img_open(ref_path / "expected_lora_pokemon.png").convert("RGB") expected_image = _img_open(ref_path / "expected_lora_pokemon.png").convert("RGB")
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin" tensors = load_tensors(lora_pokemon_weights_path)
if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
tensors = load_tensors(weights_path)
return expected_image, tensors return expected_image, tensors
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def lora_data_dpo(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]: def lora_data_dpo(
ref_path: Path,
lora_dpo_weights_path: Path,
) -> tuple[Image.Image, dict[str, torch.Tensor]]:
expected_image = _img_open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB") expected_image = _img_open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB")
weights_path = test_weights_path / "loras" / "dpo-lora" / "pytorch_lora_weights.safetensors" tensors = load_from_safetensors(lora_dpo_weights_path)
if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
tensors = load_from_safetensors(weights_path)
return expected_image, tensors return expected_image, tensors
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def lora_sliders(test_weights_path: Path) -> tuple[dict[str, dict[str, torch.Tensor]], dict[str, float]]: def lora_sliders(
weights_path = test_weights_path / "loras" / "sliders" lora_slider_age_weights_path: Path,
lora_slider_cartoon_style_weights_path: Path,
if not weights_path.is_dir(): lora_slider_eyesize_weights_path: Path,
warn(f"could not find weights at {weights_path}, skipping") ) -> tuple[dict[str, dict[str, torch.Tensor]], dict[str, float]]:
pytest.skip(allow_module_level=True)
return { return {
"age": load_tensors(weights_path / "age.pt"), # type: ignore "age": load_tensors(lora_slider_age_weights_path),
"cartoon_style": load_tensors(weights_path / "cartoon_style.pt"), # type: ignore "cartoon_style": load_tensors(lora_slider_cartoon_style_weights_path),
"eyesize": load_tensors(weights_path / "eyesize.pt"), # type: ignore "eyesize": load_tensors(lora_slider_eyesize_weights_path),
}, { }, {
"age": 0.3, "age": 0.3,
"cartoon_style": -0.2, "cartoon_style": -0.2,
@ -477,122 +488,12 @@ def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch
return load_tensors(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] return load_tensors(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"]
@pytest.fixture(scope="module")
def text_encoder_weights(test_weights_path: Path) -> Path:
text_encoder_weights = test_weights_path / "CLIPTextEncoderL.safetensors"
if not text_encoder_weights.is_file():
warn(f"could not find weights at {text_encoder_weights}, skipping")
pytest.skip(allow_module_level=True)
return text_encoder_weights
@pytest.fixture(scope="module")
def lda_weights(test_weights_path: Path) -> Path:
lda_weights = test_weights_path / "lda.safetensors"
if not lda_weights.is_file():
warn(f"could not find weights at {lda_weights}, skipping")
pytest.skip(allow_module_level=True)
return lda_weights
@pytest.fixture(scope="module")
def unet_weights_std(test_weights_path: Path) -> Path:
unet_weights_std = test_weights_path / "unet.safetensors"
if not unet_weights_std.is_file():
warn(f"could not find weights at {unet_weights_std}, skipping")
pytest.skip(allow_module_level=True)
return unet_weights_std
@pytest.fixture(scope="module")
def unet_weights_inpainting(test_weights_path: Path) -> Path:
unet_weights_inpainting = test_weights_path / "inpainting" / "unet.safetensors"
if not unet_weights_inpainting.is_file():
warn(f"could not find weights at {unet_weights_inpainting}, skipping")
pytest.skip(allow_module_level=True)
return unet_weights_inpainting
@pytest.fixture(scope="module")
def lda_ft_mse_weights(test_weights_path: Path) -> Path:
lda_weights = test_weights_path / "lda_ft_mse.safetensors"
if not lda_weights.is_file():
warn(f"could not find weights at {lda_weights}, skipping")
pytest.skip(allow_module_level=True)
return lda_weights
@pytest.fixture(scope="module")
def ella_weights(test_weights_path: Path) -> tuple[Path, Path]:
ella_adapter_weights = test_weights_path / "ELLA-Adapter" / "ella-sd1.5-tsc-t5xl.safetensors"
if not ella_adapter_weights.is_file():
warn(f"could not find weights at {ella_adapter_weights}, skipping")
pytest.skip(allow_module_level=True)
t5xl_weights = test_weights_path / "QQGYLab" / "T5XLFP16"
t5xl_files = [
"config.json",
"model.safetensors",
"special_tokens_map.json",
"spiece.model",
"tokenizer_config.json",
"tokenizer.json",
]
for file in t5xl_files:
if not (t5xl_weights / file).is_file():
warn(f"could not find weights at {t5xl_weights / file}, skipping")
pytest.skip(allow_module_level=True)
return (ella_adapter_weights, t5xl_weights)
@pytest.fixture(scope="module")
def ip_adapter_weights(test_weights_path: Path) -> Path:
ip_adapter_weights = test_weights_path / "ip-adapter_sd15.safetensors"
if not ip_adapter_weights.is_file():
warn(f"could not find weights at {ip_adapter_weights}, skipping")
pytest.skip(allow_module_level=True)
return ip_adapter_weights
@pytest.fixture(scope="module")
def ip_adapter_plus_weights(test_weights_path: Path) -> Path:
ip_adapter_weights = test_weights_path / "ip-adapter-plus_sd15.safetensors"
if not ip_adapter_weights.is_file():
warn(f"could not find weights at {ip_adapter_weights}, skipping")
pytest.skip(allow_module_level=True)
return ip_adapter_weights
@pytest.fixture(scope="module")
def sdxl_ip_adapter_weights(test_weights_path: Path) -> Path:
ip_adapter_weights = test_weights_path / "ip-adapter_sdxl_vit-h.safetensors"
if not ip_adapter_weights.is_file():
warn(f"could not find weights at {ip_adapter_weights}, skipping")
pytest.skip(allow_module_level=True)
return ip_adapter_weights
@pytest.fixture(scope="module")
def sdxl_ip_adapter_plus_weights(test_weights_path: Path) -> Path:
ip_adapter_weights = test_weights_path / "ip-adapter-plus_sdxl_vit-h.safetensors"
if not ip_adapter_weights.is_file():
warn(f"could not find weights at {ip_adapter_weights}, skipping")
pytest.skip(allow_module_level=True)
return ip_adapter_weights
@pytest.fixture(scope="module")
def image_encoder_weights(test_weights_path: Path) -> Path:
image_encoder_weights = test_weights_path / "CLIPImageEncoderH.safetensors"
if not image_encoder_weights.is_file():
warn(f"could not find weights at {image_encoder_weights}, skipping")
pytest.skip(allow_module_level=True)
return image_encoder_weights
@pytest.fixture @pytest.fixture
def sd15_std( def sd15_std(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device sd15_text_encoder_weights_path: Path,
sd15_autoencoder_weights_path: Path,
sd15_unet_weights_path: Path,
test_device: torch.device,
) -> StableDiffusion_1: ) -> StableDiffusion_1:
if test_device.type == "cpu": if test_device.type == "cpu":
warn("not running on CPU, skipping") warn("not running on CPU, skipping")
@ -600,16 +501,19 @@ def sd15_std(
sd15 = StableDiffusion_1(device=test_device) sd15 = StableDiffusion_1(device=test_device)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
sd15.lda.load_from_safetensors(lda_weights) sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
sd15.unet.load_from_safetensors(unet_weights_std) sd15.unet.load_from_safetensors(sd15_unet_weights_path)
return sd15 return sd15
@pytest.fixture @pytest.fixture
def sd15_std_sde( def sd15_std_sde(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device sd15_text_encoder_weights_path: Path,
sd15_autoencoder_weights_path: Path,
sd15_unet_weights_path: Path,
test_device: torch.device,
) -> StableDiffusion_1: ) -> StableDiffusion_1:
if test_device.type == "cpu": if test_device.type == "cpu":
warn("not running on CPU, skipping") warn("not running on CPU, skipping")
@ -618,16 +522,19 @@ def sd15_std_sde(
sde_solver = DPMSolver(num_inference_steps=30, last_step_first_order=True, params=SolverParams(sde_variance=1.0)) sde_solver = DPMSolver(num_inference_steps=30, last_step_first_order=True, params=SolverParams(sde_variance=1.0))
sd15 = StableDiffusion_1(device=test_device, solver=sde_solver) sd15 = StableDiffusion_1(device=test_device, solver=sde_solver)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
sd15.lda.load_from_safetensors(lda_weights) sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
sd15.unet.load_from_safetensors(unet_weights_std) sd15.unet.load_from_safetensors(sd15_unet_weights_path)
return sd15 return sd15
@pytest.fixture @pytest.fixture
def sd15_std_float16( def sd15_std_float16(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device sd15_text_encoder_weights_path: Path,
sd15_autoencoder_weights_path: Path,
sd15_unet_weights_path: Path,
test_device: torch.device,
) -> StableDiffusion_1: ) -> StableDiffusion_1:
if test_device.type == "cpu": if test_device.type == "cpu":
warn("not running on CPU, skipping") warn("not running on CPU, skipping")
@ -635,18 +542,18 @@ def sd15_std_float16(
sd15 = StableDiffusion_1(device=test_device, dtype=torch.float16) sd15 = StableDiffusion_1(device=test_device, dtype=torch.float16)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
sd15.lda.load_from_safetensors(lda_weights) sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
sd15.unet.load_from_safetensors(unet_weights_std) sd15.unet.load_from_safetensors(sd15_unet_weights_path)
return sd15 return sd15
@pytest.fixture @pytest.fixture
def sd15_std_bfloat16( def sd15_std_bfloat16(
text_encoder_weights: Path, sd15_text_encoder_weights_path: Path,
lda_weights: Path, sd15_autoencoder_weights_path: Path,
unet_weights_std: Path, sd15_unet_weights_path: Path,
test_device: torch.device, test_device: torch.device,
) -> StableDiffusion_1: ) -> StableDiffusion_1:
if test_device.type == "cpu": if test_device.type == "cpu":
@ -655,16 +562,19 @@ def sd15_std_bfloat16(
sd15 = StableDiffusion_1(device=test_device, dtype=torch.bfloat16) sd15 = StableDiffusion_1(device=test_device, dtype=torch.bfloat16)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
sd15.lda.load_from_safetensors(lda_weights) sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
sd15.unet.load_from_safetensors(unet_weights_std) sd15.unet.load_from_safetensors(sd15_unet_weights_path)
return sd15 return sd15
@pytest.fixture @pytest.fixture
def sd15_inpainting( def sd15_inpainting(
text_encoder_weights: Path, lda_weights: Path, unet_weights_inpainting: Path, test_device: torch.device sd15_text_encoder_weights_path: Path,
sd15_autoencoder_weights_path: Path,
sd15_unet_inpainting_weights_path: Path,
test_device: torch.device,
) -> StableDiffusion_1_Inpainting: ) -> StableDiffusion_1_Inpainting:
if test_device.type == "cpu": if test_device.type == "cpu":
warn("not running on CPU, skipping") warn("not running on CPU, skipping")
@ -673,16 +583,19 @@ def sd15_inpainting(
unet = SD1UNet(in_channels=9) unet = SD1UNet(in_channels=9)
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device) sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
sd15.lda.load_from_safetensors(lda_weights) sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
sd15.unet.load_from_safetensors(unet_weights_inpainting) sd15.unet.load_from_safetensors(sd15_unet_inpainting_weights_path)
return sd15 return sd15
@pytest.fixture @pytest.fixture
def sd15_inpainting_float16( def sd15_inpainting_float16(
text_encoder_weights: Path, lda_weights: Path, unet_weights_inpainting: Path, test_device: torch.device sd15_text_encoder_weights_path: Path,
sd15_autoencoder_weights_path: Path,
sd15_unet_inpainting_weights_path: Path,
test_device: torch.device,
) -> StableDiffusion_1_Inpainting: ) -> StableDiffusion_1_Inpainting:
if test_device.type == "cpu": if test_device.type == "cpu":
warn("not running on CPU, skipping") warn("not running on CPU, skipping")
@ -691,16 +604,19 @@ def sd15_inpainting_float16(
unet = SD1UNet(in_channels=9) unet = SD1UNet(in_channels=9)
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device, dtype=torch.float16) sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device, dtype=torch.float16)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
sd15.lda.load_from_safetensors(lda_weights) sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
sd15.unet.load_from_safetensors(unet_weights_inpainting) sd15.unet.load_from_safetensors(sd15_unet_inpainting_weights_path)
return sd15 return sd15
@pytest.fixture @pytest.fixture
def sd15_ddim( def sd15_ddim(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device sd15_text_encoder_weights_path: Path,
sd15_autoencoder_weights_path: Path,
sd15_unet_weights_path: Path,
test_device: torch.device,
) -> StableDiffusion_1: ) -> StableDiffusion_1:
if test_device.type == "cpu": if test_device.type == "cpu":
warn("not running on CPU, skipping") warn("not running on CPU, skipping")
@ -709,16 +625,19 @@ def sd15_ddim(
ddim_solver = DDIM(num_inference_steps=20) ddim_solver = DDIM(num_inference_steps=20)
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device) sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
sd15.lda.load_from_safetensors(lda_weights) sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
sd15.unet.load_from_safetensors(unet_weights_std) sd15.unet.load_from_safetensors(sd15_unet_weights_path)
return sd15 return sd15
@pytest.fixture @pytest.fixture
def sd15_ddim_karras( def sd15_ddim_karras(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device sd15_text_encoder_weights_path: Path,
sd15_autoencoder_weights_path: Path,
sd15_unet_weights_path: Path,
test_device: torch.device,
) -> StableDiffusion_1: ) -> StableDiffusion_1:
if test_device.type == "cpu": if test_device.type == "cpu":
warn("not running on CPU, skipping") warn("not running on CPU, skipping")
@ -727,16 +646,18 @@ def sd15_ddim_karras(
ddim_solver = DDIM(num_inference_steps=20, params=SolverParams(noise_schedule=NoiseSchedule.KARRAS)) ddim_solver = DDIM(num_inference_steps=20, params=SolverParams(noise_schedule=NoiseSchedule.KARRAS))
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device) sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
sd15.lda.load_from_safetensors(lda_weights) sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
sd15.unet.load_from_safetensors(unet_weights_std) sd15.unet.load_from_safetensors(sd15_unet_weights_path)
return sd15 return sd15
@pytest.fixture @pytest.fixture
def sd15_euler( def sd15_euler(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device sd15_text_encoder_weights_path: Path,
sd15_autoencoder_weights_path: Path,
sd15_unet_weights_path: Path,
test_device: torch.device,
) -> StableDiffusion_1: ) -> StableDiffusion_1:
if test_device.type == "cpu": if test_device.type == "cpu":
warn("not running on CPU, skipping") warn("not running on CPU, skipping")
@ -745,16 +666,19 @@ def sd15_euler(
euler_solver = Euler(num_inference_steps=30) euler_solver = Euler(num_inference_steps=30)
sd15 = StableDiffusion_1(solver=euler_solver, device=test_device) sd15 = StableDiffusion_1(solver=euler_solver, device=test_device)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
sd15.lda.load_from_safetensors(lda_weights) sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
sd15.unet.load_from_safetensors(unet_weights_std) sd15.unet.load_from_safetensors(sd15_unet_weights_path)
return sd15 return sd15
@pytest.fixture @pytest.fixture
def sd15_ddim_lda_ft_mse( def sd15_ddim_lda_ft_mse(
text_encoder_weights: Path, lda_ft_mse_weights: Path, unet_weights_std: Path, test_device: torch.device sd15_text_encoder_weights_path: Path,
sd15_autoencoder_mse_weights_path: Path,
sd15_unet_weights_path: Path,
test_device: torch.device,
) -> StableDiffusion_1: ) -> StableDiffusion_1:
if test_device.type == "cpu": if test_device.type == "cpu":
warn("not running on CPU, skipping") warn("not running on CPU, skipping")
@ -763,52 +687,19 @@ def sd15_ddim_lda_ft_mse(
ddim_solver = DDIM(num_inference_steps=20) ddim_solver = DDIM(num_inference_steps=20)
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device) sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights)) sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
sd15.lda.load_state_dict(load_from_safetensors(lda_ft_mse_weights)) sd15.lda.load_from_safetensors(sd15_autoencoder_mse_weights_path)
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std)) sd15.unet.load_from_safetensors(sd15_unet_weights_path)
return sd15 return sd15
@pytest.fixture
def sdxl_lda_weights(test_weights_path: Path) -> Path:
sdxl_lda_weights = test_weights_path / "sdxl-lda.safetensors"
if not sdxl_lda_weights.is_file():
warn(message=f"could not find weights at {sdxl_lda_weights}, skipping")
pytest.skip(allow_module_level=True)
return sdxl_lda_weights
@pytest.fixture
def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path:
sdxl_lda_weights = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
if not sdxl_lda_weights.is_file():
warn(message=f"could not find weights at {sdxl_lda_weights}, skipping")
pytest.skip(allow_module_level=True)
return sdxl_lda_weights
@pytest.fixture
def sdxl_unet_weights(test_weights_path: Path) -> Path:
sdxl_unet_weights = test_weights_path / "sdxl-unet.safetensors"
if not sdxl_unet_weights.is_file():
warn(message=f"could not find weights at {sdxl_unet_weights}, skipping")
pytest.skip(allow_module_level=True)
return sdxl_unet_weights
@pytest.fixture
def sdxl_text_encoder_weights(test_weights_path: Path) -> Path:
sdxl_double_text_encoder_weights = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
if not sdxl_double_text_encoder_weights.is_file():
warn(message=f"could not find weights at {sdxl_double_text_encoder_weights}, skipping")
pytest.skip(allow_module_level=True)
return sdxl_double_text_encoder_weights
@pytest.fixture @pytest.fixture
def sdxl_ddim( def sdxl_ddim(
sdxl_text_encoder_weights: Path, sdxl_lda_weights: Path, sdxl_unet_weights: Path, test_device: torch.device sdxl_text_encoder_weights_path: Path,
sdxl_autoencoder_weights_path: Path,
sdxl_unet_weights_path: Path,
test_device: torch.device,
) -> StableDiffusion_XL: ) -> StableDiffusion_XL:
if test_device.type == "cpu": if test_device.type == "cpu":
warn(message="not running on CPU, skipping") warn(message="not running on CPU, skipping")
@ -817,16 +708,19 @@ def sdxl_ddim(
solver = DDIM(num_inference_steps=30) solver = DDIM(num_inference_steps=30)
sdxl = StableDiffusion_XL(solver=solver, device=test_device) sdxl = StableDiffusion_XL(solver=solver, device=test_device)
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights) sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights_path)
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_weights) sdxl.lda.load_from_safetensors(tensors_path=sdxl_autoencoder_weights_path)
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights) sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights_path)
return sdxl return sdxl
@pytest.fixture @pytest.fixture
def sdxl_ddim_lda_fp16_fix( def sdxl_ddim_lda_fp16_fix(
sdxl_text_encoder_weights: Path, sdxl_lda_fp16_fix_weights: Path, sdxl_unet_weights: Path, test_device: torch.device sdxl_text_encoder_weights_path: Path,
sdxl_autoencoder_fp16fix_weights_path: Path,
sdxl_unet_weights_path: Path,
test_device: torch.device,
) -> StableDiffusion_XL: ) -> StableDiffusion_XL:
if test_device.type == "cpu": if test_device.type == "cpu":
warn(message="not running on CPU, skipping") warn(message="not running on CPU, skipping")
@ -835,9 +729,9 @@ def sdxl_ddim_lda_fp16_fix(
solver = DDIM(num_inference_steps=30) solver = DDIM(num_inference_steps=30)
sdxl = StableDiffusion_XL(solver=solver, device=test_device) sdxl = StableDiffusion_XL(solver=solver, device=test_device)
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights) sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights_path)
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights) sdxl.lda.load_from_safetensors(tensors_path=sdxl_autoencoder_fp16fix_weights_path)
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights) sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights_path)
return sdxl return sdxl
@ -856,23 +750,18 @@ def sdxl_euler_deterministic(sdxl_ddim: StableDiffusion_XL) -> StableDiffusion_X
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def multi_upscaler( def multi_upscaler(
test_weights_path: Path, controlnet_tiles_weights_path: Path,
unet_weights_std: Path, sd15_text_encoder_weights_path: Path,
text_encoder_weights: Path, sd15_autoencoder_mse_weights_path: Path,
lda_ft_mse_weights: Path, sd15_unet_weights_path: Path,
test_device: torch.device, test_device: torch.device,
) -> MultiUpscaler: ) -> MultiUpscaler:
controlnet_tile_weights = test_weights_path / "controlnet" / "lllyasviel_control_v11f1e_sd15_tile.safetensors"
if not controlnet_tile_weights.is_file():
warn(message=f"could not find weights at {controlnet_tile_weights}, skipping")
pytest.skip(allow_module_level=True)
return MultiUpscaler( return MultiUpscaler(
checkpoints=UpscalerCheckpoints( checkpoints=UpscalerCheckpoints(
unet=unet_weights_std, unet=sd15_unet_weights_path,
clip_text_encoder=text_encoder_weights, clip_text_encoder=sd15_text_encoder_weights_path,
lda=lda_ft_mse_weights, lda=sd15_autoencoder_mse_weights_path,
controlnet_tile=controlnet_tile_weights, controlnet_tile=controlnet_tiles_weights_path,
), ),
device=test_device, device=test_device,
dtype=torch.float32, dtype=torch.float32,
@ -891,7 +780,9 @@ def expected_multi_upscaler(ref_path: Path) -> Image.Image:
@no_grad() @no_grad()
def test_diffusion_std_random_init( def test_diffusion_std_random_init(
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device sd15_std: StableDiffusion_1,
expected_image_std_random_init: Image.Image,
test_device: torch.device,
): ):
sd15 = sd15_std sd15 = sd15_std
@ -1553,6 +1444,9 @@ def test_diffusion_sdxl_control_lora(
adapters: dict[str, ControlLoraAdapter] = {} adapters: dict[str, ControlLoraAdapter] = {}
for config_name, config in configs.items(): for config_name, config in configs.items():
if not config.weights_path.is_file():
pytest.skip(f"File not found: {config.weights_path}")
adapter = ControlLoraAdapter( adapter = ControlLoraAdapter(
name=config_name, name=config_name,
scale=config.scale, scale=config.scale,
@ -1922,13 +1816,18 @@ def test_diffusion_textual_inversion_random_init(
@no_grad() @no_grad()
def test_diffusion_ella_adapter( def test_diffusion_ella_adapter(
sd15_std_float16: StableDiffusion_1, sd15_std_float16: StableDiffusion_1,
ella_weights: tuple[Path, Path], ella_sd15_tsc_t5xl_weights_path: Path,
t5xl_transformers_path: str,
expected_image_ella_adapter: Image.Image, expected_image_ella_adapter: Image.Image,
test_device: torch.device, test_device: torch.device,
use_local_weights: bool,
): ):
sd15 = sd15_std_float16 sd15 = sd15_std_float16
ella_adapter_weights, t5xl_weights = ella_weights t5_encoder = T5TextEmbedder(
t5_encoder = T5TextEmbedder(pretrained_path=t5xl_weights, max_length=128).to(test_device, torch.float16) pretrained_path=t5xl_transformers_path,
local_files_only=use_local_weights,
max_length=128,
).to(test_device, torch.float16)
prompt = "a chinese man wearing a white shirt and a checkered headscarf, holds a large falcon near his shoulder. the falcon has dark feathers with a distinctive beak. the background consists of a clear sky and a fence, suggesting an outdoor setting, possibly a desert or arid region" prompt = "a chinese man wearing a white shirt and a checkered headscarf, holds a large falcon near his shoulder. the falcon has dark feathers with a distinctive beak. the background consists of a clear sky and a fence, suggesting an outdoor setting, possibly a desert or arid region"
negative_prompt = "" negative_prompt = ""
@ -1938,7 +1837,7 @@ def test_diffusion_ella_adapter(
llm_text_embedding, negative_prompt_embeds = t5_encoder(prompt), t5_encoder(negative_prompt) llm_text_embedding, negative_prompt_embeds = t5_encoder(prompt), t5_encoder(negative_prompt)
prompt_embedding = torch.cat((negative_prompt_embeds, llm_text_embedding)).to(test_device, torch.float16) prompt_embedding = torch.cat((negative_prompt_embeds, llm_text_embedding)).to(test_device, torch.float16)
adapter = SD1ELLAAdapter(target=sd15.unet, weights=load_from_safetensors(ella_adapter_weights)) adapter = SD1ELLAAdapter(target=sd15.unet, weights=load_from_safetensors(ella_sd15_tsc_t5xl_weights_path))
adapter.inject() adapter.inject()
sd15.set_inference_steps(50) sd15.set_inference_steps(50)
manual_seed(1001) manual_seed(1001)
@ -1959,8 +1858,8 @@ def test_diffusion_ella_adapter(
@no_grad() @no_grad()
def test_diffusion_ip_adapter( def test_diffusion_ip_adapter(
sd15_ddim_lda_ft_mse: StableDiffusion_1, sd15_ddim_lda_ft_mse: StableDiffusion_1,
ip_adapter_weights: Path, ip_adapter_sd15_weights_path: Path,
image_encoder_weights: Path, clip_image_encoder_huge_weights_path: Path,
woman_image: Image.Image, woman_image: Image.Image,
expected_image_ip_adapter_woman: Image.Image, expected_image_ip_adapter_woman: Image.Image,
test_device: torch.device, test_device: torch.device,
@ -1976,8 +1875,8 @@ def test_diffusion_ip_adapter(
prompt = "best quality, high quality" prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_weights)) ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_sd15_weights_path))
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
ip_adapter.inject() ip_adapter.inject()
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
@ -2004,8 +1903,8 @@ def test_diffusion_ip_adapter(
@no_grad() @no_grad()
def test_diffusion_ip_adapter_multi( def test_diffusion_ip_adapter_multi(
sd15_ddim_lda_ft_mse: StableDiffusion_1, sd15_ddim_lda_ft_mse: StableDiffusion_1,
ip_adapter_weights: Path, ip_adapter_sd15_weights_path: Path,
image_encoder_weights: Path, clip_image_encoder_huge_weights_path: Path,
woman_image: Image.Image, woman_image: Image.Image,
statue_image: Image.Image, statue_image: Image.Image,
expected_image_ip_adapter_multi: Image.Image, expected_image_ip_adapter_multi: Image.Image,
@ -2016,8 +1915,8 @@ def test_diffusion_ip_adapter_multi(
prompt = "best quality, high quality" prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_weights)) ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_sd15_weights_path))
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
ip_adapter.inject() ip_adapter.inject()
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
@ -2044,8 +1943,8 @@ def test_diffusion_ip_adapter_multi(
@no_grad() @no_grad()
def test_diffusion_sdxl_ip_adapter( def test_diffusion_sdxl_ip_adapter(
sdxl_ddim: StableDiffusion_XL, sdxl_ddim: StableDiffusion_XL,
sdxl_ip_adapter_weights: Path, ip_adapter_sdxl_weights_path: Path,
image_encoder_weights: Path, clip_image_encoder_huge_weights_path: Path,
woman_image: Image.Image, woman_image: Image.Image,
expected_image_sdxl_ip_adapter_woman: Image.Image, expected_image_sdxl_ip_adapter_woman: Image.Image,
test_device: torch.device, test_device: torch.device,
@ -2055,8 +1954,8 @@ def test_diffusion_sdxl_ip_adapter(
prompt = "best quality, high quality" prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_weights)) ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(ip_adapter_sdxl_weights_path))
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
ip_adapter.inject() ip_adapter.inject()
with no_grad(): with no_grad():
@ -2093,8 +1992,8 @@ def test_diffusion_sdxl_ip_adapter(
@no_grad() @no_grad()
def test_diffusion_ip_adapter_controlnet( def test_diffusion_ip_adapter_controlnet(
sd15_ddim: StableDiffusion_1, sd15_ddim: StableDiffusion_1,
ip_adapter_weights: Path, ip_adapter_sd15_weights_path: Path,
image_encoder_weights: Path, clip_image_encoder_huge_weights_path: Path,
lora_data_pokemon: tuple[Image.Image, Path], lora_data_pokemon: tuple[Image.Image, Path],
controlnet_data_depth: tuple[str, Image.Image, Image.Image, Path], controlnet_data_depth: tuple[str, Image.Image, Image.Image, Path],
expected_image_ip_adapter_controlnet: Image.Image, expected_image_ip_adapter_controlnet: Image.Image,
@ -2107,8 +2006,8 @@ def test_diffusion_ip_adapter_controlnet(
prompt = "best quality, high quality" prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_weights)) ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_sd15_weights_path))
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
ip_adapter.inject() ip_adapter.inject()
depth_controlnet = SD1ControlnetAdapter( depth_controlnet = SD1ControlnetAdapter(
@ -2149,8 +2048,8 @@ def test_diffusion_ip_adapter_controlnet(
@no_grad() @no_grad()
def test_diffusion_ip_adapter_plus( def test_diffusion_ip_adapter_plus(
sd15_ddim_lda_ft_mse: StableDiffusion_1, sd15_ddim_lda_ft_mse: StableDiffusion_1,
ip_adapter_plus_weights: Path, ip_adapter_sd15_plus_weights_path: Path,
image_encoder_weights: Path, clip_image_encoder_huge_weights_path: Path,
statue_image: Image.Image, statue_image: Image.Image,
expected_image_ip_adapter_plus_statue: Image.Image, expected_image_ip_adapter_plus_statue: Image.Image,
test_device: torch.device, test_device: torch.device,
@ -2161,9 +2060,9 @@ def test_diffusion_ip_adapter_plus(
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SD1IPAdapter( ip_adapter = SD1IPAdapter(
target=sd15.unet, weights=load_from_safetensors(ip_adapter_plus_weights), fine_grained=True target=sd15.unet, weights=load_from_safetensors(ip_adapter_sd15_plus_weights_path), fine_grained=True
) )
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
ip_adapter.inject() ip_adapter.inject()
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
@ -2190,8 +2089,8 @@ def test_diffusion_ip_adapter_plus(
@no_grad() @no_grad()
def test_diffusion_sdxl_ip_adapter_plus( def test_diffusion_sdxl_ip_adapter_plus(
sdxl_ddim: StableDiffusion_XL, sdxl_ddim: StableDiffusion_XL,
sdxl_ip_adapter_plus_weights: Path, ip_adapter_sdxl_plus_weights_path: Path,
image_encoder_weights: Path, clip_image_encoder_huge_weights_path: Path,
woman_image: Image.Image, woman_image: Image.Image,
expected_image_sdxl_ip_adapter_plus_woman: Image.Image, expected_image_sdxl_ip_adapter_plus_woman: Image.Image,
test_device: torch.device, test_device: torch.device,
@ -2202,9 +2101,9 @@ def test_diffusion_sdxl_ip_adapter_plus(
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SDXLIPAdapter( ip_adapter = SDXLIPAdapter(
target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_plus_weights), fine_grained=True target=sdxl.unet, weights=load_from_safetensors(ip_adapter_sdxl_plus_weights_path), fine_grained=True
) )
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
ip_adapter.inject() ip_adapter.inject()
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
@ -2608,8 +2507,8 @@ def test_freeu(
def test_hello_world( def test_hello_world(
sdxl_ddim_lda_fp16_fix: StableDiffusion_XL, sdxl_ddim_lda_fp16_fix: StableDiffusion_XL,
t2i_adapter_xl_data_canny: tuple[str, Image.Image, Image.Image, Path], t2i_adapter_xl_data_canny: tuple[str, Image.Image, Image.Image, Path],
sdxl_ip_adapter_weights: Path, ip_adapter_sdxl_weights_path: Path,
image_encoder_weights: Path, clip_image_encoder_huge_weights_path: Path,
hello_world_assets: tuple[Image.Image, Image.Image, Image.Image, Image.Image], hello_world_assets: tuple[Image.Image, Image.Image, Image.Image, Image.Image],
) -> None: ) -> None:
sdxl = sdxl_ddim_lda_fp16_fix.to(dtype=torch.float16) sdxl = sdxl_ddim_lda_fp16_fix.to(dtype=torch.float16)
@ -2622,8 +2521,8 @@ def test_hello_world(
warn(f"could not find weights at {weights_path}, skipping") warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True) pytest.skip(allow_module_level=True)
ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_weights)) ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(ip_adapter_sdxl_weights_path))
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
ip_adapter.inject() ip_adapter.inject()
image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt)) image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt))
@ -2743,24 +2642,19 @@ def expected_ic_light(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_ic_light.png").convert("RGB") return _img_open(ref_path / "expected_ic_light.png").convert("RGB")
@pytest.fixture(scope="module")
def ic_light_sd15_fc_weights(test_weights_path: Path) -> Path:
return test_weights_path / "iclight_sd15_fc-refiners.safetensors"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ic_light_sd15_fc( def ic_light_sd15_fc(
ic_light_sd15_fc_weights: Path, ic_light_sd15_fc_weights_path: Path,
unet_weights_std: Path, sd15_unet_weights_path: Path,
lda_weights: Path, sd15_autoencoder_weights_path: Path,
text_encoder_weights: Path, sd15_text_encoder_weights_path: Path,
test_device: torch.device, test_device: torch.device,
) -> ICLight: ) -> ICLight:
return ICLight( return ICLight(
patch_weights=load_from_safetensors(ic_light_sd15_fc_weights), patch_weights=load_from_safetensors(ic_light_sd15_fc_weights_path),
unet=SD1UNet(in_channels=4).load_from_safetensors(unet_weights_std), unet=SD1UNet(in_channels=4).load_from_safetensors(sd15_unet_weights_path),
lda=SD1Autoencoder().load_from_safetensors(lda_weights), lda=SD1Autoencoder().load_from_safetensors(sd15_autoencoder_weights_path),
clip_text_encoder=CLIPTextEncoderL().load_from_safetensors(text_encoder_weights), clip_text_encoder=CLIPTextEncoderL().load_from_safetensors(sd15_text_encoder_weights_path),
device=test_device, device=test_device,
) )

View file

@ -29,74 +29,11 @@ def ref_path(test_e2e_path: Path) -> Path:
return test_e2e_path / "test_doc_examples_ref" return test_e2e_path / "test_doc_examples_ref"
@pytest.fixture(scope="module")
def sdxl_text_encoder_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
if not path.is_file():
warn(message=f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture(scope="module")
def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
if not path.is_file():
warn(message=f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture(scope="module")
def sdxl_unet_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "sdxl-unet.safetensors"
if not path.is_file():
warn(message=f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture(scope="module")
def sdxl_ip_adapter_plus_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "ip-adapter-plus_sdxl_vit-h.safetensors"
if not path.is_file():
warn(f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture(scope="module")
def image_encoder_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "CLIPImageEncoderH.safetensors"
if not path.is_file():
warn(f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture
def scifi_lora_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "loras" / "Sci-fi_Environments_sdxl.safetensors"
if not path.is_file():
warn(message=f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture
def pixelart_lora_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "loras" / "pixel-art-xl-v1.1.safetensors"
if not path.is_file():
warn(message=f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture @pytest.fixture
def sdxl( def sdxl(
sdxl_text_encoder_weights: Path, sdxl_text_encoder_weights_path: Path,
sdxl_lda_fp16_fix_weights: Path, sdxl_autoencoder_fp16fix_weights_path: Path,
sdxl_unet_weights: Path, sdxl_unet_weights_path: Path,
test_device: torch.device, test_device: torch.device,
) -> StableDiffusion_XL: ) -> StableDiffusion_XL:
if test_device.type == "cpu": if test_device.type == "cpu":
@ -105,9 +42,9 @@ def sdxl(
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16) sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16)
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights) sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights_path)
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights) sdxl.lda.load_from_safetensors(tensors_path=sdxl_autoencoder_fp16fix_weights_path)
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights) sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights_path)
return sdxl return sdxl
@ -180,7 +117,7 @@ def test_guide_adapting_sdxl_vanilla(
def test_guide_adapting_sdxl_single_lora( def test_guide_adapting_sdxl_single_lora(
test_device: torch.device, test_device: torch.device,
sdxl: StableDiffusion_XL, sdxl: StableDiffusion_XL,
scifi_lora_weights: Path, lora_scifi_weights_path: Path,
expected_image_guide_adapting_sdxl_single_lora: Image.Image, expected_image_guide_adapting_sdxl_single_lora: Image.Image,
) -> None: ) -> None:
if test_device.type == "cpu": if test_device.type == "cpu":
@ -195,7 +132,7 @@ def test_guide_adapting_sdxl_single_lora(
sdxl.set_self_attention_guidance(enable=True, scale=0.75) sdxl.set_self_attention_guidance(enable=True, scale=0.75)
manager = SDLoraManager(sdxl) manager = SDLoraManager(sdxl)
manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights)) manager.add_loras("scifi-lora", load_from_safetensors(lora_scifi_weights_path))
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt + ", best quality, high quality", text=prompt + ", best quality, high quality",
@ -222,8 +159,8 @@ def test_guide_adapting_sdxl_single_lora(
def test_guide_adapting_sdxl_multiple_loras( def test_guide_adapting_sdxl_multiple_loras(
test_device: torch.device, test_device: torch.device,
sdxl: StableDiffusion_XL, sdxl: StableDiffusion_XL,
scifi_lora_weights: Path, lora_scifi_weights_path: Path,
pixelart_lora_weights: Path, lora_pixelart_weights_path: Path,
expected_image_guide_adapting_sdxl_multiple_loras: Image.Image, expected_image_guide_adapting_sdxl_multiple_loras: Image.Image,
) -> None: ) -> None:
if test_device.type == "cpu": if test_device.type == "cpu":
@ -238,8 +175,8 @@ def test_guide_adapting_sdxl_multiple_loras(
sdxl.set_self_attention_guidance(enable=True, scale=0.75) sdxl.set_self_attention_guidance(enable=True, scale=0.75)
manager = SDLoraManager(sdxl) manager = SDLoraManager(sdxl)
manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights)) manager.add_loras("scifi-lora", load_from_safetensors(lora_scifi_weights_path))
manager.add_loras("pixel-art-lora", load_from_safetensors(pixelart_lora_weights), scale=1.4) manager.add_loras("pixel-art-lora", load_from_safetensors(lora_pixelart_weights_path), scale=1.4)
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt + ", best quality, high quality", text=prompt + ", best quality, high quality",
@ -266,10 +203,10 @@ def test_guide_adapting_sdxl_multiple_loras(
def test_guide_adapting_sdxl_loras_ip_adapter( def test_guide_adapting_sdxl_loras_ip_adapter(
test_device: torch.device, test_device: torch.device,
sdxl: StableDiffusion_XL, sdxl: StableDiffusion_XL,
sdxl_ip_adapter_plus_weights: Path, ip_adapter_sdxl_plus_weights_path: Path,
image_encoder_weights: Path, clip_image_encoder_huge_weights_path: Path,
scifi_lora_weights: Path, lora_scifi_weights_path: Path,
pixelart_lora_weights: Path, lora_pixelart_weights_path: Path,
image_prompt_german_castle: Image.Image, image_prompt_german_castle: Image.Image,
expected_image_guide_adapting_sdxl_loras_ip_adapter: Image.Image, expected_image_guide_adapting_sdxl_loras_ip_adapter: Image.Image,
) -> None: ) -> None:
@ -285,16 +222,16 @@ def test_guide_adapting_sdxl_loras_ip_adapter(
sdxl.set_self_attention_guidance(enable=True, scale=0.75) sdxl.set_self_attention_guidance(enable=True, scale=0.75)
manager = SDLoraManager(sdxl) manager = SDLoraManager(sdxl)
manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights), scale=1.5) manager.add_loras("scifi-lora", load_from_safetensors(lora_scifi_weights_path), scale=1.5)
manager.add_loras("pixel-art-lora", load_from_safetensors(pixelart_lora_weights), scale=1.55) manager.add_loras("pixel-art-lora", load_from_safetensors(lora_pixelart_weights_path), scale=1.55)
ip_adapter = SDXLIPAdapter( ip_adapter = SDXLIPAdapter(
target=sdxl.unet, target=sdxl.unet,
weights=load_from_safetensors(sdxl_ip_adapter_plus_weights), weights=load_from_safetensors(ip_adapter_sdxl_plus_weights_path),
scale=1.0, scale=1.0,
fine_grained=True, fine_grained=True,
) )
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
ip_adapter.inject() ip_adapter.inject()
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(

View file

@ -26,51 +26,6 @@ def ensure_gc():
gc.collect() gc.collect()
@pytest.fixture
def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path:
r = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
if not r.is_file():
warn(f"could not find weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture
def sdxl_unet_weights(test_weights_path: Path) -> Path:
r = test_weights_path / "sdxl-unet.safetensors"
if not r.is_file():
warn(f"could not find weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture
def sdxl_lcm_unet_weights(test_weights_path: Path) -> Path:
r = test_weights_path / "sdxl-lcm-unet.safetensors"
if not r.is_file():
warn(f"could not find weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture
def sdxl_text_encoder_weights(test_weights_path: Path) -> Path:
r = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
if not r.is_file():
warn(f"could not find weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture
def sdxl_lcm_lora_weights(test_weights_path: Path) -> Path:
r = test_weights_path / "sdxl-lcm-lora.safetensors"
if not r.is_file():
warn(f"could not find weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ref_path(test_e2e_path: Path) -> Path: def ref_path(test_e2e_path: Path) -> Path:
return test_e2e_path / "test_lcm_ref" return test_e2e_path / "test_lcm_ref"
@ -94,9 +49,9 @@ def expected_lcm_lora_1_2(ref_path: Path) -> Image.Image:
@no_grad() @no_grad()
def test_lcm_base( def test_lcm_base(
test_device: torch.device, test_device: torch.device,
sdxl_lda_fp16_fix_weights: Path, sdxl_autoencoder_fp16fix_weights_path: Path,
sdxl_lcm_unet_weights: Path, sdxl_unet_lcm_weights_path: Path,
sdxl_text_encoder_weights: Path, sdxl_text_encoder_weights_path: Path,
expected_lcm_base: Image.Image, expected_lcm_base: Image.Image,
) -> None: ) -> None:
if test_device.type == "cpu": if test_device.type == "cpu":
@ -111,9 +66,9 @@ def test_lcm_base(
# not in the diffusion loop. # not in the diffusion loop.
SDXLLcmAdapter(sdxl.unet, condition_scale=8.0).inject() SDXLLcmAdapter(sdxl.unet, condition_scale=8.0).inject()
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights) sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path)
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights) sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path)
sdxl.unet.load_from_safetensors(sdxl_lcm_unet_weights) sdxl.unet.load_from_safetensors(sdxl_unet_lcm_weights_path)
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
expected_image = expected_lcm_base expected_image = expected_lcm_base
@ -141,10 +96,10 @@ def test_lcm_base(
@pytest.mark.parametrize("condition_scale", [1.0, 1.2]) @pytest.mark.parametrize("condition_scale", [1.0, 1.2])
def test_lcm_lora_with_guidance( def test_lcm_lora_with_guidance(
test_device: torch.device, test_device: torch.device,
sdxl_lda_fp16_fix_weights: Path, sdxl_autoencoder_fp16fix_weights_path: Path,
sdxl_unet_weights: Path, sdxl_unet_weights_path: Path,
sdxl_text_encoder_weights: Path, sdxl_text_encoder_weights_path: Path,
sdxl_lcm_lora_weights: Path, lora_sdxl_lcm_weights_path: Path,
expected_lcm_lora_1_0: Image.Image, expected_lcm_lora_1_0: Image.Image,
expected_lcm_lora_1_2: Image.Image, expected_lcm_lora_1_2: Image.Image,
condition_scale: float, condition_scale: float,
@ -156,12 +111,12 @@ def test_lcm_lora_with_guidance(
solver = LCMSolver(num_inference_steps=4) solver = LCMSolver(num_inference_steps=4)
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver) sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver)
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights) sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path)
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights) sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path)
sdxl.unet.load_from_safetensors(sdxl_unet_weights) sdxl.unet.load_from_safetensors(sdxl_unet_weights_path)
manager = SDLoraManager(sdxl) manager = SDLoraManager(sdxl)
add_lcm_lora(manager, load_from_safetensors(sdxl_lcm_lora_weights)) add_lcm_lora(manager, load_from_safetensors(lora_sdxl_lcm_weights_path))
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
expected_image = expected_lcm_lora_1_0 if condition_scale == 1.0 else expected_lcm_lora_1_2 expected_image = expected_lcm_lora_1_0 if condition_scale == 1.0 else expected_lcm_lora_1_2
@ -191,10 +146,10 @@ def test_lcm_lora_with_guidance(
@no_grad() @no_grad()
def test_lcm_lora_without_guidance( def test_lcm_lora_without_guidance(
test_device: torch.device, test_device: torch.device,
sdxl_lda_fp16_fix_weights: Path, sdxl_autoencoder_fp16fix_weights_path: Path,
sdxl_unet_weights: Path, sdxl_unet_weights_path: Path,
sdxl_text_encoder_weights: Path, sdxl_text_encoder_weights_path: Path,
sdxl_lcm_lora_weights: Path, lora_sdxl_lcm_weights_path: Path,
expected_lcm_lora_1_0: Image.Image, expected_lcm_lora_1_0: Image.Image,
) -> None: ) -> None:
if test_device.type == "cpu": if test_device.type == "cpu":
@ -205,12 +160,12 @@ def test_lcm_lora_without_guidance(
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver) sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver)
sdxl.classifier_free_guidance = False sdxl.classifier_free_guidance = False
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights) sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path)
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights) sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path)
sdxl.unet.load_from_safetensors(sdxl_unet_weights) sdxl.unet.load_from_safetensors(sdxl_unet_weights_path)
manager = SDLoraManager(sdxl) manager = SDLoraManager(sdxl)
add_lcm_lora(manager, load_from_safetensors(sdxl_lcm_lora_weights)) add_lcm_lora(manager, load_from_safetensors(lora_sdxl_lcm_weights_path))
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
expected_image = expected_lcm_lora_1_0 expected_image = expected_lcm_lora_1_0

View file

@ -25,60 +25,6 @@ def ensure_gc():
gc.collect() gc.collect()
@pytest.fixture
def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path:
r = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
if not r.is_file():
warn(f"could not find weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture
def sdxl_unet_weights(test_weights_path: Path) -> Path:
r = test_weights_path / "sdxl-unet.safetensors"
if not r.is_file():
warn(f"could not find weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture
def sdxl_lightning_4step_unet_weights(test_weights_path: Path) -> Path:
r = test_weights_path / "sdxl_lightning_4step_unet.safetensors"
if not r.is_file():
warn(f"could not find weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture
def sdxl_lightning_1step_unet_weights(test_weights_path: Path) -> Path:
r = test_weights_path / "sdxl_lightning_1step_unet_x0.safetensors"
if not r.is_file():
warn(f"could not find weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture
def sdxl_text_encoder_weights(test_weights_path: Path) -> Path:
r = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
if not r.is_file():
warn(f"could not find weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture
def sdxl_lightning_4step_lora_weights(test_weights_path: Path) -> Path:
r = test_weights_path / "sdxl_lightning_4step_lora.safetensors"
if not r.is_file():
warn(f"could not find weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ref_path(test_e2e_path: Path) -> Path: def ref_path(test_e2e_path: Path) -> Path:
return test_e2e_path / "test_lightning_ref" return test_e2e_path / "test_lightning_ref"
@ -102,16 +48,16 @@ def expected_lightning_lora_4step(ref_path: Path) -> Image.Image:
@no_grad() @no_grad()
def test_lightning_base_4step( def test_lightning_base_4step(
test_device: torch.device, test_device: torch.device,
sdxl_lda_fp16_fix_weights: Path, sdxl_autoencoder_fp16fix_weights_path: Path,
sdxl_lightning_4step_unet_weights: Path, sdxl_unet_lightning_4step_weights_path: Path,
sdxl_text_encoder_weights: Path, sdxl_text_encoder_weights_path: Path,
expected_lightning_base_4step: Image.Image, expected_lightning_base_4step: Image.Image,
) -> None: ) -> None:
if test_device.type == "cpu": if test_device.type == "cpu":
warn(message="not running on CPU, skipping") warn(message="not running on CPU, skipping")
pytest.skip() pytest.skip()
unet_weights = sdxl_lightning_4step_unet_weights unet_weights = sdxl_unet_lightning_4step_weights_path
expected_image = expected_lightning_base_4step expected_image = expected_lightning_base_4step
solver = Euler( solver = Euler(
@ -125,8 +71,8 @@ def test_lightning_base_4step(
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver) sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver)
sdxl.classifier_free_guidance = False sdxl.classifier_free_guidance = False
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights) sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path)
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights) sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path)
sdxl.unet.load_from_safetensors(unet_weights) sdxl.unet.load_from_safetensors(unet_weights)
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
@ -153,16 +99,16 @@ def test_lightning_base_4step(
@no_grad() @no_grad()
def test_lightning_base_1step( def test_lightning_base_1step(
test_device: torch.device, test_device: torch.device,
sdxl_lda_fp16_fix_weights: Path, sdxl_autoencoder_fp16fix_weights_path: Path,
sdxl_lightning_1step_unet_weights: Path, sdxl_unet_lightning_1step_weights_path: Path,
sdxl_text_encoder_weights: Path, sdxl_text_encoder_weights_path: Path,
expected_lightning_base_1step: Image.Image, expected_lightning_base_1step: Image.Image,
) -> None: ) -> None:
if test_device.type == "cpu": if test_device.type == "cpu":
warn(message="not running on CPU, skipping") warn(message="not running on CPU, skipping")
pytest.skip() pytest.skip()
unet_weights = sdxl_lightning_1step_unet_weights unet_weights = sdxl_unet_lightning_1step_weights_path
expected_image = expected_lightning_base_1step expected_image = expected_lightning_base_1step
solver = Euler( solver = Euler(
@ -176,8 +122,8 @@ def test_lightning_base_1step(
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver) sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver)
sdxl.classifier_free_guidance = False sdxl.classifier_free_guidance = False
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights) sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path)
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights) sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path)
sdxl.unet.load_from_safetensors(unet_weights) sdxl.unet.load_from_safetensors(unet_weights)
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
@ -204,10 +150,10 @@ def test_lightning_base_1step(
@no_grad() @no_grad()
def test_lightning_lora_4step( def test_lightning_lora_4step(
test_device: torch.device, test_device: torch.device,
sdxl_lda_fp16_fix_weights: Path, sdxl_autoencoder_fp16fix_weights_path: Path,
sdxl_unet_weights: Path, sdxl_unet_weights_path: Path,
sdxl_text_encoder_weights: Path, sdxl_text_encoder_weights_path: Path,
sdxl_lightning_4step_lora_weights: Path, lora_sdxl_lightning_4step_weights_path: Path,
expected_lightning_lora_4step: Image.Image, expected_lightning_lora_4step: Image.Image,
) -> None: ) -> None:
if test_device.type == "cpu": if test_device.type == "cpu":
@ -227,12 +173,12 @@ def test_lightning_lora_4step(
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver) sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver)
sdxl.classifier_free_guidance = False sdxl.classifier_free_guidance = False
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights) sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path)
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights) sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path)
sdxl.unet.load_from_safetensors(sdxl_unet_weights) sdxl.unet.load_from_safetensors(sdxl_unet_weights_path)
manager = SDLoraManager(sdxl) manager = SDLoraManager(sdxl)
add_lcm_lora(manager, load_from_safetensors(sdxl_lightning_4step_lora_weights), name="lightning") add_lcm_lora(manager, load_from_safetensors(lora_sdxl_lightning_4step_weights_path), name="lightning")
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"

View file

@ -29,19 +29,10 @@ def expected_cactus_mask(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_cactus_mask.png") return _img_open(ref_path / "expected_cactus_mask.png")
@pytest.fixture(scope="module")
def mvanet_weights(test_weights_path: Path) -> Path:
weights = test_weights_path / "mvanet" / "mvanet.safetensors"
if not weights.is_file():
warn(f"could not find weights at {test_weights_path}, skipping")
pytest.skip(allow_module_level=True)
return weights
@pytest.fixture @pytest.fixture
def mvanet_model(mvanet_weights: Path, test_device: torch.device) -> MVANet: def mvanet_model(mvanet_weights_path: Path, test_device: torch.device) -> MVANet:
model = MVANet(device=test_device).eval() # .eval() is important! model = MVANet(device=test_device).eval() # .eval() is important!
model.load_from_safetensors(mvanet_weights) model.load_from_safetensors(mvanet_weights_path)
return model return model
@ -61,7 +52,7 @@ def test_mvanet(
@no_grad() @no_grad()
def test_mvanet_to( def test_mvanet_to(
mvanet_weights: Path, mvanet_weights_path: Path,
ref_cactus: Image.Image, ref_cactus: Image.Image,
expected_cactus_mask: Image.Image, expected_cactus_mask: Image.Image,
test_device: torch.device, test_device: torch.device,
@ -71,7 +62,7 @@ def test_mvanet_to(
pytest.skip() pytest.skip()
model = MVANet(device=torch.device("cpu")).eval() model = MVANet(device=torch.device("cpu")).eval()
model.load_from_safetensors(mvanet_weights) model.load_from_safetensors(mvanet_weights_path)
model.to(test_device) model.to(test_device)
in_t = image_to_tensor(ref_cactus.resize((1024, 1024), Image.Resampling.BILINEAR)).squeeze() in_t = image_to_tensor(ref_cactus.resize((1024, 1024), Image.Resampling.BILINEAR)).squeeze()

View file

@ -1,5 +1,4 @@
from pathlib import Path from pathlib import Path
from warnings import warn
import pytest import pytest
import torch import torch
@ -29,19 +28,13 @@ def expected_image_informative_drawings(diffusion_ref_path: Path) -> Image.Image
return _img_open(diffusion_ref_path / "cutecat_guide_lineart.png").convert("RGB") return _img_open(diffusion_ref_path / "cutecat_guide_lineart.png").convert("RGB")
@pytest.fixture(scope="module")
def informative_drawings_weights(test_weights_path: Path) -> Path:
weights = test_weights_path / "informative-drawings.safetensors"
if not weights.is_file():
warn(f"could not find weights at {test_weights_path}, skipping")
pytest.skip(allow_module_level=True)
return weights
@pytest.fixture @pytest.fixture
def informative_drawings_model(informative_drawings_weights: Path, test_device: torch.device) -> InformativeDrawings: def informative_drawings_model(
controlnet_preprocessor_info_drawings_weights_path: Path,
test_device: torch.device,
) -> InformativeDrawings:
model = InformativeDrawings(device=test_device) model = InformativeDrawings(device=test_device)
model.load_from_safetensors(informative_drawings_weights) model.load_from_safetensors(controlnet_preprocessor_info_drawings_weights_path)
return model return model

View file

@ -1,5 +1,4 @@
from pathlib import Path from pathlib import Path
from warnings import warn
import pytest import pytest
import torch import torch
@ -38,24 +37,15 @@ def expected_box_segmenter_spray_cropped_mask(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_box_segmenter_spray_cropped_mask.png") return _img_open(ref_path / "expected_box_segmenter_spray_cropped_mask.png")
@pytest.fixture(scope="module")
def box_segmenter_weights(test_weights_path: Path) -> Path:
weights = test_weights_path / "finegrain-box-segmenter-v0-1.safetensors"
if not weights.is_file():
warn(f"could not find weights at {test_weights_path}, skipping")
pytest.skip(allow_module_level=True)
return weights
def test_box_segmenter( def test_box_segmenter(
box_segmenter_weights: Path, box_segmenter_weights_path: Path,
ref_shelves: Image.Image, ref_shelves: Image.Image,
expected_box_segmenter_plant_mask: Image.Image, expected_box_segmenter_plant_mask: Image.Image,
expected_box_segmenter_spray_mask: Image.Image, expected_box_segmenter_spray_mask: Image.Image,
expected_box_segmenter_spray_cropped_mask: Image.Image, expected_box_segmenter_spray_cropped_mask: Image.Image,
test_device: torch.device, test_device: torch.device,
): ):
segmenter = BoxSegmenter(weights=box_segmenter_weights, device=test_device) segmenter = BoxSegmenter(weights=box_segmenter_weights_path, device=test_device)
plant_mask = segmenter(ref_shelves, box_prompt=(504, 82, 754, 368)) plant_mask = segmenter(ref_shelves, box_prompt=(504, 82, 754, 368))
ensure_similar_images(plant_mask.convert("RGB"), expected_box_segmenter_plant_mask.convert("RGB")) ensure_similar_images(plant_mask.convert("RGB"), expected_box_segmenter_plant_mask.convert("RGB"))

View file

@ -4,7 +4,7 @@ import torch
from torch import Tensor, nn from torch import Tensor, nn
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.model_converter import ConversionStage, ModelConverter from refiners.conversion.model_converter import ConversionStage, ModelConverter
from refiners.fluxion.utils import manual_seed from refiners.fluxion.utils import manual_seed

View file

@ -1,5 +1,4 @@
from pathlib import Path from pathlib import Path
from warnings import warn
import pytest import pytest
import torch import torch
@ -20,17 +19,13 @@ PROMPTS = [
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def our_encoder_with_new_concepts( def our_encoder_with_new_concepts(
test_weights_path: Path, sd15_text_encoder_weights_path: Path,
test_device: torch.device, test_device: torch.device,
cat_embedding_textual_inversion: torch.Tensor, cat_embedding_textual_inversion: torch.Tensor,
gta5_artwork_embedding_textual_inversion: torch.Tensor, gta5_artwork_embedding_textual_inversion: torch.Tensor,
) -> CLIPTextEncoderL: ) -> CLIPTextEncoderL:
weights = test_weights_path / "CLIPTextEncoderL.safetensors"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)
encoder = CLIPTextEncoderL(device=test_device) encoder = CLIPTextEncoderL(device=test_device)
tensors = load_from_safetensors(weights) tensors = load_from_safetensors(sd15_text_encoder_weights_path)
encoder.load_state_dict(tensors) encoder.load_state_dict(tensors)
concept_extender = ConceptExtender(encoder) concept_extender = ConceptExtender(encoder)
concept_extender.add_concept("<cat-toy>", cat_embedding_textual_inversion) concept_extender.add_concept("<cat-toy>", cat_embedding_textual_inversion)
@ -41,24 +36,21 @@ def our_encoder_with_new_concepts(
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ref_sd15_with_new_concepts( def ref_sd15_with_new_concepts(
runwayml_weights_path: Path, test_textual_inversion_path: Path, test_device: torch.device sd15_diffusers_runwayml_path: str,
test_textual_inversion_path: Path,
test_device: torch.device,
use_local_weights: bool,
) -> StableDiffusionPipeline: ) -> StableDiffusionPipeline:
pipe = StableDiffusionPipeline.from_pretrained(runwayml_weights_path).to(test_device) # type: ignore pipe = StableDiffusionPipeline.from_pretrained( # type: ignore
sd15_diffusers_runwayml_path,
local_files_only=use_local_weights,
).to(test_device) # type: ignore
assert isinstance(pipe, StableDiffusionPipeline) assert isinstance(pipe, StableDiffusionPipeline)
pipe.load_textual_inversion(test_textual_inversion_path / "cat-toy") # type: ignore pipe.load_textual_inversion(test_textual_inversion_path / "cat-toy") # type: ignore
pipe.load_textual_inversion(test_textual_inversion_path / "gta5-artwork") # type: ignore pipe.load_textual_inversion(test_textual_inversion_path / "gta5-artwork") # type: ignore
return pipe return pipe
@pytest.fixture(scope="module")
def runwayml_weights_path(test_weights_path: Path):
r = test_weights_path / "runwayml" / "stable-diffusion-v1-5"
if not r.is_dir():
warn(f"could not find RunwayML weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ref_tokenizer_with_new_concepts(ref_sd15_with_new_concepts: StableDiffusionPipeline) -> transformers.CLIPTokenizer: def ref_tokenizer_with_new_concepts(ref_sd15_with_new_concepts: StableDiffusionPipeline) -> transformers.CLIPTokenizer:
return ref_sd15_with_new_concepts.tokenizer # type: ignore return ref_sd15_with_new_concepts.tokenizer # type: ignore

View file

@ -1,5 +1,4 @@
from pathlib import Path from pathlib import Path
from warnings import warn
import pytest import pytest
import torch import torch
@ -11,39 +10,28 @@ from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def our_encoder( def our_encoder(
test_weights_path: Path, clip_image_encoder_huge_weights_path: Path,
test_device: torch.device, test_device: torch.device,
test_dtype_fp32_bf16_fp16: torch.dtype, test_dtype_fp32_bf16_fp16: torch.dtype,
) -> CLIPImageEncoderH: ) -> CLIPImageEncoderH:
weights = test_weights_path / "CLIPImageEncoderH.safetensors"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)
encoder = CLIPImageEncoderH(device=test_device, dtype=test_dtype_fp32_bf16_fp16) encoder = CLIPImageEncoderH(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
tensors = load_from_safetensors(weights) tensors = load_from_safetensors(clip_image_encoder_huge_weights_path)
encoder.load_state_dict(tensors) encoder.load_state_dict(tensors)
return encoder return encoder
@pytest.fixture(scope="module")
def stabilityai_unclip_weights_path(test_weights_path: Path):
r = test_weights_path / "stabilityai" / "stable-diffusion-2-1-unclip"
if not r.is_dir():
warn(f"could not find Stability AI weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ref_encoder( def ref_encoder(
stabilityai_unclip_weights_path: Path, unclip21_transformers_stabilityai_path: str,
test_device: torch.device, test_device: torch.device,
test_dtype_fp32_bf16_fp16: torch.dtype, test_dtype_fp32_bf16_fp16: torch.dtype,
use_local_weights: bool,
) -> CLIPVisionModelWithProjection: ) -> CLIPVisionModelWithProjection:
return CLIPVisionModelWithProjection.from_pretrained( # type: ignore return CLIPVisionModelWithProjection.from_pretrained( # type: ignore
stabilityai_unclip_weights_path, unclip21_transformers_stabilityai_path,
local_files_only=use_local_weights,
subfolder="image_encoder", subfolder="image_encoder",
).to(device=test_device, dtype=test_dtype_fp32_bf16_fp16) ).to(device=test_device, dtype=test_dtype_fp32_bf16_fp16) # type: ignore
@no_grad() @no_grad()

View file

@ -1,5 +1,4 @@
from pathlib import Path from pathlib import Path
from warnings import warn
import pytest import pytest
import torch import torch
@ -31,42 +30,39 @@ PROMPTS = [
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def our_encoder( def our_encoder(
test_weights_path: Path, sd15_text_encoder_weights_path: Path,
test_device: torch.device, test_device: torch.device,
test_dtype_fp32_fp16: torch.dtype, test_dtype_fp32_fp16: torch.dtype,
) -> CLIPTextEncoderL: ) -> CLIPTextEncoderL:
weights = test_weights_path / "CLIPTextEncoderL.safetensors"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)
tensors = load_from_safetensors(weights)
encoder = CLIPTextEncoderL(device=test_device, dtype=test_dtype_fp32_fp16) encoder = CLIPTextEncoderL(device=test_device, dtype=test_dtype_fp32_fp16)
tensors = load_from_safetensors(sd15_text_encoder_weights_path)
encoder.load_state_dict(tensors) encoder.load_state_dict(tensors)
return encoder return encoder
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def runwayml_weights_path(test_weights_path: Path): def ref_tokenizer(
r = test_weights_path / "runwayml" / "stable-diffusion-v1-5" sd15_diffusers_runwayml_path: str,
if not r.is_dir(): use_local_weights: bool,
warn(f"could not find RunwayML weights at {r}, skipping") ) -> transformers.CLIPTokenizer:
pytest.skip(allow_module_level=True) return transformers.CLIPTokenizer.from_pretrained( # type: ignore
return r sd15_diffusers_runwayml_path,
local_files_only=use_local_weights,
subfolder="tokenizer",
@pytest.fixture(scope="module") )
def ref_tokenizer(runwayml_weights_path: Path) -> transformers.CLIPTokenizer:
return transformers.CLIPTokenizer.from_pretrained(runwayml_weights_path, subfolder="tokenizer") # type: ignore
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ref_encoder( def ref_encoder(
runwayml_weights_path: Path, sd15_diffusers_runwayml_path: str,
test_device: torch.device, test_device: torch.device,
test_dtype_fp32_fp16: torch.dtype, test_dtype_fp32_fp16: torch.dtype,
use_local_weights: bool,
) -> transformers.CLIPTextModel: ) -> transformers.CLIPTextModel:
return transformers.CLIPTextModel.from_pretrained( # type: ignore return transformers.CLIPTextModel.from_pretrained( # type: ignore
runwayml_weights_path, sd15_diffusers_runwayml_path,
local_files_only=use_local_weights,
subfolder="text_encoder", subfolder="text_encoder",
).to(device=test_device, dtype=test_dtype_fp32_fp16) # type: ignore ).to(device=test_device, dtype=test_dtype_fp32_fp16) # type: ignore

View file

@ -4,6 +4,7 @@ from warnings import warn
import pytest import pytest
import torch import torch
from huggingface_hub import hf_hub_download # type: ignore
from refiners.fluxion.utils import load_from_safetensors, load_tensors, manual_seed, no_grad from refiners.fluxion.utils import load_from_safetensors, load_tensors, manual_seed, no_grad
from refiners.foundationals.dinov2.dinov2 import ( from refiners.foundationals.dinov2.dinov2 import (
@ -18,7 +19,7 @@ from refiners.foundationals.dinov2.dinov2 import (
) )
from refiners.foundationals.dinov2.vit import ViT from refiners.foundationals.dinov2.vit import ViT
FLAVORS_MAP = { FLAVORS_MAP_REFINERS = {
"dinov2_vits14": DINOv2_small, "dinov2_vits14": DINOv2_small,
"dinov2_vits14_reg": DINOv2_small_reg, "dinov2_vits14_reg": DINOv2_small_reg,
"dinov2_vitb14": DINOv2_base, "dinov2_vitb14": DINOv2_base,
@ -28,6 +29,27 @@ FLAVORS_MAP = {
"dinov2_vitg14": DINOv2_giant, "dinov2_vitg14": DINOv2_giant,
"dinov2_vitg14_reg": DINOv2_giant_reg, "dinov2_vitg14_reg": DINOv2_giant_reg,
} }
FLAVORS_MAP_HUB = {
"dinov2_vits14": "refiners/dinov2.small.patch_14",
"dinov2_vits14_reg": "refiners/dinov2.small.patch_14.reg_4",
"dinov2_vitb14": "refiners/dinov2.base.patch_14",
"dinov2_vitb14_reg": "refiners/dinov2.base.patch_14.reg_4",
"dinov2_vitl14": "refiners/dinov2.large.patch_14",
"dinov2_vitl14_reg": "refiners/dinov2.large.patch_14.reg_4",
"dinov2_vitg14": "refiners/dinov2.giant.patch_14",
"dinov2_vitg14_reg": "refiners/dinov2.giant.patch_14.reg_4",
}
@pytest.fixture(scope="module", params=["float16", "bfloat16"])
def dtype(request: pytest.FixtureRequest) -> torch.dtype:
match request.param:
case "float16":
return torch.float16
case "bfloat16":
return torch.bfloat16
case _ as dtype:
raise ValueError(f"unsupported dtype: {dtype}")
@pytest.fixture(scope="module", params=[224, 518]) @pytest.fixture(scope="module", params=[224, 518])
@ -35,7 +57,7 @@ def resolution(request: pytest.FixtureRequest) -> int:
return request.param return request.param
@pytest.fixture(scope="module", params=FLAVORS_MAP.keys()) @pytest.fixture(scope="module", params=FLAVORS_MAP_REFINERS.keys())
def flavor(request: pytest.FixtureRequest) -> str: def flavor(request: pytest.FixtureRequest) -> str:
return request.param return request.param
@ -53,7 +75,14 @@ def dinov2_repo_path(test_repos_path: Path) -> Path:
def ref_model( def ref_model(
flavor: str, flavor: str,
dinov2_repo_path: Path, dinov2_repo_path: Path,
test_weights_path: Path, dinov2_small_unconverted_weights_path: Path,
dinov2_small_reg4_unconverted_weights_path: Path,
dinov2_base_unconverted_weights_path: Path,
dinov2_base_reg4_unconverted_weights_path: Path,
dinov2_large_unconverted_weights_path: Path,
dinov2_large_reg4_unconverted_weights_path: Path,
dinov2_giant_unconverted_weights_path: Path,
dinov2_giant_reg4_unconverted_weights_path: Path,
test_device: torch.device, test_device: torch.device,
) -> torch.nn.Module: ) -> torch.nn.Module:
kwargs: dict[str, Any] = {} kwargs: dict[str, Any] = {}
@ -69,34 +98,51 @@ def ref_model(
) )
model = model.to(device=test_device) model = model.to(device=test_device)
flavor = flavor.replace("_reg", "_reg4") weight_map = {
weights = test_weights_path / f"{flavor}_pretrain.pth" "dinov2_vits14": dinov2_small_unconverted_weights_path,
if not weights.is_file(): "dinov2_vits14_reg": dinov2_small_reg4_unconverted_weights_path,
warn(f"could not find weights at {weights}, skipping") "dinov2_vitb14": dinov2_base_unconverted_weights_path,
pytest.skip(allow_module_level=True) "dinov2_vitb14_reg": dinov2_base_reg4_unconverted_weights_path,
model.load_state_dict(load_tensors(weights, device=test_device)) "dinov2_vitl14": dinov2_large_unconverted_weights_path,
"dinov2_vitl14_reg": dinov2_large_reg4_unconverted_weights_path,
"dinov2_vitg14": dinov2_giant_unconverted_weights_path,
"dinov2_vitg14_reg": dinov2_giant_reg4_unconverted_weights_path,
}
weights_path = weight_map[flavor]
model.load_state_dict(load_tensors(weights_path, device=test_device))
assert isinstance(model, torch.nn.Module) assert isinstance(model, torch.nn.Module)
return model return model
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def our_model( def our_model(
test_weights_path: Path,
flavor: str, flavor: str,
dinov2_small_weights_path: Path,
dinov2_small_reg4_weights_path: Path,
dinov2_base_weights_path: Path,
dinov2_base_reg4_weights_path: Path,
dinov2_large_weights_path: Path,
dinov2_large_reg4_weights_path: Path,
dinov2_giant_weights_path: Path,
dinov2_giant_reg4_weights_path: Path,
test_device: torch.device, test_device: torch.device,
) -> ViT: ) -> ViT:
model = FLAVORS_MAP[flavor](device=test_device) weight_map = {
"dinov2_vits14": dinov2_small_weights_path,
"dinov2_vits14_reg": dinov2_small_reg4_weights_path,
"dinov2_vitb14": dinov2_base_weights_path,
"dinov2_vitb14_reg": dinov2_base_reg4_weights_path,
"dinov2_vitl14": dinov2_large_weights_path,
"dinov2_vitl14_reg": dinov2_large_reg4_weights_path,
"dinov2_vitg14": dinov2_giant_weights_path,
"dinov2_vitg14_reg": dinov2_giant_reg4_weights_path,
}
weights_path = weight_map[flavor]
flavor = flavor.replace("_reg", "_reg4") model = FLAVORS_MAP_REFINERS[flavor](device=test_device)
weights = test_weights_path / f"{flavor}_pretrain.safetensors" tensors = load_from_safetensors(weights_path)
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)
tensors = load_from_safetensors(weights)
model.load_state_dict(tensors) model.load_state_dict(tensors)
return model return model

View file

@ -0,0 +1,139 @@
from pathlib import Path
import pytest
import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
from refiners.fluxion.utils import load_from_safetensors
from refiners.foundationals.latent_diffusion import (
CLIPTextEncoderL,
DoubleTextEncoder,
SD1Autoencoder,
SD1UNet,
SDXLAutoencoder,
SDXLUNet,
StableDiffusion_1,
StableDiffusion_XL,
)
@pytest.fixture(scope="module")
def refiners_sd15_autoencoder(sd15_autoencoder_weights_path: Path) -> SD1Autoencoder:
autoencoder = SD1Autoencoder()
tensors = load_from_safetensors(sd15_autoencoder_weights_path)
autoencoder.load_state_dict(tensors)
return autoencoder
@pytest.fixture(scope="module")
def refiners_sd15_unet(sd15_unet_weights_path: Path) -> SD1UNet:
unet = SD1UNet(in_channels=4)
tensors = load_from_safetensors(sd15_unet_weights_path)
unet.load_state_dict(tensors)
return unet
@pytest.fixture(scope="module")
def refiners_sd15_text_encoder(sd15_text_encoder_weights_path: Path) -> CLIPTextEncoderL:
text_encoder = CLIPTextEncoderL()
tensors = load_from_safetensors(sd15_text_encoder_weights_path)
text_encoder.load_state_dict(tensors)
return text_encoder
@pytest.fixture(scope="module")
def refiners_sd15(
refiners_sd15_autoencoder: SD1Autoencoder,
refiners_sd15_unet: SD1UNet,
refiners_sd15_text_encoder: CLIPTextEncoderL,
) -> StableDiffusion_1:
return StableDiffusion_1(
lda=refiners_sd15_autoencoder,
unet=refiners_sd15_unet,
clip_text_encoder=refiners_sd15_text_encoder,
)
@pytest.fixture(scope="module")
def refiners_sdxl_autoencoder(sdxl_autoencoder_weights_path: Path) -> SDXLAutoencoder:
autoencoder = SDXLAutoencoder()
tensors = load_from_safetensors(sdxl_autoencoder_weights_path)
autoencoder.load_state_dict(tensors)
return autoencoder
@pytest.fixture(scope="module")
def refiners_sdxl_unet(sdxl_unet_weights_path: Path) -> SDXLUNet:
unet = SDXLUNet(in_channels=4)
tensors = load_from_safetensors(sdxl_unet_weights_path)
unet.load_state_dict(tensors)
return unet
@pytest.fixture(scope="module")
def refiners_sdxl_text_encoder(sdxl_text_encoder_weights_path: Path) -> DoubleTextEncoder:
text_encoder = DoubleTextEncoder()
tensors = load_from_safetensors(sdxl_text_encoder_weights_path)
text_encoder.load_state_dict(tensors)
return text_encoder
@pytest.fixture(scope="module")
def refiners_sdxl(
refiners_sdxl_autoencoder: SDXLAutoencoder,
refiners_sdxl_unet: SDXLUNet,
refiners_sd15_text_encoder: DoubleTextEncoder,
) -> StableDiffusion_XL:
return StableDiffusion_XL(
lda=refiners_sdxl_autoencoder,
unet=refiners_sdxl_unet,
clip_text_encoder=refiners_sd15_text_encoder,
)
@pytest.fixture(scope="module", params=["SD1.5", "SDXL"])
def refiners_autoencoder(
request: pytest.FixtureRequest,
refiners_sd15_autoencoder: SD1Autoencoder,
refiners_sdxl_autoencoder: SDXLAutoencoder,
test_dtype_fp32_bf16_fp16: torch.dtype,
) -> SD1Autoencoder | SDXLAutoencoder:
model_version = request.param
match (model_version, test_dtype_fp32_bf16_fp16):
case ("SD1.5", _):
return refiners_sd15_autoencoder
case ("SDXL", torch.float16):
return refiners_sdxl_autoencoder
case ("SDXL", _):
return refiners_sdxl_autoencoder
case _:
raise ValueError(f"Unknown model version: {model_version}")
@pytest.fixture(scope="module")
def diffusers_sd15_pipeline(
sd15_diffusers_runwayml_path: str,
use_local_weights: bool,
) -> StableDiffusionPipeline:
return StableDiffusionPipeline.from_pretrained( # type: ignore
sd15_diffusers_runwayml_path,
local_files_only=use_local_weights,
)
@pytest.fixture(scope="module")
def diffusers_sdxl_pipeline(
sdxl_diffusers_stabilityai_path: str,
use_local_weights: bool,
) -> StableDiffusionXLPipeline:
return StableDiffusionXLPipeline.from_pretrained( # type: ignore
sdxl_diffusers_stabilityai_path,
local_files_only=use_local_weights,
)
@pytest.fixture(scope="module")
def diffusers_sdxl_unet(diffusers_sdxl_pipeline: StableDiffusionXLPipeline) -> UNet2DConditionModel:
return diffusers_sdxl_pipeline.unet # type: ignore

View file

@ -1,134 +0,0 @@
from pathlib import Path
from warnings import warn
import pytest
import torch
from PIL import Image
from tests.utils import ensure_similar_images
from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion import LatentDiffusionAutoencoder, SD1Autoencoder, SDXLAutoencoder
@pytest.fixture(scope="module")
def ref_path() -> Path:
return Path(__file__).parent / "test_auto_encoder_ref"
@pytest.fixture(scope="module", params=["SD1.5", "SDXL"])
def lda(
request: pytest.FixtureRequest,
test_weights_path: Path,
test_dtype_fp32_bf16_fp16: torch.dtype,
test_device: torch.device,
) -> LatentDiffusionAutoencoder:
model_version = request.param
match (model_version, test_dtype_fp32_bf16_fp16):
case ("SD1.5", _):
weight_path = test_weights_path / "lda.safetensors"
if not weight_path.is_file():
warn(f"could not find weights at {weight_path}, skipping")
pytest.skip(allow_module_level=True)
model = SD1Autoencoder().load_from_safetensors(weight_path)
case ("SDXL", torch.float16):
weight_path = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
if not weight_path.is_file():
warn(f"could not find weights at {weight_path}, skipping")
pytest.skip(allow_module_level=True)
model = SDXLAutoencoder().load_from_safetensors(weight_path)
case ("SDXL", _):
weight_path = test_weights_path / "sdxl-lda.safetensors"
if not weight_path.is_file():
warn(f"could not find weights at {weight_path}, skipping")
pytest.skip(allow_module_level=True)
model = SDXLAutoencoder().load_from_safetensors(weight_path)
case _:
raise ValueError(f"Unknown model version: {model_version}")
model = model.to(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
return model
@pytest.fixture(scope="module")
def sample_image(ref_path: Path) -> Image.Image:
test_image = ref_path / "macaw.png"
if not test_image.is_file():
warn(f"could not reference image at {test_image}, skipping")
pytest.skip(allow_module_level=True)
img = Image.open(test_image) # type: ignore
assert img.size == (512, 512)
return img
@no_grad()
def test_encode_decode_image(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
encoded = lda.image_to_latents(sample_image)
decoded = lda.latents_to_image(encoded)
assert decoded.mode == "RGB" # type: ignore
# Ensure no saturation. The green channel (band = 1) must not max out.
assert max(iter(decoded.getdata(band=1))) < 255 # type: ignore
ensure_similar_images(sample_image, decoded, min_psnr=20, min_ssim=0.9)
@no_grad()
def test_encode_decode_images(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
encoded = lda.images_to_latents([sample_image, sample_image])
images = lda.latents_to_images(encoded)
assert isinstance(images, list)
assert len(images) == 2
ensure_similar_images(sample_image, images[1], min_psnr=20, min_ssim=0.9)
@no_grad()
def test_tiled_autoencoder(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
sample_image = sample_image.resize((2048, 2048)) # type: ignore
with lda.tiled_inference(sample_image, tile_size=(512, 512)):
encoded = lda.tiled_image_to_latents(sample_image)
result = lda.tiled_latents_to_image(encoded)
ensure_similar_images(sample_image, result, min_psnr=35, min_ssim=0.985)
@no_grad()
def test_tiled_autoencoder_rectangular_tiles(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
sample_image = sample_image.resize((2048, 2048)) # type: ignore
with lda.tiled_inference(sample_image, tile_size=(512, 1024)):
encoded = lda.tiled_image_to_latents(sample_image)
result = lda.tiled_latents_to_image(encoded)
ensure_similar_images(sample_image, result, min_psnr=35, min_ssim=0.985)
@no_grad()
def test_tiled_autoencoder_large_tile(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
sample_image = sample_image.resize((1024, 1024)) # type: ignore
with lda.tiled_inference(sample_image, tile_size=(2048, 2048)):
encoded = lda.tiled_image_to_latents(sample_image)
result = lda.tiled_latents_to_image(encoded)
ensure_similar_images(sample_image, result, min_psnr=34, min_ssim=0.975)
@no_grad()
def test_tiled_autoencoder_rectangular_image(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
sample_image = sample_image.crop((0, 0, 300, 500))
sample_image = sample_image.resize((sample_image.width * 4, sample_image.height * 4)) # type: ignore
with lda.tiled_inference(sample_image, tile_size=(512, 512)):
encoded = lda.tiled_image_to_latents(sample_image)
result = lda.tiled_latents_to_image(encoded)
ensure_similar_images(sample_image, result, min_psnr=37, min_ssim=0.985)
def test_value_error_tile_encode_no_context(lda: LatentDiffusionAutoencoder, sample_image: Image.Image) -> None:
with pytest.raises(ValueError):
lda.tiled_image_to_latents(sample_image)
with pytest.raises(ValueError):
lda.tiled_latents_to_image(torch.randn(1, 8, 16, 16, device=lda.device))

View file

@ -0,0 +1,104 @@
from pathlib import Path
from warnings import warn
import pytest
import torch
from PIL import Image
from tests.utils import ensure_similar_images
from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
@pytest.fixture(scope="module")
def sample_image() -> Image.Image:
test_image = Path(__file__).parent / "test_auto_encoder_ref" / "macaw.png"
if not test_image.is_file():
warn(f"could not reference image at {test_image}, skipping")
pytest.skip(allow_module_level=True)
img = Image.open(test_image) # type: ignore
assert img.size == (512, 512)
return img
@pytest.fixture(scope="module")
def autoencoder(
refiners_autoencoder: LatentDiffusionAutoencoder,
test_device: torch.device,
) -> LatentDiffusionAutoencoder:
return refiners_autoencoder.to(test_device)
@no_grad()
def test_encode_decode_image(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
encoded = autoencoder.image_to_latents(sample_image)
decoded = autoencoder.latents_to_image(encoded)
assert decoded.mode == "RGB" # type: ignore
# Ensure no saturation. The green channel (band = 1) must not max out.
assert max(iter(decoded.getdata(band=1))) < 255 # type: ignore
ensure_similar_images(sample_image, decoded, min_psnr=20, min_ssim=0.9)
@no_grad()
def test_encode_decode_images(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
encoded = autoencoder.images_to_latents([sample_image, sample_image])
images = autoencoder.latents_to_images(encoded)
assert isinstance(images, list)
assert len(images) == 2
ensure_similar_images(sample_image, images[1], min_psnr=20, min_ssim=0.9)
@no_grad()
def test_tiled_autoencoder(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
sample_image = sample_image.resize((2048, 2048)) # type: ignore
with autoencoder.tiled_inference(sample_image, tile_size=(512, 512)):
encoded = autoencoder.tiled_image_to_latents(sample_image)
result = autoencoder.tiled_latents_to_image(encoded)
ensure_similar_images(sample_image, result, min_psnr=35, min_ssim=0.985)
@no_grad()
def test_tiled_autoencoder_rectangular_tiles(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
sample_image = sample_image.resize((2048, 2048)) # type: ignore
with autoencoder.tiled_inference(sample_image, tile_size=(512, 1024)):
encoded = autoencoder.tiled_image_to_latents(sample_image)
result = autoencoder.tiled_latents_to_image(encoded)
ensure_similar_images(sample_image, result, min_psnr=35, min_ssim=0.985)
@no_grad()
def test_tiled_autoencoder_large_tile(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
sample_image = sample_image.resize((1024, 1024)) # type: ignore
with autoencoder.tiled_inference(sample_image, tile_size=(2048, 2048)):
encoded = autoencoder.tiled_image_to_latents(sample_image)
result = autoencoder.tiled_latents_to_image(encoded)
ensure_similar_images(sample_image, result, min_psnr=34, min_ssim=0.975)
@no_grad()
def test_tiled_autoencoder_rectangular_image(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
sample_image = sample_image.crop((0, 0, 300, 500))
sample_image = sample_image.resize((sample_image.width * 4, sample_image.height * 4)) # type: ignore
with autoencoder.tiled_inference(sample_image, tile_size=(512, 512)):
encoded = autoencoder.tiled_image_to_latents(sample_image)
result = autoencoder.tiled_latents_to_image(encoded)
ensure_similar_images(sample_image, result, min_psnr=37, min_ssim=0.985)
def test_value_error_tile_encode_no_context(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image) -> None:
with pytest.raises(ValueError):
autoencoder.tiled_image_to_latents(sample_image)
with pytest.raises(ValueError):
autoencoder.tiled_latents_to_image(torch.randn(1, 8, 16, 16, device=autoencoder.device))

View file

@ -1,31 +0,0 @@
import torch
from PIL import Image
from refiners.fluxion.utils import manual_seed, no_grad
from refiners.foundationals.latent_diffusion import StableDiffusion_1_Inpainting
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
@no_grad()
def test_sample_noise():
manual_seed(2)
latents_0 = LatentDiffusionModel.sample_noise(size=(1, 4, 64, 64))
manual_seed(2)
latents_1 = LatentDiffusionModel.sample_noise(size=(1, 4, 64, 64), offset_noise=0.0)
assert torch.allclose(latents_0, latents_1, atol=1e-6, rtol=0)
@no_grad()
def test_sd1_inpainting(test_device: torch.device) -> None:
sd = StableDiffusion_1_Inpainting(device=test_device)
latent_noise = torch.randn(1, 4, 64, 64, device=test_device)
target_image = Image.new("RGB", (512, 512))
mask = Image.new("L", (512, 512))
sd.set_inpainting_conditions(target_image=target_image, mask=mask)
text_embedding = sd.compute_clip_text_embedding("")
output = sd(latent_noise, step=0, clip_text_embedding=text_embedding)
assert output.shape == (1, 4, 64, 64)

View file

@ -0,0 +1,65 @@
import torch
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from torch import Tensor
from refiners.fluxion.utils import manual_seed, no_grad
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
@no_grad()
def test_text_encoder(
diffusers_sd15_pipeline: StableDiffusionPipeline,
refiners_sd15_text_encoder: CLIPTextEncoderL,
) -> None:
"""Compare our refiners implementation with the diffusers implementation."""
manual_seed(seed=0) # unnecessary, but just in case
prompt = "A photo of a pizza."
negative_prompt = ""
atol = 1e-2 # FIXME: very high tolerance, figure out why
( # encode text prompts using diffusers pipeline
diffusers_embeds, # type: ignore
diffusers_negative_embeds, # type: ignore
) = diffusers_sd15_pipeline.encode_prompt( # type: ignore
prompt=prompt,
negative_prompt=negative_prompt,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
device=diffusers_sd15_pipeline.device,
)
assert isinstance(diffusers_embeds, Tensor)
assert isinstance(diffusers_negative_embeds, Tensor)
# encode text prompts using refiners model
refiners_embeds = refiners_sd15_text_encoder(prompt)
refiners_negative_embeds = refiners_sd15_text_encoder("")
# check that the shapes are the same
assert diffusers_embeds.shape == refiners_embeds.shape == (1, 77, 768)
assert diffusers_negative_embeds.shape == refiners_negative_embeds.shape == (1, 77, 768)
# check that the values are close
assert torch.allclose(input=refiners_embeds, other=diffusers_embeds, atol=atol)
assert torch.allclose(input=refiners_negative_embeds, other=diffusers_negative_embeds, atol=atol)
@no_grad()
def test_text_encoder_batched(refiners_sd15_text_encoder: CLIPTextEncoderL) -> None:
"""Check that encoding two prompts works as expected whether batched or not."""
manual_seed(seed=0) # unnecessary, but just in case
prompt1 = "A photo of a pizza."
prompt2 = "A giant duck."
atol = 1e-6
# encode the two prompts at once
embeds_batched = refiners_sd15_text_encoder([prompt1, prompt2])
assert embeds_batched.shape == (2, 77, 768)
# encode the prompts one by one
embeds_1 = refiners_sd15_text_encoder(prompt1)
embeds_2 = refiners_sd15_text_encoder(prompt2)
assert embeds_1.shape == embeds_2.shape == (1, 77, 768)
# check that the values are close
assert torch.allclose(input=embeds_1, other=embeds_batched[0].unsqueeze(0), atol=atol)
assert torch.allclose(input=embeds_2, other=embeds_batched[1].unsqueeze(0), atol=atol)

View file

@ -1,121 +0,0 @@
from pathlib import Path
from typing import Any, Protocol, cast
from warnings import warn
import pytest
import torch
from torch import Tensor
import refiners.fluxion.layers as fl
from refiners.fluxion.utils import manual_seed, no_grad
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
class DiffusersSDXL(Protocol):
unet: fl.Module
text_encoder: fl.Module
text_encoder_2: fl.Module
tokenizer: fl.Module
tokenizer_2: fl.Module
vae: fl.Module
def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any: ...
def encode_prompt(
self,
prompt: str,
prompt_2: str | None = None,
negative_prompt: str | None = None,
negative_prompt_2: str | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ...
@pytest.fixture(scope="module")
def stabilityai_sdxl_base_path(test_weights_path: Path) -> Path:
r = test_weights_path / "stabilityai" / "stable-diffusion-xl-base-1.0"
if not r.is_dir():
warn(message=f"could not find Stability SDXL base weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture(scope="module")
def double_text_encoder_weights(test_weights_path: Path) -> Path:
text_encoder_weights = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
if not text_encoder_weights.is_file():
warn(f"could not find weights at {text_encoder_weights}, skipping")
pytest.skip(allow_module_level=True)
return text_encoder_weights
@pytest.fixture(scope="module")
def diffusers_sdxl(stabilityai_sdxl_base_path: Path) -> Any:
from diffusers import DiffusionPipeline # type: ignore
return DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=stabilityai_sdxl_base_path) # type: ignore
@pytest.fixture(scope="module")
def double_text_encoder(double_text_encoder_weights: Path) -> DoubleTextEncoder:
double_text_encoder = DoubleTextEncoder()
double_text_encoder.load_from_safetensors(double_text_encoder_weights)
return double_text_encoder
@no_grad()
def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder: DoubleTextEncoder) -> None:
manual_seed(seed=0)
prompt = "A photo of a pizza."
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = diffusers_sdxl.encode_prompt(prompt=prompt, negative_prompt="")
double_embedding, pooled_embedding = double_text_encoder(prompt)
assert double_embedding.shape == torch.Size([1, 77, 2048])
assert pooled_embedding.shape == torch.Size([1, 1280])
embedding_1, embedding_2 = cast(
tuple[Tensor, Tensor],
prompt_embeds.split(split_size=[768, 1280], dim=-1), # type: ignore
)
rembedding_1, rembedding_2 = cast(
tuple[Tensor, Tensor],
double_embedding.split(split_size=[768, 1280], dim=-1), # type: ignore
)
assert torch.allclose(input=embedding_1, other=rembedding_1, rtol=1e-3, atol=1e-3)
assert torch.allclose(input=embedding_2, other=rembedding_2, rtol=1e-3, atol=1e-3)
assert torch.allclose(input=pooled_embedding, other=pooled_prompt_embeds, rtol=1e-3, atol=1e-3)
negative_double_embedding, negative_pooled_embedding = double_text_encoder("")
assert torch.allclose(input=negative_double_embedding, other=negative_prompt_embeds, rtol=1e-3, atol=1e-3)
assert torch.allclose(input=negative_pooled_embedding, other=negative_pooled_prompt_embeds, rtol=1e-3, atol=1e-3)
@no_grad()
def test_double_text_encoder_batch2(double_text_encoder: DoubleTextEncoder) -> None:
manual_seed(seed=0)
prompt1 = "A photo of a pizza."
prompt2 = "A giant duck."
double_embedding_b2, pooled_embedding_b2 = double_text_encoder([prompt1, prompt2])
assert double_embedding_b2.shape == torch.Size([2, 77, 2048])
assert pooled_embedding_b2.shape == torch.Size([2, 1280])
double_embedding_1, pooled_embedding_1 = double_text_encoder(prompt1)
double_embedding_2, pooled_embedding_2 = double_text_encoder(prompt2)
assert torch.allclose(input=double_embedding_1, other=double_embedding_b2[0:1], rtol=1e-3, atol=1e-3)
assert torch.allclose(input=pooled_embedding_1, other=pooled_embedding_b2[0:1], rtol=1e-3, atol=1e-3)
assert torch.allclose(input=double_embedding_2, other=double_embedding_b2[1:2], rtol=1e-3, atol=1e-3)
assert torch.allclose(input=pooled_embedding_2, other=pooled_embedding_b2[1:2], rtol=1e-3, atol=1e-3)

View file

@ -0,0 +1,70 @@
import torch
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
from torch import Tensor
from refiners.fluxion.utils import manual_seed, no_grad
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
@no_grad()
def test_double_text_encoder(
diffusers_sdxl_pipeline: StableDiffusionXLPipeline,
refiners_sdxl_text_encoder: DoubleTextEncoder,
) -> None:
"""Compare our refiners implementation with the diffusers implementation."""
manual_seed(seed=0) # unnecessary, but just in case
prompt = "A photo of a pizza."
negative_prompt = ""
atol = 1e-6
( # encode text prompts using diffusers pipeline
diffusers_embeds,
diffusers_negative_embeds, # type: ignore
diffusers_pooled_embeds, # type: ignore
diffusers_negative_pooled_embeds, # type: ignore
) = diffusers_sdxl_pipeline.encode_prompt(prompt=prompt, negative_prompt=negative_prompt)
assert diffusers_negative_embeds is not None
assert isinstance(diffusers_pooled_embeds, Tensor)
assert isinstance(diffusers_negative_pooled_embeds, Tensor)
# encode text prompts using refiners model
refiners_embeds, refiners_pooled_embeds = refiners_sdxl_text_encoder(prompt)
refiners_negative_embeds, refiners_negative_pooled_embeds = refiners_sdxl_text_encoder("")
# check that the shapes are the same
assert diffusers_embeds.shape == refiners_embeds.shape == (1, 77, 2048)
assert diffusers_pooled_embeds.shape == refiners_pooled_embeds.shape == (1, 1280)
assert diffusers_negative_embeds.shape == refiners_negative_embeds.shape == (1, 77, 2048)
assert diffusers_negative_pooled_embeds.shape == refiners_negative_pooled_embeds.shape == (1, 1280)
# check that the values are close
assert torch.allclose(input=refiners_embeds, other=diffusers_embeds, atol=atol)
assert torch.allclose(input=refiners_negative_embeds, other=diffusers_negative_embeds, atol=atol)
assert torch.allclose(input=refiners_negative_pooled_embeds, other=diffusers_negative_pooled_embeds, atol=atol)
assert torch.allclose(input=refiners_pooled_embeds, other=diffusers_pooled_embeds, atol=atol)
@no_grad()
def test_double_text_encoder_batched(refiners_sdxl_text_encoder: DoubleTextEncoder) -> None:
"""Check that encoding two prompts works as expected whether batched or not."""
manual_seed(seed=0) # unnecessary, but just in case
prompt1 = "A photo of a pizza."
prompt2 = "A giant duck."
atol = 1e-6
# encode the two prompts at once
embeds_batched, pooled_embeds_batched = refiners_sdxl_text_encoder([prompt1, prompt2])
assert embeds_batched.shape == (2, 77, 2048)
assert pooled_embeds_batched.shape == (2, 1280)
# encode the prompts one by one
embeds_1, pooled_embeds_1 = refiners_sdxl_text_encoder(prompt1)
embeds_2, pooled_embeds_2 = refiners_sdxl_text_encoder(prompt2)
assert embeds_1.shape == embeds_2.shape == (1, 77, 2048)
assert pooled_embeds_1.shape == pooled_embeds_2.shape == (1, 1280)
# check that the values are close
assert torch.allclose(input=embeds_1, other=embeds_batched[0].unsqueeze(0), atol=atol)
assert torch.allclose(input=pooled_embeds_1, other=pooled_embeds_batched[0].unsqueeze(0), atol=atol)
assert torch.allclose(input=embeds_2, other=embeds_batched[1].unsqueeze(0), atol=atol)
assert torch.allclose(input=pooled_embeds_2, other=pooled_embeds_batched[1].unsqueeze(0), atol=atol)

View file

@ -1,36 +1,13 @@
from pathlib import Path
from typing import Any from typing import Any
from warnings import warn
import pytest import pytest
import torch import torch
from refiners.fluxion.model_converter import ConversionStage, ModelConverter from refiners.conversion.model_converter import ConversionStage, ModelConverter
from refiners.fluxion.utils import manual_seed, no_grad from refiners.fluxion.utils import manual_seed, no_grad
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
@pytest.fixture(scope="module")
def stabilityai_sdxl_base_path(test_weights_path: Path) -> Path:
r = test_weights_path / "stabilityai" / "stable-diffusion-xl-base-1.0"
if not r.is_dir():
warn(f"could not find Stability SDXL base weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture(scope="module")
def diffusers_sdxl(stabilityai_sdxl_base_path: Path) -> Any:
from diffusers import DiffusionPipeline # type: ignore
return DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=stabilityai_sdxl_base_path) # type: ignore
@pytest.fixture(scope="module")
def diffusers_sdxl_unet(diffusers_sdxl: Any) -> Any:
return diffusers_sdxl.unet
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def refiners_sdxl_unet() -> SDXLUNet: def refiners_sdxl_unet() -> SDXLUNet:
unet = SDXLUNet(in_channels=4) unet = SDXLUNet(in_channels=4)
@ -38,7 +15,10 @@ def refiners_sdxl_unet() -> SDXLUNet:
@no_grad() @no_grad()
def test_sdxl_unet(diffusers_sdxl_unet: Any, refiners_sdxl_unet: SDXLUNet) -> None: def test_sdxl_unet(
diffusers_sdxl_unet: Any,
refiners_sdxl_unet: SDXLUNet,
) -> None:
source = diffusers_sdxl_unet source = diffusers_sdxl_unet
target = refiners_sdxl_unet target = refiners_sdxl_unet

View file

@ -1,8 +1,7 @@
import gc import gc
from pathlib import Path from pathlib import Path
from warnings import warn
from pytest import fixture, skip from pytest import fixture
@fixture(autouse=True) @fixture(autouse=True)
@ -15,12 +14,3 @@ def ensure_gc():
@fixture(scope="package") @fixture(scope="package")
def ref_path(test_sam_path: Path) -> Path: def ref_path(test_sam_path: Path) -> Path:
return test_sam_path / "test_sam_ref" return test_sam_path / "test_sam_ref"
@fixture(scope="package")
def sam_h_weights(test_weights_path: Path) -> Path:
sam_h_weights = test_weights_path / "segment-anything-h.safetensors"
if not sam_h_weights.is_file():
warn(f"could not find weights at {sam_h_weights}, skipping")
skip(allow_module_level=True)
return sam_h_weights

View file

@ -1,6 +1,5 @@
from pathlib import Path from pathlib import Path
from typing import cast from typing import cast
from warnings import warn
import numpy as np import numpy as np
import pytest import pytest
@ -36,37 +35,17 @@ def tennis(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "tennis.png").convert("RGB") # type: ignore return Image.open(ref_path / "tennis.png").convert("RGB") # type: ignore
@pytest.fixture(scope="module")
def hq_adapter_weights(test_weights_path: Path) -> Path:
"""Path to the HQ adapter weights in Refiners format"""
refiners_hq_adapter_sam_weights = test_weights_path / "refiners-sam-hq-vit-h.safetensors"
if not refiners_hq_adapter_sam_weights.is_file():
warn(f"Test weights not found at {refiners_hq_adapter_sam_weights}, skipping")
pytest.skip(allow_module_level=True)
return refiners_hq_adapter_sam_weights
@pytest.fixture @pytest.fixture
def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH: def sam_h(sam_h_weights_path: Path, test_device: torch.device) -> SegmentAnythingH:
# HQSAMAdapter is designed to be used with single-output only, hence multimask_output=False. # HQSAMAdapter is designed to be used with single-output only, hence multimask_output=False.
sam_h = SegmentAnythingH(multimask_output=False, device=test_device) sam_h = SegmentAnythingH(multimask_output=False, device=test_device)
sam_h.load_from_safetensors(tensors_path=sam_h_weights) sam_h.load_from_safetensors(tensors_path=sam_h_weights_path)
return sam_h return sam_h
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def reference_hq_adapter_weights(test_weights_path: Path) -> Path: def reference_sam_h(sam_h_hq_adapter_unconverted_weights_path: Path, test_device: torch.device) -> FacebookSAM:
"""Path to the HQ adapter weights in default format""" sam_h = cast(FacebookSAM, sam_model_registry_hq["vit_h"](checkpoint=sam_h_hq_adapter_unconverted_weights_path))
reference_hq_adapter_sam_weights = test_weights_path / "sam_hq_vit_h.pth"
if not reference_hq_adapter_sam_weights.is_file():
warn(f"Test weights not found at {reference_hq_adapter_sam_weights}, skipping")
pytest.skip(allow_module_level=True)
return reference_hq_adapter_sam_weights
@pytest.fixture(scope="module")
def reference_sam_h(reference_hq_adapter_weights: Path, test_device: torch.device) -> FacebookSAM:
sam_h = cast(FacebookSAM, sam_model_registry_hq["vit_h"](checkpoint=reference_hq_adapter_weights))
return sam_h.to(device=test_device) return sam_h.to(device=test_device)
@ -142,11 +121,11 @@ def test_mask_decoder_tokens_extender() -> None:
@no_grad() @no_grad()
def test_early_vit_embedding( def test_early_vit_embedding(
sam_h: SegmentAnythingH, sam_h: SegmentAnythingH,
hq_adapter_weights: Path, sam_h_hq_adapter_weights_path: Path,
reference_sam_h: FacebookSAM, reference_sam_h: FacebookSAM,
tennis: Image.Image, tennis: Image.Image,
) -> None: ) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject()
image_tensor = image_to_tensor(image=tennis.resize(size=(1024, 1024))) # type: ignore image_tensor = image_to_tensor(image=tennis.resize(size=(1024, 1024))) # type: ignore
@ -159,8 +138,8 @@ def test_early_vit_embedding(
assert torch.equal(early_vit_embedding, early_vit_embedding_refiners) assert torch.equal(early_vit_embedding, early_vit_embedding_refiners)
def test_tokens(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None: def test_tokens(sam_h: SegmentAnythingH, sam_h_hq_adapter_weights_path: Path, reference_sam_h: FacebookSAM) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject()
mask_decoder_tokens_extender = sam_h.mask_decoder.ensure_find(MaskDecoderTokensExtender) mask_decoder_tokens_extender = sam_h.mask_decoder.ensure_find(MaskDecoderTokensExtender)
@ -175,8 +154,10 @@ def test_tokens(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam
@no_grad() @no_grad()
def test_compress_vit_feat(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None: def test_compress_vit_feat(
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() sam_h: SegmentAnythingH, sam_h_hq_adapter_weights_path: Path, reference_sam_h: FacebookSAM
) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject()
early_vit_embedding = torch.randn(1, 64, 64, 1280, device=sam_h.device, dtype=sam_h.dtype) early_vit_embedding = torch.randn(1, 64, 64, 1280, device=sam_h.device, dtype=sam_h.dtype)
@ -189,8 +170,10 @@ def test_compress_vit_feat(sam_h: SegmentAnythingH, hq_adapter_weights: Path, re
@no_grad() @no_grad()
def test_embedding_encoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None: def test_embedding_encoder(
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() sam_h: SegmentAnythingH, sam_h_hq_adapter_weights_path: Path, reference_sam_h: FacebookSAM
) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject()
x = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype) x = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype)
@ -203,8 +186,10 @@ def test_embedding_encoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path, re
@no_grad() @no_grad()
def test_hq_token_mlp(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None: def test_hq_token_mlp(
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() sam_h: SegmentAnythingH, sam_h_hq_adapter_weights_path: Path, reference_sam_h: FacebookSAM
) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject()
x = torch.randn(1, 6, 256, device=sam_h.device, dtype=sam_h.dtype) x = torch.randn(1, 6, 256, device=sam_h.device, dtype=sam_h.dtype)
@ -217,13 +202,13 @@ def test_hq_token_mlp(sam_h: SegmentAnythingH, hq_adapter_weights: Path, referen
@pytest.mark.parametrize("hq_mask_only", [True, False]) @pytest.mark.parametrize("hq_mask_only", [True, False])
def test_predictor( def test_predictor(
sam_h: SegmentAnythingH, sam_h: SegmentAnythingH,
hq_adapter_weights: Path, sam_h_hq_adapter_weights_path: Path,
hq_mask_only: bool, hq_mask_only: bool,
reference_sam_h_predictor: FacebookSAMPredictorHQ, reference_sam_h_predictor: FacebookSAMPredictorHQ,
tennis: Image.Image, tennis: Image.Image,
one_prompt: SAMPrompt, one_prompt: SAMPrompt,
) -> None: ) -> None:
adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject()
adapter.hq_mask_only = hq_mask_only adapter.hq_mask_only = hq_mask_only
assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only
@ -268,13 +253,13 @@ def test_predictor(
@pytest.mark.parametrize("hq_mask_only", [True, False]) @pytest.mark.parametrize("hq_mask_only", [True, False])
def test_predictor_equal( def test_predictor_equal(
sam_h: SegmentAnythingH, sam_h: SegmentAnythingH,
hq_adapter_weights: Path, sam_h_hq_adapter_weights_path: Path,
hq_mask_only: bool, hq_mask_only: bool,
reference_sam_h_predictor: FacebookSAMPredictorHQ, reference_sam_h_predictor: FacebookSAMPredictorHQ,
tennis: Image.Image, tennis: Image.Image,
one_prompt: SAMPrompt, one_prompt: SAMPrompt,
) -> None: ) -> None:
adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject()
adapter.hq_mask_only = hq_mask_only adapter.hq_mask_only = hq_mask_only
assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only
@ -318,8 +303,8 @@ def test_predictor_equal(
@no_grad() @no_grad()
def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) -> None: def test_batch_mask_decoder(sam_h: SegmentAnythingH, sam_h_hq_adapter_weights_path: Path) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject()
batch_size = 5 batch_size = 5
@ -348,8 +333,10 @@ def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) -
assert torch.equal(mask_prediction[0], mask_prediction[1]) assert torch.equal(mask_prediction[0], mask_prediction[1])
def test_hq_sam_load_save_weights(sam_h: SegmentAnythingH, hq_adapter_weights: Path, test_device: torch.device) -> None: def test_hq_sam_load_save_weights(
weights = load_from_safetensors(hq_adapter_weights, device=test_device) sam_h: SegmentAnythingH, sam_h_hq_adapter_weights_path: Path, test_device: torch.device
) -> None:
weights = load_from_safetensors(sam_h_hq_adapter_weights_path, device=test_device)
hq_sam_adapter = HQSAMAdapter(sam_h) hq_sam_adapter = HQSAMAdapter(sam_h)
out_weights_init = hq_sam_adapter.weights out_weights_init = hq_sam_adapter.weights

View file

@ -1,7 +1,6 @@
from math import isclose from math import isclose
from pathlib import Path from pathlib import Path
from typing import cast from typing import cast
from warnings import warn
import numpy as np import numpy as np
import pytest import pytest
@ -17,8 +16,8 @@ from tests.foundationals.segment_anything.utils import (
from torch import Tensor from torch import Tensor
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.conversion.model_converter import ModelConverter
from refiners.fluxion import manual_seed from refiners.fluxion import manual_seed
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
@ -49,20 +48,11 @@ def one_prompt() -> SAMPrompt:
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def facebook_sam_h_weights(test_weights_path: Path) -> Path: def facebook_sam_h(sam_h_unconverted_weights_path: Path, test_device: torch.device) -> FacebookSAM:
sam_h_weights = test_weights_path / "sam_vit_h_4b8939.pth"
if not sam_h_weights.is_file():
warn(f"could not find weights at {sam_h_weights}, skipping")
pytest.skip(allow_module_level=True)
return sam_h_weights
@pytest.fixture(scope="module")
def facebook_sam_h(facebook_sam_h_weights: Path, test_device: torch.device) -> FacebookSAM:
from segment_anything import build_sam_vit_h # type: ignore from segment_anything import build_sam_vit_h # type: ignore
sam_h = cast(FacebookSAM, build_sam_vit_h()) sam_h = cast(FacebookSAM, build_sam_vit_h())
sam_h.load_state_dict(state_dict=load_tensors(facebook_sam_h_weights)) sam_h.load_state_dict(state_dict=load_tensors(sam_h_unconverted_weights_path))
return sam_h.to(device=test_device) return sam_h.to(device=test_device)
@ -76,16 +66,16 @@ def facebook_sam_h_predictor(facebook_sam_h: FacebookSAM) -> FacebookSAMPredicto
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH: def sam_h(sam_h_weights_path: Path, test_device: torch.device) -> SegmentAnythingH:
sam_h = SegmentAnythingH(device=test_device) sam_h = SegmentAnythingH(device=test_device)
sam_h.load_from_safetensors(tensors_path=sam_h_weights) sam_h.load_from_safetensors(tensors_path=sam_h_weights_path)
return sam_h return sam_h
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def sam_h_single_output(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH: def sam_h_single_output(sam_h_weights_path: Path, test_device: torch.device) -> SegmentAnythingH:
sam_h = SegmentAnythingH(multimask_output=False, device=test_device) sam_h = SegmentAnythingH(multimask_output=False, device=test_device)
sam_h.load_from_safetensors(tensors_path=sam_h_weights) sam_h.load_from_safetensors(tensors_path=sam_h_weights_path)
return sam_h return sam_h
@ -469,7 +459,10 @@ def test_predictor_resized_single_output(
def test_mask_encoder( def test_mask_encoder(
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt facebook_sam_h_predictor: FacebookSAMPredictor,
sam_h: SegmentAnythingH,
truck: Image.Image,
one_prompt: SAMPrompt,
) -> None: ) -> None:
predictor = facebook_sam_h_predictor predictor = facebook_sam_h_predictor
predictor.set_image(np.array(truck)) predictor.set_image(np.array(truck))

View file

@ -1,5 +1,4 @@
from pathlib import Path from pathlib import Path
from warnings import warn
import pytest import pytest
import torch import torch
@ -25,16 +24,11 @@ class CifarDataset(Dataset[torch.Tensor]):
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def dinov2_l( def dinov2_l(
test_weights_path: Path, dinov2_large_weights_path: Path,
test_device: torch.device, test_device: torch.device,
) -> dinov2.DINOv2_large: ) -> dinov2.DINOv2_large:
weights = test_weights_path / f"dinov2_vitl14_pretrain.safetensors"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)
model = dinov2.DINOv2_large(device=test_device) model = dinov2.DINOv2_large(device=test_device)
model.load_from_safetensors(weights) model.load_from_safetensors(dinov2_large_weights_path)
return model return model

View file

@ -24,11 +24,20 @@ def ensure_similar_images(img_1: Image.Image, img_2: Image.Image, min_psnr: int
class T5TextEmbedder(nn.Module): class T5TextEmbedder(nn.Module):
def __init__( def __init__(
self, pretrained_path: Path = Path("tests/weights/QQGYLab/T5XLFP16"), max_length: int | None = None self,
pretrained_path: Path | str,
max_length: int | None = None,
local_files_only: bool = False,
) -> None: ) -> None:
super().__init__() # type: ignore[reportUnknownMemberType] super().__init__() # type: ignore[reportUnknownMemberType]
self.model: nn.Module = T5EncoderModel.from_pretrained(pretrained_path, local_files_only=True) # type: ignore self.model: nn.Module = T5EncoderModel.from_pretrained( # type: ignore
self.tokenizer: transformers.T5Tokenizer = T5Tokenizer.from_pretrained(pretrained_path, local_files_only=True) # type: ignore pretrained_path,
local_files_only=local_files_only,
)
self.tokenizer: transformers.T5Tokenizer = T5Tokenizer.from_pretrained( # type: ignore
pretrained_path,
local_files_only=local_files_only,
)
self.max_length = max_length self.max_length = max_length
def forward( def forward(