diff --git a/tests/adapters/test_lora_manager.py b/tests/adapters/test_lora_manager.py index 4e30ce8..c2f2e9f 100644 --- a/tests/adapters/test_lora_manager.py +++ b/tests/adapters/test_lora_manager.py @@ -1,5 +1,4 @@ from pathlib import Path -from warnings import warn import pytest import torch @@ -16,14 +15,8 @@ def manager() -> SDLoraManager: @pytest.fixture -def weights(test_weights_path: Path) -> dict[str, torch.Tensor]: - weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin" - - if not weights_path.is_file(): - warn(f"could not find weights at {weights_path}, skipping") - pytest.skip(allow_module_level=True) - - return load_tensors(weights_path) +def weights(lora_pokemon_weights_path: Path) -> dict[str, torch.Tensor]: + return load_tensors(lora_pokemon_weights_path) def test_add_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None: diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index a9e92df..8b2e211 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -9,6 +9,8 @@ import torch from PIL import Image from tests.utils import T5TextEmbedder, ensure_similar_images +from refiners.conversion import controllora_sdxl +from refiners.conversion.utils import Hub from refiners.fluxion.layers.attentions import ScaledDotProductAttention from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad from refiners.foundationals.clip.concepts import ConceptExtender @@ -45,6 +47,8 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler i from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAdapter +from ..weight_paths import get_path + def _img_open(path: Path) -> Image.Image: return Image.open(path) # type: ignore @@ -194,69 +198,85 @@ def expected_style_aligned(ref_path: Path) -> Image.Image: @pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"]) def controlnet_data( - ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest + ref_path: Path, + controlnet_depth_weights_path: Path, + controlnet_canny_weights_path: Path, + controlnet_lineart_weights_path: Path, + controlnet_normals_weights_path: Path, + controlnet_sam_weights_path: Path, + request: pytest.FixtureRequest, ) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]: cn_name: str = request.param condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") - weights_fn = { - "depth": "lllyasviel_control_v11f1p_sd15_depth", - "canny": "lllyasviel_control_v11p_sd15_canny", - "lineart": "lllyasviel_control_v11p_sd15_lineart", - "normals": "lllyasviel_control_v11p_sd15_normalbae", - "sam": "mfidabel_controlnet-segment-anything", - } - weights_path = test_weights_path / "controlnet" / f"{weights_fn[cn_name]}.safetensors" - yield (cn_name, condition_image, expected_image, weights_path) + weights_fn = { + "depth": controlnet_depth_weights_path, + "canny": controlnet_canny_weights_path, + "lineart": controlnet_lineart_weights_path, + "normals": controlnet_normals_weights_path, + "sam": controlnet_sam_weights_path, + } + weights_path = weights_fn[cn_name] + + yield cn_name, condition_image, expected_image, weights_path @pytest.fixture(scope="module", params=["canny"]) def controlnet_data_scale_decay( - ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest + ref_path: Path, + controlnet_canny_weights_path: Path, + request: pytest.FixtureRequest, ) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]: cn_name: str = request.param condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}_scale_decay.png").convert("RGB") - weights_fn = { - "canny": "lllyasviel_control_v11p_sd15_canny", - } - weights_path = test_weights_path / "controlnet" / f"{weights_fn[cn_name]}.safetensors" + weights_fn = { + "canny": controlnet_canny_weights_path, + } + weights_path = weights_fn[cn_name] + yield (cn_name, condition_image, expected_image, weights_path) @pytest.fixture(scope="module") -def controlnet_data_tile(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, Image.Image, Path]: +def controlnet_data_tile( + ref_path: Path, + controlnet_tiles_weights_path: Path, +) -> tuple[Image.Image, Image.Image, Path]: condition_image = _img_open(ref_path / f"low_res_dog.png").convert("RGB").resize((1024, 1024)) # type: ignore expected_image = _img_open(ref_path / f"expected_controlnet_tile.png").convert("RGB") - weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11f1e_sd15_tile.safetensors" - return condition_image, expected_image, weights_path + return condition_image, expected_image, controlnet_tiles_weights_path @pytest.fixture(scope="module") -def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: +def controlnet_data_canny( + ref_path: Path, + controlnet_canny_weights_path: Path, +) -> tuple[str, Image.Image, Image.Image, Path]: cn_name = "canny" condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") - weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11p_sd15_canny.safetensors" - return cn_name, condition_image, expected_image, weights_path + return cn_name, condition_image, expected_image, controlnet_canny_weights_path @pytest.fixture(scope="module") -def controlnet_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: +def controlnet_data_depth( + ref_path: Path, + controlnet_depth_weights_path: Path, +) -> tuple[str, Image.Image, Image.Image, Path]: cn_name = "depth" condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") - weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11f1p_sd15_depth.safetensors" - return cn_name, condition_image, expected_image, weights_path + return cn_name, condition_image, expected_image, controlnet_depth_weights_path @dataclass class ControlLoraConfig: scale: float condition_path: str - weights_path: str + weights: Hub @dataclass @@ -271,38 +291,38 @@ CONTROL_LORA_CONFIGS: dict[str, dict[str, ControlLoraConfig]] = { "PyraCanny": ControlLoraConfig( scale=1.0, condition_path="cutecat_guide_PyraCanny.png", - weights_path="refiners_control-lora-canny-rank128.safetensors", + weights=controllora_sdxl.canny.converted, ), }, "expected_controllora_CPDS.png": { "CPDS": ControlLoraConfig( scale=1.0, condition_path="cutecat_guide_CPDS.png", - weights_path="refiners_fooocus_xl_cpds_128.safetensors", + weights=controllora_sdxl.cpds.converted, ), }, "expected_controllora_PyraCanny+CPDS.png": { "PyraCanny": ControlLoraConfig( scale=0.55, condition_path="cutecat_guide_PyraCanny.png", - weights_path="refiners_control-lora-canny-rank128.safetensors", + weights=controllora_sdxl.canny.converted, ), "CPDS": ControlLoraConfig( scale=0.55, condition_path="cutecat_guide_CPDS.png", - weights_path="refiners_fooocus_xl_cpds_128.safetensors", + weights=controllora_sdxl.cpds.converted, ), }, "expected_controllora_disabled.png": { "PyraCanny": ControlLoraConfig( scale=0.0, condition_path="cutecat_guide_PyraCanny.png", - weights_path="refiners_control-lora-canny-rank128.safetensors", + weights=controllora_sdxl.canny.converted, ), "CPDS": ControlLoraConfig( scale=0.0, condition_path="cutecat_guide_CPDS.png", - weights_path="refiners_fooocus_xl_cpds_128.safetensors", + weights=controllora_sdxl.cpds.converted, ), }, } @@ -311,8 +331,8 @@ CONTROL_LORA_CONFIGS: dict[str, dict[str, ControlLoraConfig]] = { @pytest.fixture(params=CONTROL_LORA_CONFIGS.items()) def controllora_sdxl_config( request: pytest.FixtureRequest, + use_local_weights: bool, ref_path: Path, - test_weights_path: Path, ) -> tuple[Image.Image, dict[str, ControlLoraResolvedConfig]]: name: str = request.param[0] configs: dict[str, ControlLoraConfig] = request.param[1] @@ -322,7 +342,7 @@ def controllora_sdxl_config( config_name: ControlLoraResolvedConfig( scale=config.scale, condition_image=_img_open(ref_path / config.condition_path).convert("RGB"), - weights_path=test_weights_path / "control-loras" / config.weights_path, + weights_path=get_path(config.weights, use_local_weights), ) for config_name, config in configs.items() } @@ -331,66 +351,57 @@ def controllora_sdxl_config( @pytest.fixture(scope="module") -def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: +def t2i_adapter_data_depth( + ref_path: Path, + t2i_depth_weights_path: Path, +) -> tuple[str, Image.Image, Image.Image, Path]: name = "depth" condition_image = _img_open(ref_path / f"cutecat_guide_{name}.png").convert("RGB") expected_image = _img_open(ref_path / f"expected_t2i_adapter_{name}.png").convert("RGB") - weights_path = test_weights_path / "T2I-Adapter" / "t2iadapter_depth_sd15v2.safetensors" - return name, condition_image, expected_image, weights_path + return name, condition_image, expected_image, t2i_depth_weights_path @pytest.fixture(scope="module") -def t2i_adapter_xl_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: +def t2i_adapter_xl_data_canny( + ref_path: Path, + t2i_sdxl_canny_weights_path: Path, +) -> tuple[str, Image.Image, Image.Image, Path]: name = "canny" condition_image = _img_open(ref_path / f"fairy_guide_{name}.png").convert("RGB") expected_image = _img_open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB") - weights_path = test_weights_path / "T2I-Adapter" / "t2i-adapter-canny-sdxl-1.0.safetensors" - - if not weights_path.is_file(): - warn(f"could not find weights at {weights_path}, skipping") - pytest.skip(allow_module_level=True) - - return name, condition_image, expected_image, weights_path + return name, condition_image, expected_image, t2i_sdxl_canny_weights_path @pytest.fixture(scope="module") -def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]: +def lora_data_pokemon( + ref_path: Path, + lora_pokemon_weights_path: Path, +) -> tuple[Image.Image, dict[str, torch.Tensor]]: expected_image = _img_open(ref_path / "expected_lora_pokemon.png").convert("RGB") - weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin" - - if not weights_path.is_file(): - warn(f"could not find weights at {weights_path}, skipping") - pytest.skip(allow_module_level=True) - - tensors = load_tensors(weights_path) + tensors = load_tensors(lora_pokemon_weights_path) return expected_image, tensors @pytest.fixture(scope="module") -def lora_data_dpo(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]: +def lora_data_dpo( + ref_path: Path, + lora_dpo_weights_path: Path, +) -> tuple[Image.Image, dict[str, torch.Tensor]]: expected_image = _img_open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB") - weights_path = test_weights_path / "loras" / "dpo-lora" / "pytorch_lora_weights.safetensors" - - if not weights_path.is_file(): - warn(f"could not find weights at {weights_path}, skipping") - pytest.skip(allow_module_level=True) - - tensors = load_from_safetensors(weights_path) + tensors = load_from_safetensors(lora_dpo_weights_path) return expected_image, tensors @pytest.fixture(scope="module") -def lora_sliders(test_weights_path: Path) -> tuple[dict[str, dict[str, torch.Tensor]], dict[str, float]]: - weights_path = test_weights_path / "loras" / "sliders" - - if not weights_path.is_dir(): - warn(f"could not find weights at {weights_path}, skipping") - pytest.skip(allow_module_level=True) - +def lora_sliders( + lora_slider_age_weights_path: Path, + lora_slider_cartoon_style_weights_path: Path, + lora_slider_eyesize_weights_path: Path, +) -> tuple[dict[str, dict[str, torch.Tensor]], dict[str, float]]: return { - "age": load_tensors(weights_path / "age.pt"), # type: ignore - "cartoon_style": load_tensors(weights_path / "cartoon_style.pt"), # type: ignore - "eyesize": load_tensors(weights_path / "eyesize.pt"), # type: ignore + "age": load_tensors(lora_slider_age_weights_path), + "cartoon_style": load_tensors(lora_slider_cartoon_style_weights_path), + "eyesize": load_tensors(lora_slider_eyesize_weights_path), }, { "age": 0.3, "cartoon_style": -0.2, @@ -477,122 +488,12 @@ def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch return load_tensors(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")[""] -@pytest.fixture(scope="module") -def text_encoder_weights(test_weights_path: Path) -> Path: - text_encoder_weights = test_weights_path / "CLIPTextEncoderL.safetensors" - if not text_encoder_weights.is_file(): - warn(f"could not find weights at {text_encoder_weights}, skipping") - pytest.skip(allow_module_level=True) - return text_encoder_weights - - -@pytest.fixture(scope="module") -def lda_weights(test_weights_path: Path) -> Path: - lda_weights = test_weights_path / "lda.safetensors" - if not lda_weights.is_file(): - warn(f"could not find weights at {lda_weights}, skipping") - pytest.skip(allow_module_level=True) - return lda_weights - - -@pytest.fixture(scope="module") -def unet_weights_std(test_weights_path: Path) -> Path: - unet_weights_std = test_weights_path / "unet.safetensors" - if not unet_weights_std.is_file(): - warn(f"could not find weights at {unet_weights_std}, skipping") - pytest.skip(allow_module_level=True) - return unet_weights_std - - -@pytest.fixture(scope="module") -def unet_weights_inpainting(test_weights_path: Path) -> Path: - unet_weights_inpainting = test_weights_path / "inpainting" / "unet.safetensors" - if not unet_weights_inpainting.is_file(): - warn(f"could not find weights at {unet_weights_inpainting}, skipping") - pytest.skip(allow_module_level=True) - return unet_weights_inpainting - - -@pytest.fixture(scope="module") -def lda_ft_mse_weights(test_weights_path: Path) -> Path: - lda_weights = test_weights_path / "lda_ft_mse.safetensors" - if not lda_weights.is_file(): - warn(f"could not find weights at {lda_weights}, skipping") - pytest.skip(allow_module_level=True) - return lda_weights - - -@pytest.fixture(scope="module") -def ella_weights(test_weights_path: Path) -> tuple[Path, Path]: - ella_adapter_weights = test_weights_path / "ELLA-Adapter" / "ella-sd1.5-tsc-t5xl.safetensors" - if not ella_adapter_weights.is_file(): - warn(f"could not find weights at {ella_adapter_weights}, skipping") - pytest.skip(allow_module_level=True) - t5xl_weights = test_weights_path / "QQGYLab" / "T5XLFP16" - t5xl_files = [ - "config.json", - "model.safetensors", - "special_tokens_map.json", - "spiece.model", - "tokenizer_config.json", - "tokenizer.json", - ] - for file in t5xl_files: - if not (t5xl_weights / file).is_file(): - warn(f"could not find weights at {t5xl_weights / file}, skipping") - pytest.skip(allow_module_level=True) - - return (ella_adapter_weights, t5xl_weights) - - -@pytest.fixture(scope="module") -def ip_adapter_weights(test_weights_path: Path) -> Path: - ip_adapter_weights = test_weights_path / "ip-adapter_sd15.safetensors" - if not ip_adapter_weights.is_file(): - warn(f"could not find weights at {ip_adapter_weights}, skipping") - pytest.skip(allow_module_level=True) - return ip_adapter_weights - - -@pytest.fixture(scope="module") -def ip_adapter_plus_weights(test_weights_path: Path) -> Path: - ip_adapter_weights = test_weights_path / "ip-adapter-plus_sd15.safetensors" - if not ip_adapter_weights.is_file(): - warn(f"could not find weights at {ip_adapter_weights}, skipping") - pytest.skip(allow_module_level=True) - return ip_adapter_weights - - -@pytest.fixture(scope="module") -def sdxl_ip_adapter_weights(test_weights_path: Path) -> Path: - ip_adapter_weights = test_weights_path / "ip-adapter_sdxl_vit-h.safetensors" - if not ip_adapter_weights.is_file(): - warn(f"could not find weights at {ip_adapter_weights}, skipping") - pytest.skip(allow_module_level=True) - return ip_adapter_weights - - -@pytest.fixture(scope="module") -def sdxl_ip_adapter_plus_weights(test_weights_path: Path) -> Path: - ip_adapter_weights = test_weights_path / "ip-adapter-plus_sdxl_vit-h.safetensors" - if not ip_adapter_weights.is_file(): - warn(f"could not find weights at {ip_adapter_weights}, skipping") - pytest.skip(allow_module_level=True) - return ip_adapter_weights - - -@pytest.fixture(scope="module") -def image_encoder_weights(test_weights_path: Path) -> Path: - image_encoder_weights = test_weights_path / "CLIPImageEncoderH.safetensors" - if not image_encoder_weights.is_file(): - warn(f"could not find weights at {image_encoder_weights}, skipping") - pytest.skip(allow_module_level=True) - return image_encoder_weights - - @pytest.fixture def sd15_std( - text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device + sd15_text_encoder_weights_path: Path, + sd15_autoencoder_weights_path: Path, + sd15_unet_weights_path: Path, + test_device: torch.device, ) -> StableDiffusion_1: if test_device.type == "cpu": warn("not running on CPU, skipping") @@ -600,16 +501,19 @@ def sd15_std( sd15 = StableDiffusion_1(device=test_device) - sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) - sd15.lda.load_from_safetensors(lda_weights) - sd15.unet.load_from_safetensors(unet_weights_std) + sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path) + sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path) + sd15.unet.load_from_safetensors(sd15_unet_weights_path) return sd15 @pytest.fixture def sd15_std_sde( - text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device + sd15_text_encoder_weights_path: Path, + sd15_autoencoder_weights_path: Path, + sd15_unet_weights_path: Path, + test_device: torch.device, ) -> StableDiffusion_1: if test_device.type == "cpu": warn("not running on CPU, skipping") @@ -618,16 +522,19 @@ def sd15_std_sde( sde_solver = DPMSolver(num_inference_steps=30, last_step_first_order=True, params=SolverParams(sde_variance=1.0)) sd15 = StableDiffusion_1(device=test_device, solver=sde_solver) - sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) - sd15.lda.load_from_safetensors(lda_weights) - sd15.unet.load_from_safetensors(unet_weights_std) + sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path) + sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path) + sd15.unet.load_from_safetensors(sd15_unet_weights_path) return sd15 @pytest.fixture def sd15_std_float16( - text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device + sd15_text_encoder_weights_path: Path, + sd15_autoencoder_weights_path: Path, + sd15_unet_weights_path: Path, + test_device: torch.device, ) -> StableDiffusion_1: if test_device.type == "cpu": warn("not running on CPU, skipping") @@ -635,18 +542,18 @@ def sd15_std_float16( sd15 = StableDiffusion_1(device=test_device, dtype=torch.float16) - sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) - sd15.lda.load_from_safetensors(lda_weights) - sd15.unet.load_from_safetensors(unet_weights_std) + sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path) + sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path) + sd15.unet.load_from_safetensors(sd15_unet_weights_path) return sd15 @pytest.fixture def sd15_std_bfloat16( - text_encoder_weights: Path, - lda_weights: Path, - unet_weights_std: Path, + sd15_text_encoder_weights_path: Path, + sd15_autoencoder_weights_path: Path, + sd15_unet_weights_path: Path, test_device: torch.device, ) -> StableDiffusion_1: if test_device.type == "cpu": @@ -655,16 +562,19 @@ def sd15_std_bfloat16( sd15 = StableDiffusion_1(device=test_device, dtype=torch.bfloat16) - sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) - sd15.lda.load_from_safetensors(lda_weights) - sd15.unet.load_from_safetensors(unet_weights_std) + sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path) + sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path) + sd15.unet.load_from_safetensors(sd15_unet_weights_path) return sd15 @pytest.fixture def sd15_inpainting( - text_encoder_weights: Path, lda_weights: Path, unet_weights_inpainting: Path, test_device: torch.device + sd15_text_encoder_weights_path: Path, + sd15_autoencoder_weights_path: Path, + sd15_unet_inpainting_weights_path: Path, + test_device: torch.device, ) -> StableDiffusion_1_Inpainting: if test_device.type == "cpu": warn("not running on CPU, skipping") @@ -673,16 +583,19 @@ def sd15_inpainting( unet = SD1UNet(in_channels=9) sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device) - sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) - sd15.lda.load_from_safetensors(lda_weights) - sd15.unet.load_from_safetensors(unet_weights_inpainting) + sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path) + sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path) + sd15.unet.load_from_safetensors(sd15_unet_inpainting_weights_path) return sd15 @pytest.fixture def sd15_inpainting_float16( - text_encoder_weights: Path, lda_weights: Path, unet_weights_inpainting: Path, test_device: torch.device + sd15_text_encoder_weights_path: Path, + sd15_autoencoder_weights_path: Path, + sd15_unet_inpainting_weights_path: Path, + test_device: torch.device, ) -> StableDiffusion_1_Inpainting: if test_device.type == "cpu": warn("not running on CPU, skipping") @@ -691,16 +604,19 @@ def sd15_inpainting_float16( unet = SD1UNet(in_channels=9) sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device, dtype=torch.float16) - sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) - sd15.lda.load_from_safetensors(lda_weights) - sd15.unet.load_from_safetensors(unet_weights_inpainting) + sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path) + sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path) + sd15.unet.load_from_safetensors(sd15_unet_inpainting_weights_path) return sd15 @pytest.fixture def sd15_ddim( - text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device + sd15_text_encoder_weights_path: Path, + sd15_autoencoder_weights_path: Path, + sd15_unet_weights_path: Path, + test_device: torch.device, ) -> StableDiffusion_1: if test_device.type == "cpu": warn("not running on CPU, skipping") @@ -709,16 +625,19 @@ def sd15_ddim( ddim_solver = DDIM(num_inference_steps=20) sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device) - sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) - sd15.lda.load_from_safetensors(lda_weights) - sd15.unet.load_from_safetensors(unet_weights_std) + sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path) + sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path) + sd15.unet.load_from_safetensors(sd15_unet_weights_path) return sd15 @pytest.fixture def sd15_ddim_karras( - text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device + sd15_text_encoder_weights_path: Path, + sd15_autoencoder_weights_path: Path, + sd15_unet_weights_path: Path, + test_device: torch.device, ) -> StableDiffusion_1: if test_device.type == "cpu": warn("not running on CPU, skipping") @@ -727,16 +646,18 @@ def sd15_ddim_karras( ddim_solver = DDIM(num_inference_steps=20, params=SolverParams(noise_schedule=NoiseSchedule.KARRAS)) sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device) - sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) - sd15.lda.load_from_safetensors(lda_weights) - sd15.unet.load_from_safetensors(unet_weights_std) - + sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path) + sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path) + sd15.unet.load_from_safetensors(sd15_unet_weights_path) return sd15 @pytest.fixture def sd15_euler( - text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device + sd15_text_encoder_weights_path: Path, + sd15_autoencoder_weights_path: Path, + sd15_unet_weights_path: Path, + test_device: torch.device, ) -> StableDiffusion_1: if test_device.type == "cpu": warn("not running on CPU, skipping") @@ -745,16 +666,19 @@ def sd15_euler( euler_solver = Euler(num_inference_steps=30) sd15 = StableDiffusion_1(solver=euler_solver, device=test_device) - sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) - sd15.lda.load_from_safetensors(lda_weights) - sd15.unet.load_from_safetensors(unet_weights_std) + sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path) + sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path) + sd15.unet.load_from_safetensors(sd15_unet_weights_path) return sd15 @pytest.fixture def sd15_ddim_lda_ft_mse( - text_encoder_weights: Path, lda_ft_mse_weights: Path, unet_weights_std: Path, test_device: torch.device + sd15_text_encoder_weights_path: Path, + sd15_autoencoder_mse_weights_path: Path, + sd15_unet_weights_path: Path, + test_device: torch.device, ) -> StableDiffusion_1: if test_device.type == "cpu": warn("not running on CPU, skipping") @@ -763,52 +687,19 @@ def sd15_ddim_lda_ft_mse( ddim_solver = DDIM(num_inference_steps=20) sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device) - sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights)) - sd15.lda.load_state_dict(load_from_safetensors(lda_ft_mse_weights)) - sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std)) + sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path) + sd15.lda.load_from_safetensors(sd15_autoencoder_mse_weights_path) + sd15.unet.load_from_safetensors(sd15_unet_weights_path) return sd15 -@pytest.fixture -def sdxl_lda_weights(test_weights_path: Path) -> Path: - sdxl_lda_weights = test_weights_path / "sdxl-lda.safetensors" - if not sdxl_lda_weights.is_file(): - warn(message=f"could not find weights at {sdxl_lda_weights}, skipping") - pytest.skip(allow_module_level=True) - return sdxl_lda_weights - - -@pytest.fixture -def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path: - sdxl_lda_weights = test_weights_path / "sdxl-lda-fp16-fix.safetensors" - if not sdxl_lda_weights.is_file(): - warn(message=f"could not find weights at {sdxl_lda_weights}, skipping") - pytest.skip(allow_module_level=True) - return sdxl_lda_weights - - -@pytest.fixture -def sdxl_unet_weights(test_weights_path: Path) -> Path: - sdxl_unet_weights = test_weights_path / "sdxl-unet.safetensors" - if not sdxl_unet_weights.is_file(): - warn(message=f"could not find weights at {sdxl_unet_weights}, skipping") - pytest.skip(allow_module_level=True) - return sdxl_unet_weights - - -@pytest.fixture -def sdxl_text_encoder_weights(test_weights_path: Path) -> Path: - sdxl_double_text_encoder_weights = test_weights_path / "DoubleCLIPTextEncoder.safetensors" - if not sdxl_double_text_encoder_weights.is_file(): - warn(message=f"could not find weights at {sdxl_double_text_encoder_weights}, skipping") - pytest.skip(allow_module_level=True) - return sdxl_double_text_encoder_weights - - @pytest.fixture def sdxl_ddim( - sdxl_text_encoder_weights: Path, sdxl_lda_weights: Path, sdxl_unet_weights: Path, test_device: torch.device + sdxl_text_encoder_weights_path: Path, + sdxl_autoencoder_weights_path: Path, + sdxl_unet_weights_path: Path, + test_device: torch.device, ) -> StableDiffusion_XL: if test_device.type == "cpu": warn(message="not running on CPU, skipping") @@ -817,16 +708,19 @@ def sdxl_ddim( solver = DDIM(num_inference_steps=30) sdxl = StableDiffusion_XL(solver=solver, device=test_device) - sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights) - sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_weights) - sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights) + sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights_path) + sdxl.lda.load_from_safetensors(tensors_path=sdxl_autoencoder_weights_path) + sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights_path) return sdxl @pytest.fixture def sdxl_ddim_lda_fp16_fix( - sdxl_text_encoder_weights: Path, sdxl_lda_fp16_fix_weights: Path, sdxl_unet_weights: Path, test_device: torch.device + sdxl_text_encoder_weights_path: Path, + sdxl_autoencoder_fp16fix_weights_path: Path, + sdxl_unet_weights_path: Path, + test_device: torch.device, ) -> StableDiffusion_XL: if test_device.type == "cpu": warn(message="not running on CPU, skipping") @@ -835,9 +729,9 @@ def sdxl_ddim_lda_fp16_fix( solver = DDIM(num_inference_steps=30) sdxl = StableDiffusion_XL(solver=solver, device=test_device) - sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights) - sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights) - sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights) + sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights_path) + sdxl.lda.load_from_safetensors(tensors_path=sdxl_autoencoder_fp16fix_weights_path) + sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights_path) return sdxl @@ -856,23 +750,18 @@ def sdxl_euler_deterministic(sdxl_ddim: StableDiffusion_XL) -> StableDiffusion_X @pytest.fixture(scope="module") def multi_upscaler( - test_weights_path: Path, - unet_weights_std: Path, - text_encoder_weights: Path, - lda_ft_mse_weights: Path, + controlnet_tiles_weights_path: Path, + sd15_text_encoder_weights_path: Path, + sd15_autoencoder_mse_weights_path: Path, + sd15_unet_weights_path: Path, test_device: torch.device, ) -> MultiUpscaler: - controlnet_tile_weights = test_weights_path / "controlnet" / "lllyasviel_control_v11f1e_sd15_tile.safetensors" - if not controlnet_tile_weights.is_file(): - warn(message=f"could not find weights at {controlnet_tile_weights}, skipping") - pytest.skip(allow_module_level=True) - return MultiUpscaler( checkpoints=UpscalerCheckpoints( - unet=unet_weights_std, - clip_text_encoder=text_encoder_weights, - lda=lda_ft_mse_weights, - controlnet_tile=controlnet_tile_weights, + unet=sd15_unet_weights_path, + clip_text_encoder=sd15_text_encoder_weights_path, + lda=sd15_autoencoder_mse_weights_path, + controlnet_tile=controlnet_tiles_weights_path, ), device=test_device, dtype=torch.float32, @@ -891,7 +780,9 @@ def expected_multi_upscaler(ref_path: Path) -> Image.Image: @no_grad() def test_diffusion_std_random_init( - sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device + sd15_std: StableDiffusion_1, + expected_image_std_random_init: Image.Image, + test_device: torch.device, ): sd15 = sd15_std @@ -1553,6 +1444,9 @@ def test_diffusion_sdxl_control_lora( adapters: dict[str, ControlLoraAdapter] = {} for config_name, config in configs.items(): + if not config.weights_path.is_file(): + pytest.skip(f"File not found: {config.weights_path}") + adapter = ControlLoraAdapter( name=config_name, scale=config.scale, @@ -1922,13 +1816,18 @@ def test_diffusion_textual_inversion_random_init( @no_grad() def test_diffusion_ella_adapter( sd15_std_float16: StableDiffusion_1, - ella_weights: tuple[Path, Path], + ella_sd15_tsc_t5xl_weights_path: Path, + t5xl_transformers_path: str, expected_image_ella_adapter: Image.Image, test_device: torch.device, + use_local_weights: bool, ): sd15 = sd15_std_float16 - ella_adapter_weights, t5xl_weights = ella_weights - t5_encoder = T5TextEmbedder(pretrained_path=t5xl_weights, max_length=128).to(test_device, torch.float16) + t5_encoder = T5TextEmbedder( + pretrained_path=t5xl_transformers_path, + local_files_only=use_local_weights, + max_length=128, + ).to(test_device, torch.float16) prompt = "a chinese man wearing a white shirt and a checkered headscarf, holds a large falcon near his shoulder. the falcon has dark feathers with a distinctive beak. the background consists of a clear sky and a fence, suggesting an outdoor setting, possibly a desert or arid region" negative_prompt = "" @@ -1938,7 +1837,7 @@ def test_diffusion_ella_adapter( llm_text_embedding, negative_prompt_embeds = t5_encoder(prompt), t5_encoder(negative_prompt) prompt_embedding = torch.cat((negative_prompt_embeds, llm_text_embedding)).to(test_device, torch.float16) - adapter = SD1ELLAAdapter(target=sd15.unet, weights=load_from_safetensors(ella_adapter_weights)) + adapter = SD1ELLAAdapter(target=sd15.unet, weights=load_from_safetensors(ella_sd15_tsc_t5xl_weights_path)) adapter.inject() sd15.set_inference_steps(50) manual_seed(1001) @@ -1959,8 +1858,8 @@ def test_diffusion_ella_adapter( @no_grad() def test_diffusion_ip_adapter( sd15_ddim_lda_ft_mse: StableDiffusion_1, - ip_adapter_weights: Path, - image_encoder_weights: Path, + ip_adapter_sd15_weights_path: Path, + clip_image_encoder_huge_weights_path: Path, woman_image: Image.Image, expected_image_ip_adapter_woman: Image.Image, test_device: torch.device, @@ -1976,8 +1875,8 @@ def test_diffusion_ip_adapter( prompt = "best quality, high quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_weights)) - ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) + ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_sd15_weights_path)) + ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path) ip_adapter.inject() clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) @@ -2004,8 +1903,8 @@ def test_diffusion_ip_adapter( @no_grad() def test_diffusion_ip_adapter_multi( sd15_ddim_lda_ft_mse: StableDiffusion_1, - ip_adapter_weights: Path, - image_encoder_weights: Path, + ip_adapter_sd15_weights_path: Path, + clip_image_encoder_huge_weights_path: Path, woman_image: Image.Image, statue_image: Image.Image, expected_image_ip_adapter_multi: Image.Image, @@ -2016,8 +1915,8 @@ def test_diffusion_ip_adapter_multi( prompt = "best quality, high quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_weights)) - ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) + ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_sd15_weights_path)) + ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path) ip_adapter.inject() clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) @@ -2044,8 +1943,8 @@ def test_diffusion_ip_adapter_multi( @no_grad() def test_diffusion_sdxl_ip_adapter( sdxl_ddim: StableDiffusion_XL, - sdxl_ip_adapter_weights: Path, - image_encoder_weights: Path, + ip_adapter_sdxl_weights_path: Path, + clip_image_encoder_huge_weights_path: Path, woman_image: Image.Image, expected_image_sdxl_ip_adapter_woman: Image.Image, test_device: torch.device, @@ -2055,8 +1954,8 @@ def test_diffusion_sdxl_ip_adapter( prompt = "best quality, high quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_weights)) - ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) + ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(ip_adapter_sdxl_weights_path)) + ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path) ip_adapter.inject() with no_grad(): @@ -2093,8 +1992,8 @@ def test_diffusion_sdxl_ip_adapter( @no_grad() def test_diffusion_ip_adapter_controlnet( sd15_ddim: StableDiffusion_1, - ip_adapter_weights: Path, - image_encoder_weights: Path, + ip_adapter_sd15_weights_path: Path, + clip_image_encoder_huge_weights_path: Path, lora_data_pokemon: tuple[Image.Image, Path], controlnet_data_depth: tuple[str, Image.Image, Image.Image, Path], expected_image_ip_adapter_controlnet: Image.Image, @@ -2107,8 +2006,8 @@ def test_diffusion_ip_adapter_controlnet( prompt = "best quality, high quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_weights)) - ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) + ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_sd15_weights_path)) + ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path) ip_adapter.inject() depth_controlnet = SD1ControlnetAdapter( @@ -2149,8 +2048,8 @@ def test_diffusion_ip_adapter_controlnet( @no_grad() def test_diffusion_ip_adapter_plus( sd15_ddim_lda_ft_mse: StableDiffusion_1, - ip_adapter_plus_weights: Path, - image_encoder_weights: Path, + ip_adapter_sd15_plus_weights_path: Path, + clip_image_encoder_huge_weights_path: Path, statue_image: Image.Image, expected_image_ip_adapter_plus_statue: Image.Image, test_device: torch.device, @@ -2161,9 +2060,9 @@ def test_diffusion_ip_adapter_plus( negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" ip_adapter = SD1IPAdapter( - target=sd15.unet, weights=load_from_safetensors(ip_adapter_plus_weights), fine_grained=True + target=sd15.unet, weights=load_from_safetensors(ip_adapter_sd15_plus_weights_path), fine_grained=True ) - ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) + ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path) ip_adapter.inject() clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) @@ -2190,8 +2089,8 @@ def test_diffusion_ip_adapter_plus( @no_grad() def test_diffusion_sdxl_ip_adapter_plus( sdxl_ddim: StableDiffusion_XL, - sdxl_ip_adapter_plus_weights: Path, - image_encoder_weights: Path, + ip_adapter_sdxl_plus_weights_path: Path, + clip_image_encoder_huge_weights_path: Path, woman_image: Image.Image, expected_image_sdxl_ip_adapter_plus_woman: Image.Image, test_device: torch.device, @@ -2202,9 +2101,9 @@ def test_diffusion_sdxl_ip_adapter_plus( negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" ip_adapter = SDXLIPAdapter( - target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_plus_weights), fine_grained=True + target=sdxl.unet, weights=load_from_safetensors(ip_adapter_sdxl_plus_weights_path), fine_grained=True ) - ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) + ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path) ip_adapter.inject() clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( @@ -2608,8 +2507,8 @@ def test_freeu( def test_hello_world( sdxl_ddim_lda_fp16_fix: StableDiffusion_XL, t2i_adapter_xl_data_canny: tuple[str, Image.Image, Image.Image, Path], - sdxl_ip_adapter_weights: Path, - image_encoder_weights: Path, + ip_adapter_sdxl_weights_path: Path, + clip_image_encoder_huge_weights_path: Path, hello_world_assets: tuple[Image.Image, Image.Image, Image.Image, Image.Image], ) -> None: sdxl = sdxl_ddim_lda_fp16_fix.to(dtype=torch.float16) @@ -2622,8 +2521,8 @@ def test_hello_world( warn(f"could not find weights at {weights_path}, skipping") pytest.skip(allow_module_level=True) - ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_weights)) - ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) + ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(ip_adapter_sdxl_weights_path)) + ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path) ip_adapter.inject() image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt)) @@ -2743,24 +2642,19 @@ def expected_ic_light(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_ic_light.png").convert("RGB") -@pytest.fixture(scope="module") -def ic_light_sd15_fc_weights(test_weights_path: Path) -> Path: - return test_weights_path / "iclight_sd15_fc-refiners.safetensors" - - @pytest.fixture(scope="module") def ic_light_sd15_fc( - ic_light_sd15_fc_weights: Path, - unet_weights_std: Path, - lda_weights: Path, - text_encoder_weights: Path, + ic_light_sd15_fc_weights_path: Path, + sd15_unet_weights_path: Path, + sd15_autoencoder_weights_path: Path, + sd15_text_encoder_weights_path: Path, test_device: torch.device, ) -> ICLight: return ICLight( - patch_weights=load_from_safetensors(ic_light_sd15_fc_weights), - unet=SD1UNet(in_channels=4).load_from_safetensors(unet_weights_std), - lda=SD1Autoencoder().load_from_safetensors(lda_weights), - clip_text_encoder=CLIPTextEncoderL().load_from_safetensors(text_encoder_weights), + patch_weights=load_from_safetensors(ic_light_sd15_fc_weights_path), + unet=SD1UNet(in_channels=4).load_from_safetensors(sd15_unet_weights_path), + lda=SD1Autoencoder().load_from_safetensors(sd15_autoencoder_weights_path), + clip_text_encoder=CLIPTextEncoderL().load_from_safetensors(sd15_text_encoder_weights_path), device=test_device, ) diff --git a/tests/e2e/test_doc_examples.py b/tests/e2e/test_doc_examples.py index 7514c25..0289b9f 100644 --- a/tests/e2e/test_doc_examples.py +++ b/tests/e2e/test_doc_examples.py @@ -29,74 +29,11 @@ def ref_path(test_e2e_path: Path) -> Path: return test_e2e_path / "test_doc_examples_ref" -@pytest.fixture(scope="module") -def sdxl_text_encoder_weights(test_weights_path: Path) -> Path: - path = test_weights_path / "DoubleCLIPTextEncoder.safetensors" - if not path.is_file(): - warn(message=f"could not find weights at {path}, skipping") - pytest.skip(allow_module_level=True) - return path - - -@pytest.fixture(scope="module") -def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path: - path = test_weights_path / "sdxl-lda-fp16-fix.safetensors" - if not path.is_file(): - warn(message=f"could not find weights at {path}, skipping") - pytest.skip(allow_module_level=True) - return path - - -@pytest.fixture(scope="module") -def sdxl_unet_weights(test_weights_path: Path) -> Path: - path = test_weights_path / "sdxl-unet.safetensors" - if not path.is_file(): - warn(message=f"could not find weights at {path}, skipping") - pytest.skip(allow_module_level=True) - return path - - -@pytest.fixture(scope="module") -def sdxl_ip_adapter_plus_weights(test_weights_path: Path) -> Path: - path = test_weights_path / "ip-adapter-plus_sdxl_vit-h.safetensors" - if not path.is_file(): - warn(f"could not find weights at {path}, skipping") - pytest.skip(allow_module_level=True) - return path - - -@pytest.fixture(scope="module") -def image_encoder_weights(test_weights_path: Path) -> Path: - path = test_weights_path / "CLIPImageEncoderH.safetensors" - if not path.is_file(): - warn(f"could not find weights at {path}, skipping") - pytest.skip(allow_module_level=True) - return path - - -@pytest.fixture -def scifi_lora_weights(test_weights_path: Path) -> Path: - path = test_weights_path / "loras" / "Sci-fi_Environments_sdxl.safetensors" - if not path.is_file(): - warn(message=f"could not find weights at {path}, skipping") - pytest.skip(allow_module_level=True) - return path - - -@pytest.fixture -def pixelart_lora_weights(test_weights_path: Path) -> Path: - path = test_weights_path / "loras" / "pixel-art-xl-v1.1.safetensors" - if not path.is_file(): - warn(message=f"could not find weights at {path}, skipping") - pytest.skip(allow_module_level=True) - return path - - @pytest.fixture def sdxl( - sdxl_text_encoder_weights: Path, - sdxl_lda_fp16_fix_weights: Path, - sdxl_unet_weights: Path, + sdxl_text_encoder_weights_path: Path, + sdxl_autoencoder_fp16fix_weights_path: Path, + sdxl_unet_weights_path: Path, test_device: torch.device, ) -> StableDiffusion_XL: if test_device.type == "cpu": @@ -105,9 +42,9 @@ def sdxl( sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16) - sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights) - sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights) - sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights) + sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights_path) + sdxl.lda.load_from_safetensors(tensors_path=sdxl_autoencoder_fp16fix_weights_path) + sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights_path) return sdxl @@ -180,7 +117,7 @@ def test_guide_adapting_sdxl_vanilla( def test_guide_adapting_sdxl_single_lora( test_device: torch.device, sdxl: StableDiffusion_XL, - scifi_lora_weights: Path, + lora_scifi_weights_path: Path, expected_image_guide_adapting_sdxl_single_lora: Image.Image, ) -> None: if test_device.type == "cpu": @@ -195,7 +132,7 @@ def test_guide_adapting_sdxl_single_lora( sdxl.set_self_attention_guidance(enable=True, scale=0.75) manager = SDLoraManager(sdxl) - manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights)) + manager.add_loras("scifi-lora", load_from_safetensors(lora_scifi_weights_path)) clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( text=prompt + ", best quality, high quality", @@ -222,8 +159,8 @@ def test_guide_adapting_sdxl_single_lora( def test_guide_adapting_sdxl_multiple_loras( test_device: torch.device, sdxl: StableDiffusion_XL, - scifi_lora_weights: Path, - pixelart_lora_weights: Path, + lora_scifi_weights_path: Path, + lora_pixelart_weights_path: Path, expected_image_guide_adapting_sdxl_multiple_loras: Image.Image, ) -> None: if test_device.type == "cpu": @@ -238,8 +175,8 @@ def test_guide_adapting_sdxl_multiple_loras( sdxl.set_self_attention_guidance(enable=True, scale=0.75) manager = SDLoraManager(sdxl) - manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights)) - manager.add_loras("pixel-art-lora", load_from_safetensors(pixelart_lora_weights), scale=1.4) + manager.add_loras("scifi-lora", load_from_safetensors(lora_scifi_weights_path)) + manager.add_loras("pixel-art-lora", load_from_safetensors(lora_pixelart_weights_path), scale=1.4) clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( text=prompt + ", best quality, high quality", @@ -266,10 +203,10 @@ def test_guide_adapting_sdxl_multiple_loras( def test_guide_adapting_sdxl_loras_ip_adapter( test_device: torch.device, sdxl: StableDiffusion_XL, - sdxl_ip_adapter_plus_weights: Path, - image_encoder_weights: Path, - scifi_lora_weights: Path, - pixelart_lora_weights: Path, + ip_adapter_sdxl_plus_weights_path: Path, + clip_image_encoder_huge_weights_path: Path, + lora_scifi_weights_path: Path, + lora_pixelart_weights_path: Path, image_prompt_german_castle: Image.Image, expected_image_guide_adapting_sdxl_loras_ip_adapter: Image.Image, ) -> None: @@ -285,16 +222,16 @@ def test_guide_adapting_sdxl_loras_ip_adapter( sdxl.set_self_attention_guidance(enable=True, scale=0.75) manager = SDLoraManager(sdxl) - manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights), scale=1.5) - manager.add_loras("pixel-art-lora", load_from_safetensors(pixelart_lora_weights), scale=1.55) + manager.add_loras("scifi-lora", load_from_safetensors(lora_scifi_weights_path), scale=1.5) + manager.add_loras("pixel-art-lora", load_from_safetensors(lora_pixelart_weights_path), scale=1.55) ip_adapter = SDXLIPAdapter( target=sdxl.unet, - weights=load_from_safetensors(sdxl_ip_adapter_plus_weights), + weights=load_from_safetensors(ip_adapter_sdxl_plus_weights_path), scale=1.0, fine_grained=True, ) - ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) + ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path) ip_adapter.inject() clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( diff --git a/tests/e2e/test_lcm.py b/tests/e2e/test_lcm.py index 5b9e50d..c2e03e3 100644 --- a/tests/e2e/test_lcm.py +++ b/tests/e2e/test_lcm.py @@ -26,51 +26,6 @@ def ensure_gc(): gc.collect() -@pytest.fixture -def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path: - r = test_weights_path / "sdxl-lda-fp16-fix.safetensors" - if not r.is_file(): - warn(f"could not find weights at {r}, skipping") - pytest.skip(allow_module_level=True) - return r - - -@pytest.fixture -def sdxl_unet_weights(test_weights_path: Path) -> Path: - r = test_weights_path / "sdxl-unet.safetensors" - if not r.is_file(): - warn(f"could not find weights at {r}, skipping") - pytest.skip(allow_module_level=True) - return r - - -@pytest.fixture -def sdxl_lcm_unet_weights(test_weights_path: Path) -> Path: - r = test_weights_path / "sdxl-lcm-unet.safetensors" - if not r.is_file(): - warn(f"could not find weights at {r}, skipping") - pytest.skip(allow_module_level=True) - return r - - -@pytest.fixture -def sdxl_text_encoder_weights(test_weights_path: Path) -> Path: - r = test_weights_path / "DoubleCLIPTextEncoder.safetensors" - if not r.is_file(): - warn(f"could not find weights at {r}, skipping") - pytest.skip(allow_module_level=True) - return r - - -@pytest.fixture -def sdxl_lcm_lora_weights(test_weights_path: Path) -> Path: - r = test_weights_path / "sdxl-lcm-lora.safetensors" - if not r.is_file(): - warn(f"could not find weights at {r}, skipping") - pytest.skip(allow_module_level=True) - return r - - @pytest.fixture(scope="module") def ref_path(test_e2e_path: Path) -> Path: return test_e2e_path / "test_lcm_ref" @@ -94,9 +49,9 @@ def expected_lcm_lora_1_2(ref_path: Path) -> Image.Image: @no_grad() def test_lcm_base( test_device: torch.device, - sdxl_lda_fp16_fix_weights: Path, - sdxl_lcm_unet_weights: Path, - sdxl_text_encoder_weights: Path, + sdxl_autoencoder_fp16fix_weights_path: Path, + sdxl_unet_lcm_weights_path: Path, + sdxl_text_encoder_weights_path: Path, expected_lcm_base: Image.Image, ) -> None: if test_device.type == "cpu": @@ -111,9 +66,9 @@ def test_lcm_base( # not in the diffusion loop. SDXLLcmAdapter(sdxl.unet, condition_scale=8.0).inject() - sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights) - sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights) - sdxl.unet.load_from_safetensors(sdxl_lcm_unet_weights) + sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path) + sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path) + sdxl.unet.load_from_safetensors(sdxl_unet_lcm_weights_path) prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" expected_image = expected_lcm_base @@ -141,10 +96,10 @@ def test_lcm_base( @pytest.mark.parametrize("condition_scale", [1.0, 1.2]) def test_lcm_lora_with_guidance( test_device: torch.device, - sdxl_lda_fp16_fix_weights: Path, - sdxl_unet_weights: Path, - sdxl_text_encoder_weights: Path, - sdxl_lcm_lora_weights: Path, + sdxl_autoencoder_fp16fix_weights_path: Path, + sdxl_unet_weights_path: Path, + sdxl_text_encoder_weights_path: Path, + lora_sdxl_lcm_weights_path: Path, expected_lcm_lora_1_0: Image.Image, expected_lcm_lora_1_2: Image.Image, condition_scale: float, @@ -156,12 +111,12 @@ def test_lcm_lora_with_guidance( solver = LCMSolver(num_inference_steps=4) sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver) - sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights) - sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights) - sdxl.unet.load_from_safetensors(sdxl_unet_weights) + sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path) + sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path) + sdxl.unet.load_from_safetensors(sdxl_unet_weights_path) manager = SDLoraManager(sdxl) - add_lcm_lora(manager, load_from_safetensors(sdxl_lcm_lora_weights)) + add_lcm_lora(manager, load_from_safetensors(lora_sdxl_lcm_weights_path)) prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" expected_image = expected_lcm_lora_1_0 if condition_scale == 1.0 else expected_lcm_lora_1_2 @@ -191,10 +146,10 @@ def test_lcm_lora_with_guidance( @no_grad() def test_lcm_lora_without_guidance( test_device: torch.device, - sdxl_lda_fp16_fix_weights: Path, - sdxl_unet_weights: Path, - sdxl_text_encoder_weights: Path, - sdxl_lcm_lora_weights: Path, + sdxl_autoencoder_fp16fix_weights_path: Path, + sdxl_unet_weights_path: Path, + sdxl_text_encoder_weights_path: Path, + lora_sdxl_lcm_weights_path: Path, expected_lcm_lora_1_0: Image.Image, ) -> None: if test_device.type == "cpu": @@ -205,12 +160,12 @@ def test_lcm_lora_without_guidance( sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver) sdxl.classifier_free_guidance = False - sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights) - sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights) - sdxl.unet.load_from_safetensors(sdxl_unet_weights) + sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path) + sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path) + sdxl.unet.load_from_safetensors(sdxl_unet_weights_path) manager = SDLoraManager(sdxl) - add_lcm_lora(manager, load_from_safetensors(sdxl_lcm_lora_weights)) + add_lcm_lora(manager, load_from_safetensors(lora_sdxl_lcm_weights_path)) prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" expected_image = expected_lcm_lora_1_0 diff --git a/tests/e2e/test_lightning.py b/tests/e2e/test_lightning.py index bf8eb8a..83d701e 100644 --- a/tests/e2e/test_lightning.py +++ b/tests/e2e/test_lightning.py @@ -25,60 +25,6 @@ def ensure_gc(): gc.collect() -@pytest.fixture -def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path: - r = test_weights_path / "sdxl-lda-fp16-fix.safetensors" - if not r.is_file(): - warn(f"could not find weights at {r}, skipping") - pytest.skip(allow_module_level=True) - return r - - -@pytest.fixture -def sdxl_unet_weights(test_weights_path: Path) -> Path: - r = test_weights_path / "sdxl-unet.safetensors" - if not r.is_file(): - warn(f"could not find weights at {r}, skipping") - pytest.skip(allow_module_level=True) - return r - - -@pytest.fixture -def sdxl_lightning_4step_unet_weights(test_weights_path: Path) -> Path: - r = test_weights_path / "sdxl_lightning_4step_unet.safetensors" - if not r.is_file(): - warn(f"could not find weights at {r}, skipping") - pytest.skip(allow_module_level=True) - return r - - -@pytest.fixture -def sdxl_lightning_1step_unet_weights(test_weights_path: Path) -> Path: - r = test_weights_path / "sdxl_lightning_1step_unet_x0.safetensors" - if not r.is_file(): - warn(f"could not find weights at {r}, skipping") - pytest.skip(allow_module_level=True) - return r - - -@pytest.fixture -def sdxl_text_encoder_weights(test_weights_path: Path) -> Path: - r = test_weights_path / "DoubleCLIPTextEncoder.safetensors" - if not r.is_file(): - warn(f"could not find weights at {r}, skipping") - pytest.skip(allow_module_level=True) - return r - - -@pytest.fixture -def sdxl_lightning_4step_lora_weights(test_weights_path: Path) -> Path: - r = test_weights_path / "sdxl_lightning_4step_lora.safetensors" - if not r.is_file(): - warn(f"could not find weights at {r}, skipping") - pytest.skip(allow_module_level=True) - return r - - @pytest.fixture(scope="module") def ref_path(test_e2e_path: Path) -> Path: return test_e2e_path / "test_lightning_ref" @@ -102,16 +48,16 @@ def expected_lightning_lora_4step(ref_path: Path) -> Image.Image: @no_grad() def test_lightning_base_4step( test_device: torch.device, - sdxl_lda_fp16_fix_weights: Path, - sdxl_lightning_4step_unet_weights: Path, - sdxl_text_encoder_weights: Path, + sdxl_autoencoder_fp16fix_weights_path: Path, + sdxl_unet_lightning_4step_weights_path: Path, + sdxl_text_encoder_weights_path: Path, expected_lightning_base_4step: Image.Image, ) -> None: if test_device.type == "cpu": warn(message="not running on CPU, skipping") pytest.skip() - unet_weights = sdxl_lightning_4step_unet_weights + unet_weights = sdxl_unet_lightning_4step_weights_path expected_image = expected_lightning_base_4step solver = Euler( @@ -125,8 +71,8 @@ def test_lightning_base_4step( sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver) sdxl.classifier_free_guidance = False - sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights) - sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights) + sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path) + sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path) sdxl.unet.load_from_safetensors(unet_weights) prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" @@ -153,16 +99,16 @@ def test_lightning_base_4step( @no_grad() def test_lightning_base_1step( test_device: torch.device, - sdxl_lda_fp16_fix_weights: Path, - sdxl_lightning_1step_unet_weights: Path, - sdxl_text_encoder_weights: Path, + sdxl_autoencoder_fp16fix_weights_path: Path, + sdxl_unet_lightning_1step_weights_path: Path, + sdxl_text_encoder_weights_path: Path, expected_lightning_base_1step: Image.Image, ) -> None: if test_device.type == "cpu": warn(message="not running on CPU, skipping") pytest.skip() - unet_weights = sdxl_lightning_1step_unet_weights + unet_weights = sdxl_unet_lightning_1step_weights_path expected_image = expected_lightning_base_1step solver = Euler( @@ -176,8 +122,8 @@ def test_lightning_base_1step( sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver) sdxl.classifier_free_guidance = False - sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights) - sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights) + sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path) + sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path) sdxl.unet.load_from_safetensors(unet_weights) prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" @@ -204,10 +150,10 @@ def test_lightning_base_1step( @no_grad() def test_lightning_lora_4step( test_device: torch.device, - sdxl_lda_fp16_fix_weights: Path, - sdxl_unet_weights: Path, - sdxl_text_encoder_weights: Path, - sdxl_lightning_4step_lora_weights: Path, + sdxl_autoencoder_fp16fix_weights_path: Path, + sdxl_unet_weights_path: Path, + sdxl_text_encoder_weights_path: Path, + lora_sdxl_lightning_4step_weights_path: Path, expected_lightning_lora_4step: Image.Image, ) -> None: if test_device.type == "cpu": @@ -227,12 +173,12 @@ def test_lightning_lora_4step( sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver) sdxl.classifier_free_guidance = False - sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights) - sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights) - sdxl.unet.load_from_safetensors(sdxl_unet_weights) + sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path) + sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path) + sdxl.unet.load_from_safetensors(sdxl_unet_weights_path) manager = SDLoraManager(sdxl) - add_lcm_lora(manager, load_from_safetensors(sdxl_lightning_4step_lora_weights), name="lightning") + add_lcm_lora(manager, load_from_safetensors(lora_sdxl_lightning_4step_weights_path), name="lightning") prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" diff --git a/tests/e2e/test_mvanet.py b/tests/e2e/test_mvanet.py index 71833ba..0022f71 100644 --- a/tests/e2e/test_mvanet.py +++ b/tests/e2e/test_mvanet.py @@ -29,19 +29,10 @@ def expected_cactus_mask(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_cactus_mask.png") -@pytest.fixture(scope="module") -def mvanet_weights(test_weights_path: Path) -> Path: - weights = test_weights_path / "mvanet" / "mvanet.safetensors" - if not weights.is_file(): - warn(f"could not find weights at {test_weights_path}, skipping") - pytest.skip(allow_module_level=True) - return weights - - @pytest.fixture -def mvanet_model(mvanet_weights: Path, test_device: torch.device) -> MVANet: +def mvanet_model(mvanet_weights_path: Path, test_device: torch.device) -> MVANet: model = MVANet(device=test_device).eval() # .eval() is important! - model.load_from_safetensors(mvanet_weights) + model.load_from_safetensors(mvanet_weights_path) return model @@ -61,7 +52,7 @@ def test_mvanet( @no_grad() def test_mvanet_to( - mvanet_weights: Path, + mvanet_weights_path: Path, ref_cactus: Image.Image, expected_cactus_mask: Image.Image, test_device: torch.device, @@ -71,7 +62,7 @@ def test_mvanet_to( pytest.skip() model = MVANet(device=torch.device("cpu")).eval() - model.load_from_safetensors(mvanet_weights) + model.load_from_safetensors(mvanet_weights_path) model.to(test_device) in_t = image_to_tensor(ref_cactus.resize((1024, 1024), Image.Resampling.BILINEAR)).squeeze() diff --git a/tests/e2e/test_preprocessors.py b/tests/e2e/test_preprocessors.py index 6eedf42..c5c9ae6 100644 --- a/tests/e2e/test_preprocessors.py +++ b/tests/e2e/test_preprocessors.py @@ -1,5 +1,4 @@ from pathlib import Path -from warnings import warn import pytest import torch @@ -29,19 +28,13 @@ def expected_image_informative_drawings(diffusion_ref_path: Path) -> Image.Image return _img_open(diffusion_ref_path / "cutecat_guide_lineart.png").convert("RGB") -@pytest.fixture(scope="module") -def informative_drawings_weights(test_weights_path: Path) -> Path: - weights = test_weights_path / "informative-drawings.safetensors" - if not weights.is_file(): - warn(f"could not find weights at {test_weights_path}, skipping") - pytest.skip(allow_module_level=True) - return weights - - @pytest.fixture -def informative_drawings_model(informative_drawings_weights: Path, test_device: torch.device) -> InformativeDrawings: +def informative_drawings_model( + controlnet_preprocessor_info_drawings_weights_path: Path, + test_device: torch.device, +) -> InformativeDrawings: model = InformativeDrawings(device=test_device) - model.load_from_safetensors(informative_drawings_weights) + model.load_from_safetensors(controlnet_preprocessor_info_drawings_weights_path) return model diff --git a/tests/e2e/test_solutions.py b/tests/e2e/test_solutions.py index b9a0769..5dc47ca 100644 --- a/tests/e2e/test_solutions.py +++ b/tests/e2e/test_solutions.py @@ -1,5 +1,4 @@ from pathlib import Path -from warnings import warn import pytest import torch @@ -38,24 +37,15 @@ def expected_box_segmenter_spray_cropped_mask(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_box_segmenter_spray_cropped_mask.png") -@pytest.fixture(scope="module") -def box_segmenter_weights(test_weights_path: Path) -> Path: - weights = test_weights_path / "finegrain-box-segmenter-v0-1.safetensors" - if not weights.is_file(): - warn(f"could not find weights at {test_weights_path}, skipping") - pytest.skip(allow_module_level=True) - return weights - - def test_box_segmenter( - box_segmenter_weights: Path, + box_segmenter_weights_path: Path, ref_shelves: Image.Image, expected_box_segmenter_plant_mask: Image.Image, expected_box_segmenter_spray_mask: Image.Image, expected_box_segmenter_spray_cropped_mask: Image.Image, test_device: torch.device, ): - segmenter = BoxSegmenter(weights=box_segmenter_weights, device=test_device) + segmenter = BoxSegmenter(weights=box_segmenter_weights_path, device=test_device) plant_mask = segmenter(ref_shelves, box_prompt=(504, 82, 754, 368)) ensure_similar_images(plant_mask.convert("RGB"), expected_box_segmenter_plant_mask.convert("RGB")) diff --git a/tests/fluxion/test_model_converter.py b/tests/fluxion/test_model_converter.py index 657ce0a..e0d2bc4 100644 --- a/tests/fluxion/test_model_converter.py +++ b/tests/fluxion/test_model_converter.py @@ -4,7 +4,7 @@ import torch from torch import Tensor, nn import refiners.fluxion.layers as fl -from refiners.fluxion.model_converter import ConversionStage, ModelConverter +from refiners.conversion.model_converter import ConversionStage, ModelConverter from refiners.fluxion.utils import manual_seed diff --git a/tests/foundationals/clip/test_concepts.py b/tests/foundationals/clip/test_concepts.py index debeee5..81a6b3b 100644 --- a/tests/foundationals/clip/test_concepts.py +++ b/tests/foundationals/clip/test_concepts.py @@ -1,5 +1,4 @@ from pathlib import Path -from warnings import warn import pytest import torch @@ -20,17 +19,13 @@ PROMPTS = [ @pytest.fixture(scope="module") def our_encoder_with_new_concepts( - test_weights_path: Path, + sd15_text_encoder_weights_path: Path, test_device: torch.device, cat_embedding_textual_inversion: torch.Tensor, gta5_artwork_embedding_textual_inversion: torch.Tensor, ) -> CLIPTextEncoderL: - weights = test_weights_path / "CLIPTextEncoderL.safetensors" - if not weights.is_file(): - warn(f"could not find weights at {weights}, skipping") - pytest.skip(allow_module_level=True) encoder = CLIPTextEncoderL(device=test_device) - tensors = load_from_safetensors(weights) + tensors = load_from_safetensors(sd15_text_encoder_weights_path) encoder.load_state_dict(tensors) concept_extender = ConceptExtender(encoder) concept_extender.add_concept("", cat_embedding_textual_inversion) @@ -41,24 +36,21 @@ def our_encoder_with_new_concepts( @pytest.fixture(scope="module") def ref_sd15_with_new_concepts( - runwayml_weights_path: Path, test_textual_inversion_path: Path, test_device: torch.device + sd15_diffusers_runwayml_path: str, + test_textual_inversion_path: Path, + test_device: torch.device, + use_local_weights: bool, ) -> StableDiffusionPipeline: - pipe = StableDiffusionPipeline.from_pretrained(runwayml_weights_path).to(test_device) # type: ignore + pipe = StableDiffusionPipeline.from_pretrained( # type: ignore + sd15_diffusers_runwayml_path, + local_files_only=use_local_weights, + ).to(test_device) # type: ignore assert isinstance(pipe, StableDiffusionPipeline) pipe.load_textual_inversion(test_textual_inversion_path / "cat-toy") # type: ignore pipe.load_textual_inversion(test_textual_inversion_path / "gta5-artwork") # type: ignore return pipe -@pytest.fixture(scope="module") -def runwayml_weights_path(test_weights_path: Path): - r = test_weights_path / "runwayml" / "stable-diffusion-v1-5" - if not r.is_dir(): - warn(f"could not find RunwayML weights at {r}, skipping") - pytest.skip(allow_module_level=True) - return r - - @pytest.fixture(scope="module") def ref_tokenizer_with_new_concepts(ref_sd15_with_new_concepts: StableDiffusionPipeline) -> transformers.CLIPTokenizer: return ref_sd15_with_new_concepts.tokenizer # type: ignore diff --git a/tests/foundationals/clip/test_image_encoder.py b/tests/foundationals/clip/test_image_encoder.py index 69988c2..e532d3a 100644 --- a/tests/foundationals/clip/test_image_encoder.py +++ b/tests/foundationals/clip/test_image_encoder.py @@ -1,5 +1,4 @@ from pathlib import Path -from warnings import warn import pytest import torch @@ -11,39 +10,28 @@ from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH @pytest.fixture(scope="module") def our_encoder( - test_weights_path: Path, + clip_image_encoder_huge_weights_path: Path, test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype, ) -> CLIPImageEncoderH: - weights = test_weights_path / "CLIPImageEncoderH.safetensors" - if not weights.is_file(): - warn(f"could not find weights at {weights}, skipping") - pytest.skip(allow_module_level=True) encoder = CLIPImageEncoderH(device=test_device, dtype=test_dtype_fp32_bf16_fp16) - tensors = load_from_safetensors(weights) + tensors = load_from_safetensors(clip_image_encoder_huge_weights_path) encoder.load_state_dict(tensors) return encoder -@pytest.fixture(scope="module") -def stabilityai_unclip_weights_path(test_weights_path: Path): - r = test_weights_path / "stabilityai" / "stable-diffusion-2-1-unclip" - if not r.is_dir(): - warn(f"could not find Stability AI weights at {r}, skipping") - pytest.skip(allow_module_level=True) - return r - - @pytest.fixture(scope="module") def ref_encoder( - stabilityai_unclip_weights_path: Path, + unclip21_transformers_stabilityai_path: str, test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype, + use_local_weights: bool, ) -> CLIPVisionModelWithProjection: return CLIPVisionModelWithProjection.from_pretrained( # type: ignore - stabilityai_unclip_weights_path, + unclip21_transformers_stabilityai_path, + local_files_only=use_local_weights, subfolder="image_encoder", - ).to(device=test_device, dtype=test_dtype_fp32_bf16_fp16) + ).to(device=test_device, dtype=test_dtype_fp32_bf16_fp16) # type: ignore @no_grad() diff --git a/tests/foundationals/clip/test_text_encoder.py b/tests/foundationals/clip/test_text_encoder.py index 28eeada..67e841c 100644 --- a/tests/foundationals/clip/test_text_encoder.py +++ b/tests/foundationals/clip/test_text_encoder.py @@ -1,5 +1,4 @@ from pathlib import Path -from warnings import warn import pytest import torch @@ -31,42 +30,39 @@ PROMPTS = [ @pytest.fixture(scope="module") def our_encoder( - test_weights_path: Path, + sd15_text_encoder_weights_path: Path, test_device: torch.device, test_dtype_fp32_fp16: torch.dtype, ) -> CLIPTextEncoderL: - weights = test_weights_path / "CLIPTextEncoderL.safetensors" - if not weights.is_file(): - warn(f"could not find weights at {weights}, skipping") - pytest.skip(allow_module_level=True) - tensors = load_from_safetensors(weights) encoder = CLIPTextEncoderL(device=test_device, dtype=test_dtype_fp32_fp16) + tensors = load_from_safetensors(sd15_text_encoder_weights_path) + encoder.load_state_dict(tensors) return encoder @pytest.fixture(scope="module") -def runwayml_weights_path(test_weights_path: Path): - r = test_weights_path / "runwayml" / "stable-diffusion-v1-5" - if not r.is_dir(): - warn(f"could not find RunwayML weights at {r}, skipping") - pytest.skip(allow_module_level=True) - return r - - -@pytest.fixture(scope="module") -def ref_tokenizer(runwayml_weights_path: Path) -> transformers.CLIPTokenizer: - return transformers.CLIPTokenizer.from_pretrained(runwayml_weights_path, subfolder="tokenizer") # type: ignore +def ref_tokenizer( + sd15_diffusers_runwayml_path: str, + use_local_weights: bool, +) -> transformers.CLIPTokenizer: + return transformers.CLIPTokenizer.from_pretrained( # type: ignore + sd15_diffusers_runwayml_path, + local_files_only=use_local_weights, + subfolder="tokenizer", + ) @pytest.fixture(scope="module") def ref_encoder( - runwayml_weights_path: Path, + sd15_diffusers_runwayml_path: str, test_device: torch.device, test_dtype_fp32_fp16: torch.dtype, + use_local_weights: bool, ) -> transformers.CLIPTextModel: return transformers.CLIPTextModel.from_pretrained( # type: ignore - runwayml_weights_path, + sd15_diffusers_runwayml_path, + local_files_only=use_local_weights, subfolder="text_encoder", ).to(device=test_device, dtype=test_dtype_fp32_fp16) # type: ignore diff --git a/tests/foundationals/dinov2/test_dinov2.py b/tests/foundationals/dinov2/test_dinov2.py index 1e73b6d..4d48bcd 100644 --- a/tests/foundationals/dinov2/test_dinov2.py +++ b/tests/foundationals/dinov2/test_dinov2.py @@ -4,6 +4,7 @@ from warnings import warn import pytest import torch +from huggingface_hub import hf_hub_download # type: ignore from refiners.fluxion.utils import load_from_safetensors, load_tensors, manual_seed, no_grad from refiners.foundationals.dinov2.dinov2 import ( @@ -18,7 +19,7 @@ from refiners.foundationals.dinov2.dinov2 import ( ) from refiners.foundationals.dinov2.vit import ViT -FLAVORS_MAP = { +FLAVORS_MAP_REFINERS = { "dinov2_vits14": DINOv2_small, "dinov2_vits14_reg": DINOv2_small_reg, "dinov2_vitb14": DINOv2_base, @@ -28,6 +29,27 @@ FLAVORS_MAP = { "dinov2_vitg14": DINOv2_giant, "dinov2_vitg14_reg": DINOv2_giant_reg, } +FLAVORS_MAP_HUB = { + "dinov2_vits14": "refiners/dinov2.small.patch_14", + "dinov2_vits14_reg": "refiners/dinov2.small.patch_14.reg_4", + "dinov2_vitb14": "refiners/dinov2.base.patch_14", + "dinov2_vitb14_reg": "refiners/dinov2.base.patch_14.reg_4", + "dinov2_vitl14": "refiners/dinov2.large.patch_14", + "dinov2_vitl14_reg": "refiners/dinov2.large.patch_14.reg_4", + "dinov2_vitg14": "refiners/dinov2.giant.patch_14", + "dinov2_vitg14_reg": "refiners/dinov2.giant.patch_14.reg_4", +} + + +@pytest.fixture(scope="module", params=["float16", "bfloat16"]) +def dtype(request: pytest.FixtureRequest) -> torch.dtype: + match request.param: + case "float16": + return torch.float16 + case "bfloat16": + return torch.bfloat16 + case _ as dtype: + raise ValueError(f"unsupported dtype: {dtype}") @pytest.fixture(scope="module", params=[224, 518]) @@ -35,7 +57,7 @@ def resolution(request: pytest.FixtureRequest) -> int: return request.param -@pytest.fixture(scope="module", params=FLAVORS_MAP.keys()) +@pytest.fixture(scope="module", params=FLAVORS_MAP_REFINERS.keys()) def flavor(request: pytest.FixtureRequest) -> str: return request.param @@ -53,7 +75,14 @@ def dinov2_repo_path(test_repos_path: Path) -> Path: def ref_model( flavor: str, dinov2_repo_path: Path, - test_weights_path: Path, + dinov2_small_unconverted_weights_path: Path, + dinov2_small_reg4_unconverted_weights_path: Path, + dinov2_base_unconverted_weights_path: Path, + dinov2_base_reg4_unconverted_weights_path: Path, + dinov2_large_unconverted_weights_path: Path, + dinov2_large_reg4_unconverted_weights_path: Path, + dinov2_giant_unconverted_weights_path: Path, + dinov2_giant_reg4_unconverted_weights_path: Path, test_device: torch.device, ) -> torch.nn.Module: kwargs: dict[str, Any] = {} @@ -69,34 +98,51 @@ def ref_model( ) model = model.to(device=test_device) - flavor = flavor.replace("_reg", "_reg4") - weights = test_weights_path / f"{flavor}_pretrain.pth" - if not weights.is_file(): - warn(f"could not find weights at {weights}, skipping") - pytest.skip(allow_module_level=True) - model.load_state_dict(load_tensors(weights, device=test_device)) + weight_map = { + "dinov2_vits14": dinov2_small_unconverted_weights_path, + "dinov2_vits14_reg": dinov2_small_reg4_unconverted_weights_path, + "dinov2_vitb14": dinov2_base_unconverted_weights_path, + "dinov2_vitb14_reg": dinov2_base_reg4_unconverted_weights_path, + "dinov2_vitl14": dinov2_large_unconverted_weights_path, + "dinov2_vitl14_reg": dinov2_large_reg4_unconverted_weights_path, + "dinov2_vitg14": dinov2_giant_unconverted_weights_path, + "dinov2_vitg14_reg": dinov2_giant_reg4_unconverted_weights_path, + } + weights_path = weight_map[flavor] + model.load_state_dict(load_tensors(weights_path, device=test_device)) assert isinstance(model, torch.nn.Module) return model @pytest.fixture(scope="module") def our_model( - test_weights_path: Path, flavor: str, + dinov2_small_weights_path: Path, + dinov2_small_reg4_weights_path: Path, + dinov2_base_weights_path: Path, + dinov2_base_reg4_weights_path: Path, + dinov2_large_weights_path: Path, + dinov2_large_reg4_weights_path: Path, + dinov2_giant_weights_path: Path, + dinov2_giant_reg4_weights_path: Path, test_device: torch.device, ) -> ViT: - model = FLAVORS_MAP[flavor](device=test_device) + weight_map = { + "dinov2_vits14": dinov2_small_weights_path, + "dinov2_vits14_reg": dinov2_small_reg4_weights_path, + "dinov2_vitb14": dinov2_base_weights_path, + "dinov2_vitb14_reg": dinov2_base_reg4_weights_path, + "dinov2_vitl14": dinov2_large_weights_path, + "dinov2_vitl14_reg": dinov2_large_reg4_weights_path, + "dinov2_vitg14": dinov2_giant_weights_path, + "dinov2_vitg14_reg": dinov2_giant_reg4_weights_path, + } + weights_path = weight_map[flavor] - flavor = flavor.replace("_reg", "_reg4") - weights = test_weights_path / f"{flavor}_pretrain.safetensors" - if not weights.is_file(): - warn(f"could not find weights at {weights}, skipping") - pytest.skip(allow_module_level=True) - - tensors = load_from_safetensors(weights) + model = FLAVORS_MAP_REFINERS[flavor](device=test_device) + tensors = load_from_safetensors(weights_path) model.load_state_dict(tensors) - return model diff --git a/tests/foundationals/latent_diffusion/conftest.py b/tests/foundationals/latent_diffusion/conftest.py new file mode 100644 index 0000000..8e05561 --- /dev/null +++ b/tests/foundationals/latent_diffusion/conftest.py @@ -0,0 +1,139 @@ +from pathlib import Path + +import pytest +import torch +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline + +from refiners.fluxion.utils import load_from_safetensors +from refiners.foundationals.latent_diffusion import ( + CLIPTextEncoderL, + DoubleTextEncoder, + SD1Autoencoder, + SD1UNet, + SDXLAutoencoder, + SDXLUNet, + StableDiffusion_1, + StableDiffusion_XL, +) + + +@pytest.fixture(scope="module") +def refiners_sd15_autoencoder(sd15_autoencoder_weights_path: Path) -> SD1Autoencoder: + autoencoder = SD1Autoencoder() + tensors = load_from_safetensors(sd15_autoencoder_weights_path) + autoencoder.load_state_dict(tensors) + return autoencoder + + +@pytest.fixture(scope="module") +def refiners_sd15_unet(sd15_unet_weights_path: Path) -> SD1UNet: + unet = SD1UNet(in_channels=4) + tensors = load_from_safetensors(sd15_unet_weights_path) + unet.load_state_dict(tensors) + return unet + + +@pytest.fixture(scope="module") +def refiners_sd15_text_encoder(sd15_text_encoder_weights_path: Path) -> CLIPTextEncoderL: + text_encoder = CLIPTextEncoderL() + tensors = load_from_safetensors(sd15_text_encoder_weights_path) + text_encoder.load_state_dict(tensors) + return text_encoder + + +@pytest.fixture(scope="module") +def refiners_sd15( + refiners_sd15_autoencoder: SD1Autoencoder, + refiners_sd15_unet: SD1UNet, + refiners_sd15_text_encoder: CLIPTextEncoderL, +) -> StableDiffusion_1: + return StableDiffusion_1( + lda=refiners_sd15_autoencoder, + unet=refiners_sd15_unet, + clip_text_encoder=refiners_sd15_text_encoder, + ) + + +@pytest.fixture(scope="module") +def refiners_sdxl_autoencoder(sdxl_autoencoder_weights_path: Path) -> SDXLAutoencoder: + autoencoder = SDXLAutoencoder() + tensors = load_from_safetensors(sdxl_autoencoder_weights_path) + autoencoder.load_state_dict(tensors) + return autoencoder + + +@pytest.fixture(scope="module") +def refiners_sdxl_unet(sdxl_unet_weights_path: Path) -> SDXLUNet: + unet = SDXLUNet(in_channels=4) + tensors = load_from_safetensors(sdxl_unet_weights_path) + unet.load_state_dict(tensors) + return unet + + +@pytest.fixture(scope="module") +def refiners_sdxl_text_encoder(sdxl_text_encoder_weights_path: Path) -> DoubleTextEncoder: + text_encoder = DoubleTextEncoder() + tensors = load_from_safetensors(sdxl_text_encoder_weights_path) + text_encoder.load_state_dict(tensors) + return text_encoder + + +@pytest.fixture(scope="module") +def refiners_sdxl( + refiners_sdxl_autoencoder: SDXLAutoencoder, + refiners_sdxl_unet: SDXLUNet, + refiners_sd15_text_encoder: DoubleTextEncoder, +) -> StableDiffusion_XL: + return StableDiffusion_XL( + lda=refiners_sdxl_autoencoder, + unet=refiners_sdxl_unet, + clip_text_encoder=refiners_sd15_text_encoder, + ) + + +@pytest.fixture(scope="module", params=["SD1.5", "SDXL"]) +def refiners_autoencoder( + request: pytest.FixtureRequest, + refiners_sd15_autoencoder: SD1Autoencoder, + refiners_sdxl_autoencoder: SDXLAutoencoder, + test_dtype_fp32_bf16_fp16: torch.dtype, +) -> SD1Autoencoder | SDXLAutoencoder: + model_version = request.param + match (model_version, test_dtype_fp32_bf16_fp16): + case ("SD1.5", _): + return refiners_sd15_autoencoder + case ("SDXL", torch.float16): + return refiners_sdxl_autoencoder + case ("SDXL", _): + return refiners_sdxl_autoencoder + case _: + raise ValueError(f"Unknown model version: {model_version}") + + +@pytest.fixture(scope="module") +def diffusers_sd15_pipeline( + sd15_diffusers_runwayml_path: str, + use_local_weights: bool, +) -> StableDiffusionPipeline: + return StableDiffusionPipeline.from_pretrained( # type: ignore + sd15_diffusers_runwayml_path, + local_files_only=use_local_weights, + ) + + +@pytest.fixture(scope="module") +def diffusers_sdxl_pipeline( + sdxl_diffusers_stabilityai_path: str, + use_local_weights: bool, +) -> StableDiffusionXLPipeline: + return StableDiffusionXLPipeline.from_pretrained( # type: ignore + sdxl_diffusers_stabilityai_path, + local_files_only=use_local_weights, + ) + + +@pytest.fixture(scope="module") +def diffusers_sdxl_unet(diffusers_sdxl_pipeline: StableDiffusionXLPipeline) -> UNet2DConditionModel: + return diffusers_sdxl_pipeline.unet # type: ignore diff --git a/tests/foundationals/latent_diffusion/test_auto_encoder.py b/tests/foundationals/latent_diffusion/test_auto_encoder.py deleted file mode 100644 index eedd251..0000000 --- a/tests/foundationals/latent_diffusion/test_auto_encoder.py +++ /dev/null @@ -1,134 +0,0 @@ -from pathlib import Path -from warnings import warn - -import pytest -import torch -from PIL import Image -from tests.utils import ensure_similar_images - -from refiners.fluxion.utils import no_grad -from refiners.foundationals.latent_diffusion import LatentDiffusionAutoencoder, SD1Autoencoder, SDXLAutoencoder - - -@pytest.fixture(scope="module") -def ref_path() -> Path: - return Path(__file__).parent / "test_auto_encoder_ref" - - -@pytest.fixture(scope="module", params=["SD1.5", "SDXL"]) -def lda( - request: pytest.FixtureRequest, - test_weights_path: Path, - test_dtype_fp32_bf16_fp16: torch.dtype, - test_device: torch.device, -) -> LatentDiffusionAutoencoder: - model_version = request.param - match (model_version, test_dtype_fp32_bf16_fp16): - case ("SD1.5", _): - weight_path = test_weights_path / "lda.safetensors" - if not weight_path.is_file(): - warn(f"could not find weights at {weight_path}, skipping") - pytest.skip(allow_module_level=True) - model = SD1Autoencoder().load_from_safetensors(weight_path) - case ("SDXL", torch.float16): - weight_path = test_weights_path / "sdxl-lda-fp16-fix.safetensors" - if not weight_path.is_file(): - warn(f"could not find weights at {weight_path}, skipping") - pytest.skip(allow_module_level=True) - model = SDXLAutoencoder().load_from_safetensors(weight_path) - case ("SDXL", _): - weight_path = test_weights_path / "sdxl-lda.safetensors" - if not weight_path.is_file(): - warn(f"could not find weights at {weight_path}, skipping") - pytest.skip(allow_module_level=True) - model = SDXLAutoencoder().load_from_safetensors(weight_path) - case _: - raise ValueError(f"Unknown model version: {model_version}") - model = model.to(device=test_device, dtype=test_dtype_fp32_bf16_fp16) - return model - - -@pytest.fixture(scope="module") -def sample_image(ref_path: Path) -> Image.Image: - test_image = ref_path / "macaw.png" - if not test_image.is_file(): - warn(f"could not reference image at {test_image}, skipping") - pytest.skip(allow_module_level=True) - img = Image.open(test_image) # type: ignore - assert img.size == (512, 512) - return img - - -@no_grad() -def test_encode_decode_image(lda: LatentDiffusionAutoencoder, sample_image: Image.Image): - encoded = lda.image_to_latents(sample_image) - decoded = lda.latents_to_image(encoded) - - assert decoded.mode == "RGB" # type: ignore - - # Ensure no saturation. The green channel (band = 1) must not max out. - assert max(iter(decoded.getdata(band=1))) < 255 # type: ignore - - ensure_similar_images(sample_image, decoded, min_psnr=20, min_ssim=0.9) - - -@no_grad() -def test_encode_decode_images(lda: LatentDiffusionAutoencoder, sample_image: Image.Image): - encoded = lda.images_to_latents([sample_image, sample_image]) - images = lda.latents_to_images(encoded) - assert isinstance(images, list) - assert len(images) == 2 - ensure_similar_images(sample_image, images[1], min_psnr=20, min_ssim=0.9) - - -@no_grad() -def test_tiled_autoencoder(lda: LatentDiffusionAutoencoder, sample_image: Image.Image): - sample_image = sample_image.resize((2048, 2048)) # type: ignore - - with lda.tiled_inference(sample_image, tile_size=(512, 512)): - encoded = lda.tiled_image_to_latents(sample_image) - result = lda.tiled_latents_to_image(encoded) - - ensure_similar_images(sample_image, result, min_psnr=35, min_ssim=0.985) - - -@no_grad() -def test_tiled_autoencoder_rectangular_tiles(lda: LatentDiffusionAutoencoder, sample_image: Image.Image): - sample_image = sample_image.resize((2048, 2048)) # type: ignore - - with lda.tiled_inference(sample_image, tile_size=(512, 1024)): - encoded = lda.tiled_image_to_latents(sample_image) - result = lda.tiled_latents_to_image(encoded) - - ensure_similar_images(sample_image, result, min_psnr=35, min_ssim=0.985) - - -@no_grad() -def test_tiled_autoencoder_large_tile(lda: LatentDiffusionAutoencoder, sample_image: Image.Image): - sample_image = sample_image.resize((1024, 1024)) # type: ignore - - with lda.tiled_inference(sample_image, tile_size=(2048, 2048)): - encoded = lda.tiled_image_to_latents(sample_image) - result = lda.tiled_latents_to_image(encoded) - - ensure_similar_images(sample_image, result, min_psnr=34, min_ssim=0.975) - - -@no_grad() -def test_tiled_autoencoder_rectangular_image(lda: LatentDiffusionAutoencoder, sample_image: Image.Image): - sample_image = sample_image.crop((0, 0, 300, 500)) - sample_image = sample_image.resize((sample_image.width * 4, sample_image.height * 4)) # type: ignore - - with lda.tiled_inference(sample_image, tile_size=(512, 512)): - encoded = lda.tiled_image_to_latents(sample_image) - result = lda.tiled_latents_to_image(encoded) - - ensure_similar_images(sample_image, result, min_psnr=37, min_ssim=0.985) - - -def test_value_error_tile_encode_no_context(lda: LatentDiffusionAutoencoder, sample_image: Image.Image) -> None: - with pytest.raises(ValueError): - lda.tiled_image_to_latents(sample_image) - - with pytest.raises(ValueError): - lda.tiled_latents_to_image(torch.randn(1, 8, 16, 16, device=lda.device)) diff --git a/tests/foundationals/latent_diffusion/test_autoencoders.py b/tests/foundationals/latent_diffusion/test_autoencoders.py new file mode 100644 index 0000000..d21ac3e --- /dev/null +++ b/tests/foundationals/latent_diffusion/test_autoencoders.py @@ -0,0 +1,104 @@ +from pathlib import Path +from warnings import warn + +import pytest +import torch +from PIL import Image +from tests.utils import ensure_similar_images + +from refiners.fluxion.utils import no_grad +from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder + + +@pytest.fixture(scope="module") +def sample_image() -> Image.Image: + test_image = Path(__file__).parent / "test_auto_encoder_ref" / "macaw.png" + if not test_image.is_file(): + warn(f"could not reference image at {test_image}, skipping") + pytest.skip(allow_module_level=True) + img = Image.open(test_image) # type: ignore + assert img.size == (512, 512) + return img + + +@pytest.fixture(scope="module") +def autoencoder( + refiners_autoencoder: LatentDiffusionAutoencoder, + test_device: torch.device, +) -> LatentDiffusionAutoencoder: + return refiners_autoencoder.to(test_device) + + +@no_grad() +def test_encode_decode_image(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image): + encoded = autoencoder.image_to_latents(sample_image) + decoded = autoencoder.latents_to_image(encoded) + + assert decoded.mode == "RGB" # type: ignore + + # Ensure no saturation. The green channel (band = 1) must not max out. + assert max(iter(decoded.getdata(band=1))) < 255 # type: ignore + + ensure_similar_images(sample_image, decoded, min_psnr=20, min_ssim=0.9) + + +@no_grad() +def test_encode_decode_images(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image): + encoded = autoencoder.images_to_latents([sample_image, sample_image]) + images = autoencoder.latents_to_images(encoded) + assert isinstance(images, list) + assert len(images) == 2 + ensure_similar_images(sample_image, images[1], min_psnr=20, min_ssim=0.9) + + +@no_grad() +def test_tiled_autoencoder(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image): + sample_image = sample_image.resize((2048, 2048)) # type: ignore + + with autoencoder.tiled_inference(sample_image, tile_size=(512, 512)): + encoded = autoencoder.tiled_image_to_latents(sample_image) + result = autoencoder.tiled_latents_to_image(encoded) + + ensure_similar_images(sample_image, result, min_psnr=35, min_ssim=0.985) + + +@no_grad() +def test_tiled_autoencoder_rectangular_tiles(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image): + sample_image = sample_image.resize((2048, 2048)) # type: ignore + + with autoencoder.tiled_inference(sample_image, tile_size=(512, 1024)): + encoded = autoencoder.tiled_image_to_latents(sample_image) + result = autoencoder.tiled_latents_to_image(encoded) + + ensure_similar_images(sample_image, result, min_psnr=35, min_ssim=0.985) + + +@no_grad() +def test_tiled_autoencoder_large_tile(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image): + sample_image = sample_image.resize((1024, 1024)) # type: ignore + + with autoencoder.tiled_inference(sample_image, tile_size=(2048, 2048)): + encoded = autoencoder.tiled_image_to_latents(sample_image) + result = autoencoder.tiled_latents_to_image(encoded) + + ensure_similar_images(sample_image, result, min_psnr=34, min_ssim=0.975) + + +@no_grad() +def test_tiled_autoencoder_rectangular_image(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image): + sample_image = sample_image.crop((0, 0, 300, 500)) + sample_image = sample_image.resize((sample_image.width * 4, sample_image.height * 4)) # type: ignore + + with autoencoder.tiled_inference(sample_image, tile_size=(512, 512)): + encoded = autoencoder.tiled_image_to_latents(sample_image) + result = autoencoder.tiled_latents_to_image(encoded) + + ensure_similar_images(sample_image, result, min_psnr=37, min_ssim=0.985) + + +def test_value_error_tile_encode_no_context(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image) -> None: + with pytest.raises(ValueError): + autoencoder.tiled_image_to_latents(sample_image) + + with pytest.raises(ValueError): + autoencoder.tiled_latents_to_image(torch.randn(1, 8, 16, 16, device=autoencoder.device)) diff --git a/tests/foundationals/latent_diffusion/test_model.py b/tests/foundationals/latent_diffusion/test_model.py deleted file mode 100644 index dafd38f..0000000 --- a/tests/foundationals/latent_diffusion/test_model.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -from PIL import Image - -from refiners.fluxion.utils import manual_seed, no_grad -from refiners.foundationals.latent_diffusion import StableDiffusion_1_Inpainting -from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel - - -@no_grad() -def test_sample_noise(): - manual_seed(2) - latents_0 = LatentDiffusionModel.sample_noise(size=(1, 4, 64, 64)) - manual_seed(2) - latents_1 = LatentDiffusionModel.sample_noise(size=(1, 4, 64, 64), offset_noise=0.0) - - assert torch.allclose(latents_0, latents_1, atol=1e-6, rtol=0) - - -@no_grad() -def test_sd1_inpainting(test_device: torch.device) -> None: - sd = StableDiffusion_1_Inpainting(device=test_device) - - latent_noise = torch.randn(1, 4, 64, 64, device=test_device) - target_image = Image.new("RGB", (512, 512)) - mask = Image.new("L", (512, 512)) - - sd.set_inpainting_conditions(target_image=target_image, mask=mask) - text_embedding = sd.compute_clip_text_embedding("") - output = sd(latent_noise, step=0, clip_text_embedding=text_embedding) - - assert output.shape == (1, 4, 64, 64) diff --git a/tests/foundationals/latent_diffusion/test_sd15_text_encoder.py b/tests/foundationals/latent_diffusion/test_sd15_text_encoder.py new file mode 100644 index 0000000..b4b7845 --- /dev/null +++ b/tests/foundationals/latent_diffusion/test_sd15_text_encoder.py @@ -0,0 +1,65 @@ +import torch +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline +from torch import Tensor + +from refiners.fluxion.utils import manual_seed, no_grad +from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL + + +@no_grad() +def test_text_encoder( + diffusers_sd15_pipeline: StableDiffusionPipeline, + refiners_sd15_text_encoder: CLIPTextEncoderL, +) -> None: + """Compare our refiners implementation with the diffusers implementation.""" + manual_seed(seed=0) # unnecessary, but just in case + prompt = "A photo of a pizza." + negative_prompt = "" + atol = 1e-2 # FIXME: very high tolerance, figure out why + + ( # encode text prompts using diffusers pipeline + diffusers_embeds, # type: ignore + diffusers_negative_embeds, # type: ignore + ) = diffusers_sd15_pipeline.encode_prompt( # type: ignore + prompt=prompt, + negative_prompt=negative_prompt, + num_images_per_prompt=1, + do_classifier_free_guidance=True, + device=diffusers_sd15_pipeline.device, + ) + assert isinstance(diffusers_embeds, Tensor) + assert isinstance(diffusers_negative_embeds, Tensor) + + # encode text prompts using refiners model + refiners_embeds = refiners_sd15_text_encoder(prompt) + refiners_negative_embeds = refiners_sd15_text_encoder("") + + # check that the shapes are the same + assert diffusers_embeds.shape == refiners_embeds.shape == (1, 77, 768) + assert diffusers_negative_embeds.shape == refiners_negative_embeds.shape == (1, 77, 768) + + # check that the values are close + assert torch.allclose(input=refiners_embeds, other=diffusers_embeds, atol=atol) + assert torch.allclose(input=refiners_negative_embeds, other=diffusers_negative_embeds, atol=atol) + + +@no_grad() +def test_text_encoder_batched(refiners_sd15_text_encoder: CLIPTextEncoderL) -> None: + """Check that encoding two prompts works as expected whether batched or not.""" + manual_seed(seed=0) # unnecessary, but just in case + prompt1 = "A photo of a pizza." + prompt2 = "A giant duck." + atol = 1e-6 + + # encode the two prompts at once + embeds_batched = refiners_sd15_text_encoder([prompt1, prompt2]) + assert embeds_batched.shape == (2, 77, 768) + + # encode the prompts one by one + embeds_1 = refiners_sd15_text_encoder(prompt1) + embeds_2 = refiners_sd15_text_encoder(prompt2) + assert embeds_1.shape == embeds_2.shape == (1, 77, 768) + + # check that the values are close + assert torch.allclose(input=embeds_1, other=embeds_batched[0].unsqueeze(0), atol=atol) + assert torch.allclose(input=embeds_2, other=embeds_batched[1].unsqueeze(0), atol=atol) diff --git a/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py b/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py deleted file mode 100644 index d738ff2..0000000 --- a/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py +++ /dev/null @@ -1,121 +0,0 @@ -from pathlib import Path -from typing import Any, Protocol, cast -from warnings import warn - -import pytest -import torch -from torch import Tensor - -import refiners.fluxion.layers as fl -from refiners.fluxion.utils import manual_seed, no_grad -from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder - - -class DiffusersSDXL(Protocol): - unet: fl.Module - text_encoder: fl.Module - text_encoder_2: fl.Module - tokenizer: fl.Module - tokenizer_2: fl.Module - vae: fl.Module - - def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any: ... - - def encode_prompt( - self, - prompt: str, - prompt_2: str | None = None, - negative_prompt: str | None = None, - negative_prompt_2: str | None = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ... - - -@pytest.fixture(scope="module") -def stabilityai_sdxl_base_path(test_weights_path: Path) -> Path: - r = test_weights_path / "stabilityai" / "stable-diffusion-xl-base-1.0" - if not r.is_dir(): - warn(message=f"could not find Stability SDXL base weights at {r}, skipping") - pytest.skip(allow_module_level=True) - return r - - -@pytest.fixture(scope="module") -def double_text_encoder_weights(test_weights_path: Path) -> Path: - text_encoder_weights = test_weights_path / "DoubleCLIPTextEncoder.safetensors" - if not text_encoder_weights.is_file(): - warn(f"could not find weights at {text_encoder_weights}, skipping") - pytest.skip(allow_module_level=True) - return text_encoder_weights - - -@pytest.fixture(scope="module") -def diffusers_sdxl(stabilityai_sdxl_base_path: Path) -> Any: - from diffusers import DiffusionPipeline # type: ignore - - return DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=stabilityai_sdxl_base_path) # type: ignore - - -@pytest.fixture(scope="module") -def double_text_encoder(double_text_encoder_weights: Path) -> DoubleTextEncoder: - double_text_encoder = DoubleTextEncoder() - double_text_encoder.load_from_safetensors(double_text_encoder_weights) - - return double_text_encoder - - -@no_grad() -def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder: DoubleTextEncoder) -> None: - manual_seed(seed=0) - prompt = "A photo of a pizza." - - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = diffusers_sdxl.encode_prompt(prompt=prompt, negative_prompt="") - - double_embedding, pooled_embedding = double_text_encoder(prompt) - - assert double_embedding.shape == torch.Size([1, 77, 2048]) - assert pooled_embedding.shape == torch.Size([1, 1280]) - - embedding_1, embedding_2 = cast( - tuple[Tensor, Tensor], - prompt_embeds.split(split_size=[768, 1280], dim=-1), # type: ignore - ) - - rembedding_1, rembedding_2 = cast( - tuple[Tensor, Tensor], - double_embedding.split(split_size=[768, 1280], dim=-1), # type: ignore - ) - - assert torch.allclose(input=embedding_1, other=rembedding_1, rtol=1e-3, atol=1e-3) - assert torch.allclose(input=embedding_2, other=rembedding_2, rtol=1e-3, atol=1e-3) - assert torch.allclose(input=pooled_embedding, other=pooled_prompt_embeds, rtol=1e-3, atol=1e-3) - - negative_double_embedding, negative_pooled_embedding = double_text_encoder("") - - assert torch.allclose(input=negative_double_embedding, other=negative_prompt_embeds, rtol=1e-3, atol=1e-3) - assert torch.allclose(input=negative_pooled_embedding, other=negative_pooled_prompt_embeds, rtol=1e-3, atol=1e-3) - - -@no_grad() -def test_double_text_encoder_batch2(double_text_encoder: DoubleTextEncoder) -> None: - manual_seed(seed=0) - prompt1 = "A photo of a pizza." - prompt2 = "A giant duck." - - double_embedding_b2, pooled_embedding_b2 = double_text_encoder([prompt1, prompt2]) - - assert double_embedding_b2.shape == torch.Size([2, 77, 2048]) - assert pooled_embedding_b2.shape == torch.Size([2, 1280]) - - double_embedding_1, pooled_embedding_1 = double_text_encoder(prompt1) - double_embedding_2, pooled_embedding_2 = double_text_encoder(prompt2) - - assert torch.allclose(input=double_embedding_1, other=double_embedding_b2[0:1], rtol=1e-3, atol=1e-3) - assert torch.allclose(input=pooled_embedding_1, other=pooled_embedding_b2[0:1], rtol=1e-3, atol=1e-3) - - assert torch.allclose(input=double_embedding_2, other=double_embedding_b2[1:2], rtol=1e-3, atol=1e-3) - assert torch.allclose(input=pooled_embedding_2, other=pooled_embedding_b2[1:2], rtol=1e-3, atol=1e-3) diff --git a/tests/foundationals/latent_diffusion/test_sdxl_text_encoders.py b/tests/foundationals/latent_diffusion/test_sdxl_text_encoders.py new file mode 100644 index 0000000..c842427 --- /dev/null +++ b/tests/foundationals/latent_diffusion/test_sdxl_text_encoders.py @@ -0,0 +1,70 @@ +import torch +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline +from torch import Tensor + +from refiners.fluxion.utils import manual_seed, no_grad +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder + + +@no_grad() +def test_double_text_encoder( + diffusers_sdxl_pipeline: StableDiffusionXLPipeline, + refiners_sdxl_text_encoder: DoubleTextEncoder, +) -> None: + """Compare our refiners implementation with the diffusers implementation.""" + manual_seed(seed=0) # unnecessary, but just in case + prompt = "A photo of a pizza." + negative_prompt = "" + atol = 1e-6 + + ( # encode text prompts using diffusers pipeline + diffusers_embeds, + diffusers_negative_embeds, # type: ignore + diffusers_pooled_embeds, # type: ignore + diffusers_negative_pooled_embeds, # type: ignore + ) = diffusers_sdxl_pipeline.encode_prompt(prompt=prompt, negative_prompt=negative_prompt) + assert diffusers_negative_embeds is not None + assert isinstance(diffusers_pooled_embeds, Tensor) + assert isinstance(diffusers_negative_pooled_embeds, Tensor) + + # encode text prompts using refiners model + refiners_embeds, refiners_pooled_embeds = refiners_sdxl_text_encoder(prompt) + refiners_negative_embeds, refiners_negative_pooled_embeds = refiners_sdxl_text_encoder("") + + # check that the shapes are the same + assert diffusers_embeds.shape == refiners_embeds.shape == (1, 77, 2048) + assert diffusers_pooled_embeds.shape == refiners_pooled_embeds.shape == (1, 1280) + assert diffusers_negative_embeds.shape == refiners_negative_embeds.shape == (1, 77, 2048) + assert diffusers_negative_pooled_embeds.shape == refiners_negative_pooled_embeds.shape == (1, 1280) + + # check that the values are close + assert torch.allclose(input=refiners_embeds, other=diffusers_embeds, atol=atol) + assert torch.allclose(input=refiners_negative_embeds, other=diffusers_negative_embeds, atol=atol) + assert torch.allclose(input=refiners_negative_pooled_embeds, other=diffusers_negative_pooled_embeds, atol=atol) + assert torch.allclose(input=refiners_pooled_embeds, other=diffusers_pooled_embeds, atol=atol) + + +@no_grad() +def test_double_text_encoder_batched(refiners_sdxl_text_encoder: DoubleTextEncoder) -> None: + """Check that encoding two prompts works as expected whether batched or not.""" + manual_seed(seed=0) # unnecessary, but just in case + prompt1 = "A photo of a pizza." + prompt2 = "A giant duck." + atol = 1e-6 + + # encode the two prompts at once + embeds_batched, pooled_embeds_batched = refiners_sdxl_text_encoder([prompt1, prompt2]) + assert embeds_batched.shape == (2, 77, 2048) + assert pooled_embeds_batched.shape == (2, 1280) + + # encode the prompts one by one + embeds_1, pooled_embeds_1 = refiners_sdxl_text_encoder(prompt1) + embeds_2, pooled_embeds_2 = refiners_sdxl_text_encoder(prompt2) + assert embeds_1.shape == embeds_2.shape == (1, 77, 2048) + assert pooled_embeds_1.shape == pooled_embeds_2.shape == (1, 1280) + + # check that the values are close + assert torch.allclose(input=embeds_1, other=embeds_batched[0].unsqueeze(0), atol=atol) + assert torch.allclose(input=pooled_embeds_1, other=pooled_embeds_batched[0].unsqueeze(0), atol=atol) + assert torch.allclose(input=embeds_2, other=embeds_batched[1].unsqueeze(0), atol=atol) + assert torch.allclose(input=pooled_embeds_2, other=pooled_embeds_batched[1].unsqueeze(0), atol=atol) diff --git a/tests/foundationals/latent_diffusion/test_sdxl_unet.py b/tests/foundationals/latent_diffusion/test_sdxl_unet.py index c3d0f10..b4855b2 100644 --- a/tests/foundationals/latent_diffusion/test_sdxl_unet.py +++ b/tests/foundationals/latent_diffusion/test_sdxl_unet.py @@ -1,36 +1,13 @@ -from pathlib import Path from typing import Any -from warnings import warn import pytest import torch -from refiners.fluxion.model_converter import ConversionStage, ModelConverter +from refiners.conversion.model_converter import ConversionStage, ModelConverter from refiners.fluxion.utils import manual_seed, no_grad from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet -@pytest.fixture(scope="module") -def stabilityai_sdxl_base_path(test_weights_path: Path) -> Path: - r = test_weights_path / "stabilityai" / "stable-diffusion-xl-base-1.0" - if not r.is_dir(): - warn(f"could not find Stability SDXL base weights at {r}, skipping") - pytest.skip(allow_module_level=True) - return r - - -@pytest.fixture(scope="module") -def diffusers_sdxl(stabilityai_sdxl_base_path: Path) -> Any: - from diffusers import DiffusionPipeline # type: ignore - - return DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=stabilityai_sdxl_base_path) # type: ignore - - -@pytest.fixture(scope="module") -def diffusers_sdxl_unet(diffusers_sdxl: Any) -> Any: - return diffusers_sdxl.unet - - @pytest.fixture(scope="module") def refiners_sdxl_unet() -> SDXLUNet: unet = SDXLUNet(in_channels=4) @@ -38,7 +15,10 @@ def refiners_sdxl_unet() -> SDXLUNet: @no_grad() -def test_sdxl_unet(diffusers_sdxl_unet: Any, refiners_sdxl_unet: SDXLUNet) -> None: +def test_sdxl_unet( + diffusers_sdxl_unet: Any, + refiners_sdxl_unet: SDXLUNet, +) -> None: source = diffusers_sdxl_unet target = refiners_sdxl_unet diff --git a/tests/foundationals/segment_anything/conftest.py b/tests/foundationals/segment_anything/conftest.py index 66d19a4..635ff00 100644 --- a/tests/foundationals/segment_anything/conftest.py +++ b/tests/foundationals/segment_anything/conftest.py @@ -1,8 +1,7 @@ import gc from pathlib import Path -from warnings import warn -from pytest import fixture, skip +from pytest import fixture @fixture(autouse=True) @@ -15,12 +14,3 @@ def ensure_gc(): @fixture(scope="package") def ref_path(test_sam_path: Path) -> Path: return test_sam_path / "test_sam_ref" - - -@fixture(scope="package") -def sam_h_weights(test_weights_path: Path) -> Path: - sam_h_weights = test_weights_path / "segment-anything-h.safetensors" - if not sam_h_weights.is_file(): - warn(f"could not find weights at {sam_h_weights}, skipping") - skip(allow_module_level=True) - return sam_h_weights diff --git a/tests/foundationals/segment_anything/test_hq_sam.py b/tests/foundationals/segment_anything/test_hq_sam.py index cee6fda..439bd26 100644 --- a/tests/foundationals/segment_anything/test_hq_sam.py +++ b/tests/foundationals/segment_anything/test_hq_sam.py @@ -1,6 +1,5 @@ from pathlib import Path from typing import cast -from warnings import warn import numpy as np import pytest @@ -36,37 +35,17 @@ def tennis(ref_path: Path) -> Image.Image: return Image.open(ref_path / "tennis.png").convert("RGB") # type: ignore -@pytest.fixture(scope="module") -def hq_adapter_weights(test_weights_path: Path) -> Path: - """Path to the HQ adapter weights in Refiners format""" - refiners_hq_adapter_sam_weights = test_weights_path / "refiners-sam-hq-vit-h.safetensors" - if not refiners_hq_adapter_sam_weights.is_file(): - warn(f"Test weights not found at {refiners_hq_adapter_sam_weights}, skipping") - pytest.skip(allow_module_level=True) - return refiners_hq_adapter_sam_weights - - @pytest.fixture -def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH: +def sam_h(sam_h_weights_path: Path, test_device: torch.device) -> SegmentAnythingH: # HQSAMAdapter is designed to be used with single-output only, hence multimask_output=False. sam_h = SegmentAnythingH(multimask_output=False, device=test_device) - sam_h.load_from_safetensors(tensors_path=sam_h_weights) + sam_h.load_from_safetensors(tensors_path=sam_h_weights_path) return sam_h @pytest.fixture(scope="module") -def reference_hq_adapter_weights(test_weights_path: Path) -> Path: - """Path to the HQ adapter weights in default format""" - reference_hq_adapter_sam_weights = test_weights_path / "sam_hq_vit_h.pth" - if not reference_hq_adapter_sam_weights.is_file(): - warn(f"Test weights not found at {reference_hq_adapter_sam_weights}, skipping") - pytest.skip(allow_module_level=True) - return reference_hq_adapter_sam_weights - - -@pytest.fixture(scope="module") -def reference_sam_h(reference_hq_adapter_weights: Path, test_device: torch.device) -> FacebookSAM: - sam_h = cast(FacebookSAM, sam_model_registry_hq["vit_h"](checkpoint=reference_hq_adapter_weights)) +def reference_sam_h(sam_h_hq_adapter_unconverted_weights_path: Path, test_device: torch.device) -> FacebookSAM: + sam_h = cast(FacebookSAM, sam_model_registry_hq["vit_h"](checkpoint=sam_h_hq_adapter_unconverted_weights_path)) return sam_h.to(device=test_device) @@ -142,11 +121,11 @@ def test_mask_decoder_tokens_extender() -> None: @no_grad() def test_early_vit_embedding( sam_h: SegmentAnythingH, - hq_adapter_weights: Path, + sam_h_hq_adapter_weights_path: Path, reference_sam_h: FacebookSAM, tennis: Image.Image, ) -> None: - HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() + HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject() image_tensor = image_to_tensor(image=tennis.resize(size=(1024, 1024))) # type: ignore @@ -159,8 +138,8 @@ def test_early_vit_embedding( assert torch.equal(early_vit_embedding, early_vit_embedding_refiners) -def test_tokens(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None: - HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() +def test_tokens(sam_h: SegmentAnythingH, sam_h_hq_adapter_weights_path: Path, reference_sam_h: FacebookSAM) -> None: + HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject() mask_decoder_tokens_extender = sam_h.mask_decoder.ensure_find(MaskDecoderTokensExtender) @@ -175,8 +154,10 @@ def test_tokens(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam @no_grad() -def test_compress_vit_feat(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None: - HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() +def test_compress_vit_feat( + sam_h: SegmentAnythingH, sam_h_hq_adapter_weights_path: Path, reference_sam_h: FacebookSAM +) -> None: + HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject() early_vit_embedding = torch.randn(1, 64, 64, 1280, device=sam_h.device, dtype=sam_h.dtype) @@ -189,8 +170,10 @@ def test_compress_vit_feat(sam_h: SegmentAnythingH, hq_adapter_weights: Path, re @no_grad() -def test_embedding_encoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None: - HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() +def test_embedding_encoder( + sam_h: SegmentAnythingH, sam_h_hq_adapter_weights_path: Path, reference_sam_h: FacebookSAM +) -> None: + HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject() x = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype) @@ -203,8 +186,10 @@ def test_embedding_encoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path, re @no_grad() -def test_hq_token_mlp(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None: - HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() +def test_hq_token_mlp( + sam_h: SegmentAnythingH, sam_h_hq_adapter_weights_path: Path, reference_sam_h: FacebookSAM +) -> None: + HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject() x = torch.randn(1, 6, 256, device=sam_h.device, dtype=sam_h.dtype) @@ -217,13 +202,13 @@ def test_hq_token_mlp(sam_h: SegmentAnythingH, hq_adapter_weights: Path, referen @pytest.mark.parametrize("hq_mask_only", [True, False]) def test_predictor( sam_h: SegmentAnythingH, - hq_adapter_weights: Path, + sam_h_hq_adapter_weights_path: Path, hq_mask_only: bool, reference_sam_h_predictor: FacebookSAMPredictorHQ, tennis: Image.Image, one_prompt: SAMPrompt, ) -> None: - adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() + adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject() adapter.hq_mask_only = hq_mask_only assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only @@ -268,13 +253,13 @@ def test_predictor( @pytest.mark.parametrize("hq_mask_only", [True, False]) def test_predictor_equal( sam_h: SegmentAnythingH, - hq_adapter_weights: Path, + sam_h_hq_adapter_weights_path: Path, hq_mask_only: bool, reference_sam_h_predictor: FacebookSAMPredictorHQ, tennis: Image.Image, one_prompt: SAMPrompt, ) -> None: - adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() + adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject() adapter.hq_mask_only = hq_mask_only assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only @@ -318,8 +303,8 @@ def test_predictor_equal( @no_grad() -def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) -> None: - HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() +def test_batch_mask_decoder(sam_h: SegmentAnythingH, sam_h_hq_adapter_weights_path: Path) -> None: + HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject() batch_size = 5 @@ -348,8 +333,10 @@ def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) - assert torch.equal(mask_prediction[0], mask_prediction[1]) -def test_hq_sam_load_save_weights(sam_h: SegmentAnythingH, hq_adapter_weights: Path, test_device: torch.device) -> None: - weights = load_from_safetensors(hq_adapter_weights, device=test_device) +def test_hq_sam_load_save_weights( + sam_h: SegmentAnythingH, sam_h_hq_adapter_weights_path: Path, test_device: torch.device +) -> None: + weights = load_from_safetensors(sam_h_hq_adapter_weights_path, device=test_device) hq_sam_adapter = HQSAMAdapter(sam_h) out_weights_init = hq_sam_adapter.weights diff --git a/tests/foundationals/segment_anything/test_sam.py b/tests/foundationals/segment_anything/test_sam.py index b87e566..95eed70 100644 --- a/tests/foundationals/segment_anything/test_sam.py +++ b/tests/foundationals/segment_anything/test_sam.py @@ -1,7 +1,6 @@ from math import isclose from pathlib import Path from typing import cast -from warnings import warn import numpy as np import pytest @@ -17,8 +16,8 @@ from tests.foundationals.segment_anything.utils import ( from torch import Tensor import refiners.fluxion.layers as fl +from refiners.conversion.model_converter import ModelConverter from refiners.fluxion import manual_seed -from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder @@ -49,20 +48,11 @@ def one_prompt() -> SAMPrompt: @pytest.fixture(scope="module") -def facebook_sam_h_weights(test_weights_path: Path) -> Path: - sam_h_weights = test_weights_path / "sam_vit_h_4b8939.pth" - if not sam_h_weights.is_file(): - warn(f"could not find weights at {sam_h_weights}, skipping") - pytest.skip(allow_module_level=True) - return sam_h_weights - - -@pytest.fixture(scope="module") -def facebook_sam_h(facebook_sam_h_weights: Path, test_device: torch.device) -> FacebookSAM: +def facebook_sam_h(sam_h_unconverted_weights_path: Path, test_device: torch.device) -> FacebookSAM: from segment_anything import build_sam_vit_h # type: ignore sam_h = cast(FacebookSAM, build_sam_vit_h()) - sam_h.load_state_dict(state_dict=load_tensors(facebook_sam_h_weights)) + sam_h.load_state_dict(state_dict=load_tensors(sam_h_unconverted_weights_path)) return sam_h.to(device=test_device) @@ -76,16 +66,16 @@ def facebook_sam_h_predictor(facebook_sam_h: FacebookSAM) -> FacebookSAMPredicto @pytest.fixture(scope="module") -def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH: +def sam_h(sam_h_weights_path: Path, test_device: torch.device) -> SegmentAnythingH: sam_h = SegmentAnythingH(device=test_device) - sam_h.load_from_safetensors(tensors_path=sam_h_weights) + sam_h.load_from_safetensors(tensors_path=sam_h_weights_path) return sam_h @pytest.fixture(scope="module") -def sam_h_single_output(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH: +def sam_h_single_output(sam_h_weights_path: Path, test_device: torch.device) -> SegmentAnythingH: sam_h = SegmentAnythingH(multimask_output=False, device=test_device) - sam_h.load_from_safetensors(tensors_path=sam_h_weights) + sam_h.load_from_safetensors(tensors_path=sam_h_weights_path) return sam_h @@ -469,7 +459,10 @@ def test_predictor_resized_single_output( def test_mask_encoder( - facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt + facebook_sam_h_predictor: FacebookSAMPredictor, + sam_h: SegmentAnythingH, + truck: Image.Image, + one_prompt: SAMPrompt, ) -> None: predictor = facebook_sam_h_predictor predictor.set_image(np.array(truck)) diff --git a/tests/training_utils/test_metrics.py b/tests/training_utils/test_metrics.py index 9dcaf0a..801391f 100644 --- a/tests/training_utils/test_metrics.py +++ b/tests/training_utils/test_metrics.py @@ -1,5 +1,4 @@ from pathlib import Path -from warnings import warn import pytest import torch @@ -25,16 +24,11 @@ class CifarDataset(Dataset[torch.Tensor]): @pytest.fixture(scope="module") def dinov2_l( - test_weights_path: Path, + dinov2_large_weights_path: Path, test_device: torch.device, ) -> dinov2.DINOv2_large: - weights = test_weights_path / f"dinov2_vitl14_pretrain.safetensors" - if not weights.is_file(): - warn(f"could not find weights at {weights}, skipping") - pytest.skip(allow_module_level=True) - model = dinov2.DINOv2_large(device=test_device) - model.load_from_safetensors(weights) + model.load_from_safetensors(dinov2_large_weights_path) return model diff --git a/tests/utils.py b/tests/utils.py index df0f2a2..963b44c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -24,11 +24,20 @@ def ensure_similar_images(img_1: Image.Image, img_2: Image.Image, min_psnr: int class T5TextEmbedder(nn.Module): def __init__( - self, pretrained_path: Path = Path("tests/weights/QQGYLab/T5XLFP16"), max_length: int | None = None + self, + pretrained_path: Path | str, + max_length: int | None = None, + local_files_only: bool = False, ) -> None: super().__init__() # type: ignore[reportUnknownMemberType] - self.model: nn.Module = T5EncoderModel.from_pretrained(pretrained_path, local_files_only=True) # type: ignore - self.tokenizer: transformers.T5Tokenizer = T5Tokenizer.from_pretrained(pretrained_path, local_files_only=True) # type: ignore + self.model: nn.Module = T5EncoderModel.from_pretrained( # type: ignore + pretrained_path, + local_files_only=local_files_only, + ) + self.tokenizer: transformers.T5Tokenizer = T5Tokenizer.from_pretrained( # type: ignore + pretrained_path, + local_files_only=local_files_only, + ) self.max_length = max_length def forward(