update tests to use new fixtures

This commit is contained in:
Laurent 2024-10-09 09:28:34 +00:00 committed by Laureηt
parent 94eeb1afc3
commit 316fe6e4f0
26 changed files with 836 additions and 1070 deletions

View file

@ -1,5 +1,4 @@
from pathlib import Path
from warnings import warn
import pytest
import torch
@ -16,14 +15,8 @@ def manager() -> SDLoraManager:
@pytest.fixture
def weights(test_weights_path: Path) -> dict[str, torch.Tensor]:
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin"
if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
return load_tensors(weights_path)
def weights(lora_pokemon_weights_path: Path) -> dict[str, torch.Tensor]:
return load_tensors(lora_pokemon_weights_path)
def test_add_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:

View file

@ -9,6 +9,8 @@ import torch
from PIL import Image
from tests.utils import T5TextEmbedder, ensure_similar_images
from refiners.conversion import controllora_sdxl
from refiners.conversion.utils import Hub
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad
from refiners.foundationals.clip.concepts import ConceptExtender
@ -45,6 +47,8 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler i
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAdapter
from ..weight_paths import get_path
def _img_open(path: Path) -> Image.Image:
return Image.open(path) # type: ignore
@ -194,69 +198,85 @@ def expected_style_aligned(ref_path: Path) -> Image.Image:
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
def controlnet_data(
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
ref_path: Path,
controlnet_depth_weights_path: Path,
controlnet_canny_weights_path: Path,
controlnet_lineart_weights_path: Path,
controlnet_normals_weights_path: Path,
controlnet_sam_weights_path: Path,
request: pytest.FixtureRequest,
) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]:
cn_name: str = request.param
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
weights_fn = {
"depth": "lllyasviel_control_v11f1p_sd15_depth",
"canny": "lllyasviel_control_v11p_sd15_canny",
"lineart": "lllyasviel_control_v11p_sd15_lineart",
"normals": "lllyasviel_control_v11p_sd15_normalbae",
"sam": "mfidabel_controlnet-segment-anything",
}
weights_path = test_weights_path / "controlnet" / f"{weights_fn[cn_name]}.safetensors"
yield (cn_name, condition_image, expected_image, weights_path)
weights_fn = {
"depth": controlnet_depth_weights_path,
"canny": controlnet_canny_weights_path,
"lineart": controlnet_lineart_weights_path,
"normals": controlnet_normals_weights_path,
"sam": controlnet_sam_weights_path,
}
weights_path = weights_fn[cn_name]
yield cn_name, condition_image, expected_image, weights_path
@pytest.fixture(scope="module", params=["canny"])
def controlnet_data_scale_decay(
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
ref_path: Path,
controlnet_canny_weights_path: Path,
request: pytest.FixtureRequest,
) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]:
cn_name: str = request.param
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}_scale_decay.png").convert("RGB")
weights_fn = {
"canny": "lllyasviel_control_v11p_sd15_canny",
}
weights_path = test_weights_path / "controlnet" / f"{weights_fn[cn_name]}.safetensors"
weights_fn = {
"canny": controlnet_canny_weights_path,
}
weights_path = weights_fn[cn_name]
yield (cn_name, condition_image, expected_image, weights_path)
@pytest.fixture(scope="module")
def controlnet_data_tile(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, Image.Image, Path]:
def controlnet_data_tile(
ref_path: Path,
controlnet_tiles_weights_path: Path,
) -> tuple[Image.Image, Image.Image, Path]:
condition_image = _img_open(ref_path / f"low_res_dog.png").convert("RGB").resize((1024, 1024)) # type: ignore
expected_image = _img_open(ref_path / f"expected_controlnet_tile.png").convert("RGB")
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11f1e_sd15_tile.safetensors"
return condition_image, expected_image, weights_path
return condition_image, expected_image, controlnet_tiles_weights_path
@pytest.fixture(scope="module")
def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
def controlnet_data_canny(
ref_path: Path,
controlnet_canny_weights_path: Path,
) -> tuple[str, Image.Image, Image.Image, Path]:
cn_name = "canny"
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11p_sd15_canny.safetensors"
return cn_name, condition_image, expected_image, weights_path
return cn_name, condition_image, expected_image, controlnet_canny_weights_path
@pytest.fixture(scope="module")
def controlnet_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
def controlnet_data_depth(
ref_path: Path,
controlnet_depth_weights_path: Path,
) -> tuple[str, Image.Image, Image.Image, Path]:
cn_name = "depth"
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11f1p_sd15_depth.safetensors"
return cn_name, condition_image, expected_image, weights_path
return cn_name, condition_image, expected_image, controlnet_depth_weights_path
@dataclass
class ControlLoraConfig:
scale: float
condition_path: str
weights_path: str
weights: Hub
@dataclass
@ -271,38 +291,38 @@ CONTROL_LORA_CONFIGS: dict[str, dict[str, ControlLoraConfig]] = {
"PyraCanny": ControlLoraConfig(
scale=1.0,
condition_path="cutecat_guide_PyraCanny.png",
weights_path="refiners_control-lora-canny-rank128.safetensors",
weights=controllora_sdxl.canny.converted,
),
},
"expected_controllora_CPDS.png": {
"CPDS": ControlLoraConfig(
scale=1.0,
condition_path="cutecat_guide_CPDS.png",
weights_path="refiners_fooocus_xl_cpds_128.safetensors",
weights=controllora_sdxl.cpds.converted,
),
},
"expected_controllora_PyraCanny+CPDS.png": {
"PyraCanny": ControlLoraConfig(
scale=0.55,
condition_path="cutecat_guide_PyraCanny.png",
weights_path="refiners_control-lora-canny-rank128.safetensors",
weights=controllora_sdxl.canny.converted,
),
"CPDS": ControlLoraConfig(
scale=0.55,
condition_path="cutecat_guide_CPDS.png",
weights_path="refiners_fooocus_xl_cpds_128.safetensors",
weights=controllora_sdxl.cpds.converted,
),
},
"expected_controllora_disabled.png": {
"PyraCanny": ControlLoraConfig(
scale=0.0,
condition_path="cutecat_guide_PyraCanny.png",
weights_path="refiners_control-lora-canny-rank128.safetensors",
weights=controllora_sdxl.canny.converted,
),
"CPDS": ControlLoraConfig(
scale=0.0,
condition_path="cutecat_guide_CPDS.png",
weights_path="refiners_fooocus_xl_cpds_128.safetensors",
weights=controllora_sdxl.cpds.converted,
),
},
}
@ -311,8 +331,8 @@ CONTROL_LORA_CONFIGS: dict[str, dict[str, ControlLoraConfig]] = {
@pytest.fixture(params=CONTROL_LORA_CONFIGS.items())
def controllora_sdxl_config(
request: pytest.FixtureRequest,
use_local_weights: bool,
ref_path: Path,
test_weights_path: Path,
) -> tuple[Image.Image, dict[str, ControlLoraResolvedConfig]]:
name: str = request.param[0]
configs: dict[str, ControlLoraConfig] = request.param[1]
@ -322,7 +342,7 @@ def controllora_sdxl_config(
config_name: ControlLoraResolvedConfig(
scale=config.scale,
condition_image=_img_open(ref_path / config.condition_path).convert("RGB"),
weights_path=test_weights_path / "control-loras" / config.weights_path,
weights_path=get_path(config.weights, use_local_weights),
)
for config_name, config in configs.items()
}
@ -331,66 +351,57 @@ def controllora_sdxl_config(
@pytest.fixture(scope="module")
def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
def t2i_adapter_data_depth(
ref_path: Path,
t2i_depth_weights_path: Path,
) -> tuple[str, Image.Image, Image.Image, Path]:
name = "depth"
condition_image = _img_open(ref_path / f"cutecat_guide_{name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_t2i_adapter_{name}.png").convert("RGB")
weights_path = test_weights_path / "T2I-Adapter" / "t2iadapter_depth_sd15v2.safetensors"
return name, condition_image, expected_image, weights_path
return name, condition_image, expected_image, t2i_depth_weights_path
@pytest.fixture(scope="module")
def t2i_adapter_xl_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
def t2i_adapter_xl_data_canny(
ref_path: Path,
t2i_sdxl_canny_weights_path: Path,
) -> tuple[str, Image.Image, Image.Image, Path]:
name = "canny"
condition_image = _img_open(ref_path / f"fairy_guide_{name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB")
weights_path = test_weights_path / "T2I-Adapter" / "t2i-adapter-canny-sdxl-1.0.safetensors"
if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
return name, condition_image, expected_image, weights_path
return name, condition_image, expected_image, t2i_sdxl_canny_weights_path
@pytest.fixture(scope="module")
def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]:
def lora_data_pokemon(
ref_path: Path,
lora_pokemon_weights_path: Path,
) -> tuple[Image.Image, dict[str, torch.Tensor]]:
expected_image = _img_open(ref_path / "expected_lora_pokemon.png").convert("RGB")
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin"
if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
tensors = load_tensors(weights_path)
tensors = load_tensors(lora_pokemon_weights_path)
return expected_image, tensors
@pytest.fixture(scope="module")
def lora_data_dpo(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]:
def lora_data_dpo(
ref_path: Path,
lora_dpo_weights_path: Path,
) -> tuple[Image.Image, dict[str, torch.Tensor]]:
expected_image = _img_open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB")
weights_path = test_weights_path / "loras" / "dpo-lora" / "pytorch_lora_weights.safetensors"
if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
tensors = load_from_safetensors(weights_path)
tensors = load_from_safetensors(lora_dpo_weights_path)
return expected_image, tensors
@pytest.fixture(scope="module")
def lora_sliders(test_weights_path: Path) -> tuple[dict[str, dict[str, torch.Tensor]], dict[str, float]]:
weights_path = test_weights_path / "loras" / "sliders"
if not weights_path.is_dir():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
def lora_sliders(
lora_slider_age_weights_path: Path,
lora_slider_cartoon_style_weights_path: Path,
lora_slider_eyesize_weights_path: Path,
) -> tuple[dict[str, dict[str, torch.Tensor]], dict[str, float]]:
return {
"age": load_tensors(weights_path / "age.pt"), # type: ignore
"cartoon_style": load_tensors(weights_path / "cartoon_style.pt"), # type: ignore
"eyesize": load_tensors(weights_path / "eyesize.pt"), # type: ignore
"age": load_tensors(lora_slider_age_weights_path),
"cartoon_style": load_tensors(lora_slider_cartoon_style_weights_path),
"eyesize": load_tensors(lora_slider_eyesize_weights_path),
}, {
"age": 0.3,
"cartoon_style": -0.2,
@ -477,122 +488,12 @@ def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch
return load_tensors(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"]
@pytest.fixture(scope="module")
def text_encoder_weights(test_weights_path: Path) -> Path:
text_encoder_weights = test_weights_path / "CLIPTextEncoderL.safetensors"
if not text_encoder_weights.is_file():
warn(f"could not find weights at {text_encoder_weights}, skipping")
pytest.skip(allow_module_level=True)
return text_encoder_weights
@pytest.fixture(scope="module")
def lda_weights(test_weights_path: Path) -> Path:
lda_weights = test_weights_path / "lda.safetensors"
if not lda_weights.is_file():
warn(f"could not find weights at {lda_weights}, skipping")
pytest.skip(allow_module_level=True)
return lda_weights
@pytest.fixture(scope="module")
def unet_weights_std(test_weights_path: Path) -> Path:
unet_weights_std = test_weights_path / "unet.safetensors"
if not unet_weights_std.is_file():
warn(f"could not find weights at {unet_weights_std}, skipping")
pytest.skip(allow_module_level=True)
return unet_weights_std
@pytest.fixture(scope="module")
def unet_weights_inpainting(test_weights_path: Path) -> Path:
unet_weights_inpainting = test_weights_path / "inpainting" / "unet.safetensors"
if not unet_weights_inpainting.is_file():
warn(f"could not find weights at {unet_weights_inpainting}, skipping")
pytest.skip(allow_module_level=True)
return unet_weights_inpainting
@pytest.fixture(scope="module")
def lda_ft_mse_weights(test_weights_path: Path) -> Path:
lda_weights = test_weights_path / "lda_ft_mse.safetensors"
if not lda_weights.is_file():
warn(f"could not find weights at {lda_weights}, skipping")
pytest.skip(allow_module_level=True)
return lda_weights
@pytest.fixture(scope="module")
def ella_weights(test_weights_path: Path) -> tuple[Path, Path]:
ella_adapter_weights = test_weights_path / "ELLA-Adapter" / "ella-sd1.5-tsc-t5xl.safetensors"
if not ella_adapter_weights.is_file():
warn(f"could not find weights at {ella_adapter_weights}, skipping")
pytest.skip(allow_module_level=True)
t5xl_weights = test_weights_path / "QQGYLab" / "T5XLFP16"
t5xl_files = [
"config.json",
"model.safetensors",
"special_tokens_map.json",
"spiece.model",
"tokenizer_config.json",
"tokenizer.json",
]
for file in t5xl_files:
if not (t5xl_weights / file).is_file():
warn(f"could not find weights at {t5xl_weights / file}, skipping")
pytest.skip(allow_module_level=True)
return (ella_adapter_weights, t5xl_weights)
@pytest.fixture(scope="module")
def ip_adapter_weights(test_weights_path: Path) -> Path:
ip_adapter_weights = test_weights_path / "ip-adapter_sd15.safetensors"
if not ip_adapter_weights.is_file():
warn(f"could not find weights at {ip_adapter_weights}, skipping")
pytest.skip(allow_module_level=True)
return ip_adapter_weights
@pytest.fixture(scope="module")
def ip_adapter_plus_weights(test_weights_path: Path) -> Path:
ip_adapter_weights = test_weights_path / "ip-adapter-plus_sd15.safetensors"
if not ip_adapter_weights.is_file():
warn(f"could not find weights at {ip_adapter_weights}, skipping")
pytest.skip(allow_module_level=True)
return ip_adapter_weights
@pytest.fixture(scope="module")
def sdxl_ip_adapter_weights(test_weights_path: Path) -> Path:
ip_adapter_weights = test_weights_path / "ip-adapter_sdxl_vit-h.safetensors"
if not ip_adapter_weights.is_file():
warn(f"could not find weights at {ip_adapter_weights}, skipping")
pytest.skip(allow_module_level=True)
return ip_adapter_weights
@pytest.fixture(scope="module")
def sdxl_ip_adapter_plus_weights(test_weights_path: Path) -> Path:
ip_adapter_weights = test_weights_path / "ip-adapter-plus_sdxl_vit-h.safetensors"
if not ip_adapter_weights.is_file():
warn(f"could not find weights at {ip_adapter_weights}, skipping")
pytest.skip(allow_module_level=True)
return ip_adapter_weights
@pytest.fixture(scope="module")
def image_encoder_weights(test_weights_path: Path) -> Path:
image_encoder_weights = test_weights_path / "CLIPImageEncoderH.safetensors"
if not image_encoder_weights.is_file():
warn(f"could not find weights at {image_encoder_weights}, skipping")
pytest.skip(allow_module_level=True)
return image_encoder_weights
@pytest.fixture
def sd15_std(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
sd15_text_encoder_weights_path: Path,
sd15_autoencoder_weights_path: Path,
sd15_unet_weights_path: Path,
test_device: torch.device,
) -> StableDiffusion_1:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
@ -600,16 +501,19 @@ def sd15_std(
sd15 = StableDiffusion_1(device=test_device)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_std)
sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
sd15.unet.load_from_safetensors(sd15_unet_weights_path)
return sd15
@pytest.fixture
def sd15_std_sde(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
sd15_text_encoder_weights_path: Path,
sd15_autoencoder_weights_path: Path,
sd15_unet_weights_path: Path,
test_device: torch.device,
) -> StableDiffusion_1:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
@ -618,16 +522,19 @@ def sd15_std_sde(
sde_solver = DPMSolver(num_inference_steps=30, last_step_first_order=True, params=SolverParams(sde_variance=1.0))
sd15 = StableDiffusion_1(device=test_device, solver=sde_solver)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_std)
sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
sd15.unet.load_from_safetensors(sd15_unet_weights_path)
return sd15
@pytest.fixture
def sd15_std_float16(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
sd15_text_encoder_weights_path: Path,
sd15_autoencoder_weights_path: Path,
sd15_unet_weights_path: Path,
test_device: torch.device,
) -> StableDiffusion_1:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
@ -635,18 +542,18 @@ def sd15_std_float16(
sd15 = StableDiffusion_1(device=test_device, dtype=torch.float16)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_std)
sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
sd15.unet.load_from_safetensors(sd15_unet_weights_path)
return sd15
@pytest.fixture
def sd15_std_bfloat16(
text_encoder_weights: Path,
lda_weights: Path,
unet_weights_std: Path,
sd15_text_encoder_weights_path: Path,
sd15_autoencoder_weights_path: Path,
sd15_unet_weights_path: Path,
test_device: torch.device,
) -> StableDiffusion_1:
if test_device.type == "cpu":
@ -655,16 +562,19 @@ def sd15_std_bfloat16(
sd15 = StableDiffusion_1(device=test_device, dtype=torch.bfloat16)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_std)
sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
sd15.unet.load_from_safetensors(sd15_unet_weights_path)
return sd15
@pytest.fixture
def sd15_inpainting(
text_encoder_weights: Path, lda_weights: Path, unet_weights_inpainting: Path, test_device: torch.device
sd15_text_encoder_weights_path: Path,
sd15_autoencoder_weights_path: Path,
sd15_unet_inpainting_weights_path: Path,
test_device: torch.device,
) -> StableDiffusion_1_Inpainting:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
@ -673,16 +583,19 @@ def sd15_inpainting(
unet = SD1UNet(in_channels=9)
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_inpainting)
sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
sd15.unet.load_from_safetensors(sd15_unet_inpainting_weights_path)
return sd15
@pytest.fixture
def sd15_inpainting_float16(
text_encoder_weights: Path, lda_weights: Path, unet_weights_inpainting: Path, test_device: torch.device
sd15_text_encoder_weights_path: Path,
sd15_autoencoder_weights_path: Path,
sd15_unet_inpainting_weights_path: Path,
test_device: torch.device,
) -> StableDiffusion_1_Inpainting:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
@ -691,16 +604,19 @@ def sd15_inpainting_float16(
unet = SD1UNet(in_channels=9)
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device, dtype=torch.float16)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_inpainting)
sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
sd15.unet.load_from_safetensors(sd15_unet_inpainting_weights_path)
return sd15
@pytest.fixture
def sd15_ddim(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
sd15_text_encoder_weights_path: Path,
sd15_autoencoder_weights_path: Path,
sd15_unet_weights_path: Path,
test_device: torch.device,
) -> StableDiffusion_1:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
@ -709,16 +625,19 @@ def sd15_ddim(
ddim_solver = DDIM(num_inference_steps=20)
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_std)
sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
sd15.unet.load_from_safetensors(sd15_unet_weights_path)
return sd15
@pytest.fixture
def sd15_ddim_karras(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
sd15_text_encoder_weights_path: Path,
sd15_autoencoder_weights_path: Path,
sd15_unet_weights_path: Path,
test_device: torch.device,
) -> StableDiffusion_1:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
@ -727,16 +646,18 @@ def sd15_ddim_karras(
ddim_solver = DDIM(num_inference_steps=20, params=SolverParams(noise_schedule=NoiseSchedule.KARRAS))
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_std)
sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
sd15.unet.load_from_safetensors(sd15_unet_weights_path)
return sd15
@pytest.fixture
def sd15_euler(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
sd15_text_encoder_weights_path: Path,
sd15_autoencoder_weights_path: Path,
sd15_unet_weights_path: Path,
test_device: torch.device,
) -> StableDiffusion_1:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
@ -745,16 +666,19 @@ def sd15_euler(
euler_solver = Euler(num_inference_steps=30)
sd15 = StableDiffusion_1(solver=euler_solver, device=test_device)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_std)
sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
sd15.unet.load_from_safetensors(sd15_unet_weights_path)
return sd15
@pytest.fixture
def sd15_ddim_lda_ft_mse(
text_encoder_weights: Path, lda_ft_mse_weights: Path, unet_weights_std: Path, test_device: torch.device
sd15_text_encoder_weights_path: Path,
sd15_autoencoder_mse_weights_path: Path,
sd15_unet_weights_path: Path,
test_device: torch.device,
) -> StableDiffusion_1:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
@ -763,52 +687,19 @@ def sd15_ddim_lda_ft_mse(
ddim_solver = DDIM(num_inference_steps=20)
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
sd15.lda.load_state_dict(load_from_safetensors(lda_ft_mse_weights))
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std))
sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
sd15.lda.load_from_safetensors(sd15_autoencoder_mse_weights_path)
sd15.unet.load_from_safetensors(sd15_unet_weights_path)
return sd15
@pytest.fixture
def sdxl_lda_weights(test_weights_path: Path) -> Path:
sdxl_lda_weights = test_weights_path / "sdxl-lda.safetensors"
if not sdxl_lda_weights.is_file():
warn(message=f"could not find weights at {sdxl_lda_weights}, skipping")
pytest.skip(allow_module_level=True)
return sdxl_lda_weights
@pytest.fixture
def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path:
sdxl_lda_weights = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
if not sdxl_lda_weights.is_file():
warn(message=f"could not find weights at {sdxl_lda_weights}, skipping")
pytest.skip(allow_module_level=True)
return sdxl_lda_weights
@pytest.fixture
def sdxl_unet_weights(test_weights_path: Path) -> Path:
sdxl_unet_weights = test_weights_path / "sdxl-unet.safetensors"
if not sdxl_unet_weights.is_file():
warn(message=f"could not find weights at {sdxl_unet_weights}, skipping")
pytest.skip(allow_module_level=True)
return sdxl_unet_weights
@pytest.fixture
def sdxl_text_encoder_weights(test_weights_path: Path) -> Path:
sdxl_double_text_encoder_weights = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
if not sdxl_double_text_encoder_weights.is_file():
warn(message=f"could not find weights at {sdxl_double_text_encoder_weights}, skipping")
pytest.skip(allow_module_level=True)
return sdxl_double_text_encoder_weights
@pytest.fixture
def sdxl_ddim(
sdxl_text_encoder_weights: Path, sdxl_lda_weights: Path, sdxl_unet_weights: Path, test_device: torch.device
sdxl_text_encoder_weights_path: Path,
sdxl_autoencoder_weights_path: Path,
sdxl_unet_weights_path: Path,
test_device: torch.device,
) -> StableDiffusion_XL:
if test_device.type == "cpu":
warn(message="not running on CPU, skipping")
@ -817,16 +708,19 @@ def sdxl_ddim(
solver = DDIM(num_inference_steps=30)
sdxl = StableDiffusion_XL(solver=solver, device=test_device)
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_weights)
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights)
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights_path)
sdxl.lda.load_from_safetensors(tensors_path=sdxl_autoencoder_weights_path)
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights_path)
return sdxl
@pytest.fixture
def sdxl_ddim_lda_fp16_fix(
sdxl_text_encoder_weights: Path, sdxl_lda_fp16_fix_weights: Path, sdxl_unet_weights: Path, test_device: torch.device
sdxl_text_encoder_weights_path: Path,
sdxl_autoencoder_fp16fix_weights_path: Path,
sdxl_unet_weights_path: Path,
test_device: torch.device,
) -> StableDiffusion_XL:
if test_device.type == "cpu":
warn(message="not running on CPU, skipping")
@ -835,9 +729,9 @@ def sdxl_ddim_lda_fp16_fix(
solver = DDIM(num_inference_steps=30)
sdxl = StableDiffusion_XL(solver=solver, device=test_device)
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights)
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights)
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights_path)
sdxl.lda.load_from_safetensors(tensors_path=sdxl_autoencoder_fp16fix_weights_path)
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights_path)
return sdxl
@ -856,23 +750,18 @@ def sdxl_euler_deterministic(sdxl_ddim: StableDiffusion_XL) -> StableDiffusion_X
@pytest.fixture(scope="module")
def multi_upscaler(
test_weights_path: Path,
unet_weights_std: Path,
text_encoder_weights: Path,
lda_ft_mse_weights: Path,
controlnet_tiles_weights_path: Path,
sd15_text_encoder_weights_path: Path,
sd15_autoencoder_mse_weights_path: Path,
sd15_unet_weights_path: Path,
test_device: torch.device,
) -> MultiUpscaler:
controlnet_tile_weights = test_weights_path / "controlnet" / "lllyasviel_control_v11f1e_sd15_tile.safetensors"
if not controlnet_tile_weights.is_file():
warn(message=f"could not find weights at {controlnet_tile_weights}, skipping")
pytest.skip(allow_module_level=True)
return MultiUpscaler(
checkpoints=UpscalerCheckpoints(
unet=unet_weights_std,
clip_text_encoder=text_encoder_weights,
lda=lda_ft_mse_weights,
controlnet_tile=controlnet_tile_weights,
unet=sd15_unet_weights_path,
clip_text_encoder=sd15_text_encoder_weights_path,
lda=sd15_autoencoder_mse_weights_path,
controlnet_tile=controlnet_tiles_weights_path,
),
device=test_device,
dtype=torch.float32,
@ -891,7 +780,9 @@ def expected_multi_upscaler(ref_path: Path) -> Image.Image:
@no_grad()
def test_diffusion_std_random_init(
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
sd15_std: StableDiffusion_1,
expected_image_std_random_init: Image.Image,
test_device: torch.device,
):
sd15 = sd15_std
@ -1553,6 +1444,9 @@ def test_diffusion_sdxl_control_lora(
adapters: dict[str, ControlLoraAdapter] = {}
for config_name, config in configs.items():
if not config.weights_path.is_file():
pytest.skip(f"File not found: {config.weights_path}")
adapter = ControlLoraAdapter(
name=config_name,
scale=config.scale,
@ -1922,13 +1816,18 @@ def test_diffusion_textual_inversion_random_init(
@no_grad()
def test_diffusion_ella_adapter(
sd15_std_float16: StableDiffusion_1,
ella_weights: tuple[Path, Path],
ella_sd15_tsc_t5xl_weights_path: Path,
t5xl_transformers_path: str,
expected_image_ella_adapter: Image.Image,
test_device: torch.device,
use_local_weights: bool,
):
sd15 = sd15_std_float16
ella_adapter_weights, t5xl_weights = ella_weights
t5_encoder = T5TextEmbedder(pretrained_path=t5xl_weights, max_length=128).to(test_device, torch.float16)
t5_encoder = T5TextEmbedder(
pretrained_path=t5xl_transformers_path,
local_files_only=use_local_weights,
max_length=128,
).to(test_device, torch.float16)
prompt = "a chinese man wearing a white shirt and a checkered headscarf, holds a large falcon near his shoulder. the falcon has dark feathers with a distinctive beak. the background consists of a clear sky and a fence, suggesting an outdoor setting, possibly a desert or arid region"
negative_prompt = ""
@ -1938,7 +1837,7 @@ def test_diffusion_ella_adapter(
llm_text_embedding, negative_prompt_embeds = t5_encoder(prompt), t5_encoder(negative_prompt)
prompt_embedding = torch.cat((negative_prompt_embeds, llm_text_embedding)).to(test_device, torch.float16)
adapter = SD1ELLAAdapter(target=sd15.unet, weights=load_from_safetensors(ella_adapter_weights))
adapter = SD1ELLAAdapter(target=sd15.unet, weights=load_from_safetensors(ella_sd15_tsc_t5xl_weights_path))
adapter.inject()
sd15.set_inference_steps(50)
manual_seed(1001)
@ -1959,8 +1858,8 @@ def test_diffusion_ella_adapter(
@no_grad()
def test_diffusion_ip_adapter(
sd15_ddim_lda_ft_mse: StableDiffusion_1,
ip_adapter_weights: Path,
image_encoder_weights: Path,
ip_adapter_sd15_weights_path: Path,
clip_image_encoder_huge_weights_path: Path,
woman_image: Image.Image,
expected_image_ip_adapter_woman: Image.Image,
test_device: torch.device,
@ -1976,8 +1875,8 @@ def test_diffusion_ip_adapter(
prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_weights))
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_sd15_weights_path))
ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
ip_adapter.inject()
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
@ -2004,8 +1903,8 @@ def test_diffusion_ip_adapter(
@no_grad()
def test_diffusion_ip_adapter_multi(
sd15_ddim_lda_ft_mse: StableDiffusion_1,
ip_adapter_weights: Path,
image_encoder_weights: Path,
ip_adapter_sd15_weights_path: Path,
clip_image_encoder_huge_weights_path: Path,
woman_image: Image.Image,
statue_image: Image.Image,
expected_image_ip_adapter_multi: Image.Image,
@ -2016,8 +1915,8 @@ def test_diffusion_ip_adapter_multi(
prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_weights))
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_sd15_weights_path))
ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
ip_adapter.inject()
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
@ -2044,8 +1943,8 @@ def test_diffusion_ip_adapter_multi(
@no_grad()
def test_diffusion_sdxl_ip_adapter(
sdxl_ddim: StableDiffusion_XL,
sdxl_ip_adapter_weights: Path,
image_encoder_weights: Path,
ip_adapter_sdxl_weights_path: Path,
clip_image_encoder_huge_weights_path: Path,
woman_image: Image.Image,
expected_image_sdxl_ip_adapter_woman: Image.Image,
test_device: torch.device,
@ -2055,8 +1954,8 @@ def test_diffusion_sdxl_ip_adapter(
prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_weights))
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(ip_adapter_sdxl_weights_path))
ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
ip_adapter.inject()
with no_grad():
@ -2093,8 +1992,8 @@ def test_diffusion_sdxl_ip_adapter(
@no_grad()
def test_diffusion_ip_adapter_controlnet(
sd15_ddim: StableDiffusion_1,
ip_adapter_weights: Path,
image_encoder_weights: Path,
ip_adapter_sd15_weights_path: Path,
clip_image_encoder_huge_weights_path: Path,
lora_data_pokemon: tuple[Image.Image, Path],
controlnet_data_depth: tuple[str, Image.Image, Image.Image, Path],
expected_image_ip_adapter_controlnet: Image.Image,
@ -2107,8 +2006,8 @@ def test_diffusion_ip_adapter_controlnet(
prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_weights))
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_sd15_weights_path))
ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
ip_adapter.inject()
depth_controlnet = SD1ControlnetAdapter(
@ -2149,8 +2048,8 @@ def test_diffusion_ip_adapter_controlnet(
@no_grad()
def test_diffusion_ip_adapter_plus(
sd15_ddim_lda_ft_mse: StableDiffusion_1,
ip_adapter_plus_weights: Path,
image_encoder_weights: Path,
ip_adapter_sd15_plus_weights_path: Path,
clip_image_encoder_huge_weights_path: Path,
statue_image: Image.Image,
expected_image_ip_adapter_plus_statue: Image.Image,
test_device: torch.device,
@ -2161,9 +2060,9 @@ def test_diffusion_ip_adapter_plus(
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SD1IPAdapter(
target=sd15.unet, weights=load_from_safetensors(ip_adapter_plus_weights), fine_grained=True
target=sd15.unet, weights=load_from_safetensors(ip_adapter_sd15_plus_weights_path), fine_grained=True
)
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
ip_adapter.inject()
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
@ -2190,8 +2089,8 @@ def test_diffusion_ip_adapter_plus(
@no_grad()
def test_diffusion_sdxl_ip_adapter_plus(
sdxl_ddim: StableDiffusion_XL,
sdxl_ip_adapter_plus_weights: Path,
image_encoder_weights: Path,
ip_adapter_sdxl_plus_weights_path: Path,
clip_image_encoder_huge_weights_path: Path,
woman_image: Image.Image,
expected_image_sdxl_ip_adapter_plus_woman: Image.Image,
test_device: torch.device,
@ -2202,9 +2101,9 @@ def test_diffusion_sdxl_ip_adapter_plus(
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SDXLIPAdapter(
target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_plus_weights), fine_grained=True
target=sdxl.unet, weights=load_from_safetensors(ip_adapter_sdxl_plus_weights_path), fine_grained=True
)
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
ip_adapter.inject()
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
@ -2608,8 +2507,8 @@ def test_freeu(
def test_hello_world(
sdxl_ddim_lda_fp16_fix: StableDiffusion_XL,
t2i_adapter_xl_data_canny: tuple[str, Image.Image, Image.Image, Path],
sdxl_ip_adapter_weights: Path,
image_encoder_weights: Path,
ip_adapter_sdxl_weights_path: Path,
clip_image_encoder_huge_weights_path: Path,
hello_world_assets: tuple[Image.Image, Image.Image, Image.Image, Image.Image],
) -> None:
sdxl = sdxl_ddim_lda_fp16_fix.to(dtype=torch.float16)
@ -2622,8 +2521,8 @@ def test_hello_world(
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_weights))
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(ip_adapter_sdxl_weights_path))
ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
ip_adapter.inject()
image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt))
@ -2743,24 +2642,19 @@ def expected_ic_light(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_ic_light.png").convert("RGB")
@pytest.fixture(scope="module")
def ic_light_sd15_fc_weights(test_weights_path: Path) -> Path:
return test_weights_path / "iclight_sd15_fc-refiners.safetensors"
@pytest.fixture(scope="module")
def ic_light_sd15_fc(
ic_light_sd15_fc_weights: Path,
unet_weights_std: Path,
lda_weights: Path,
text_encoder_weights: Path,
ic_light_sd15_fc_weights_path: Path,
sd15_unet_weights_path: Path,
sd15_autoencoder_weights_path: Path,
sd15_text_encoder_weights_path: Path,
test_device: torch.device,
) -> ICLight:
return ICLight(
patch_weights=load_from_safetensors(ic_light_sd15_fc_weights),
unet=SD1UNet(in_channels=4).load_from_safetensors(unet_weights_std),
lda=SD1Autoencoder().load_from_safetensors(lda_weights),
clip_text_encoder=CLIPTextEncoderL().load_from_safetensors(text_encoder_weights),
patch_weights=load_from_safetensors(ic_light_sd15_fc_weights_path),
unet=SD1UNet(in_channels=4).load_from_safetensors(sd15_unet_weights_path),
lda=SD1Autoencoder().load_from_safetensors(sd15_autoencoder_weights_path),
clip_text_encoder=CLIPTextEncoderL().load_from_safetensors(sd15_text_encoder_weights_path),
device=test_device,
)

View file

@ -29,74 +29,11 @@ def ref_path(test_e2e_path: Path) -> Path:
return test_e2e_path / "test_doc_examples_ref"
@pytest.fixture(scope="module")
def sdxl_text_encoder_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
if not path.is_file():
warn(message=f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture(scope="module")
def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
if not path.is_file():
warn(message=f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture(scope="module")
def sdxl_unet_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "sdxl-unet.safetensors"
if not path.is_file():
warn(message=f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture(scope="module")
def sdxl_ip_adapter_plus_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "ip-adapter-plus_sdxl_vit-h.safetensors"
if not path.is_file():
warn(f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture(scope="module")
def image_encoder_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "CLIPImageEncoderH.safetensors"
if not path.is_file():
warn(f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture
def scifi_lora_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "loras" / "Sci-fi_Environments_sdxl.safetensors"
if not path.is_file():
warn(message=f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture
def pixelart_lora_weights(test_weights_path: Path) -> Path:
path = test_weights_path / "loras" / "pixel-art-xl-v1.1.safetensors"
if not path.is_file():
warn(message=f"could not find weights at {path}, skipping")
pytest.skip(allow_module_level=True)
return path
@pytest.fixture
def sdxl(
sdxl_text_encoder_weights: Path,
sdxl_lda_fp16_fix_weights: Path,
sdxl_unet_weights: Path,
sdxl_text_encoder_weights_path: Path,
sdxl_autoencoder_fp16fix_weights_path: Path,
sdxl_unet_weights_path: Path,
test_device: torch.device,
) -> StableDiffusion_XL:
if test_device.type == "cpu":
@ -105,9 +42,9 @@ def sdxl(
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16)
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights)
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights)
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights_path)
sdxl.lda.load_from_safetensors(tensors_path=sdxl_autoencoder_fp16fix_weights_path)
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights_path)
return sdxl
@ -180,7 +117,7 @@ def test_guide_adapting_sdxl_vanilla(
def test_guide_adapting_sdxl_single_lora(
test_device: torch.device,
sdxl: StableDiffusion_XL,
scifi_lora_weights: Path,
lora_scifi_weights_path: Path,
expected_image_guide_adapting_sdxl_single_lora: Image.Image,
) -> None:
if test_device.type == "cpu":
@ -195,7 +132,7 @@ def test_guide_adapting_sdxl_single_lora(
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
manager = SDLoraManager(sdxl)
manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights))
manager.add_loras("scifi-lora", load_from_safetensors(lora_scifi_weights_path))
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt + ", best quality, high quality",
@ -222,8 +159,8 @@ def test_guide_adapting_sdxl_single_lora(
def test_guide_adapting_sdxl_multiple_loras(
test_device: torch.device,
sdxl: StableDiffusion_XL,
scifi_lora_weights: Path,
pixelart_lora_weights: Path,
lora_scifi_weights_path: Path,
lora_pixelart_weights_path: Path,
expected_image_guide_adapting_sdxl_multiple_loras: Image.Image,
) -> None:
if test_device.type == "cpu":
@ -238,8 +175,8 @@ def test_guide_adapting_sdxl_multiple_loras(
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
manager = SDLoraManager(sdxl)
manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights))
manager.add_loras("pixel-art-lora", load_from_safetensors(pixelart_lora_weights), scale=1.4)
manager.add_loras("scifi-lora", load_from_safetensors(lora_scifi_weights_path))
manager.add_loras("pixel-art-lora", load_from_safetensors(lora_pixelart_weights_path), scale=1.4)
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt + ", best quality, high quality",
@ -266,10 +203,10 @@ def test_guide_adapting_sdxl_multiple_loras(
def test_guide_adapting_sdxl_loras_ip_adapter(
test_device: torch.device,
sdxl: StableDiffusion_XL,
sdxl_ip_adapter_plus_weights: Path,
image_encoder_weights: Path,
scifi_lora_weights: Path,
pixelart_lora_weights: Path,
ip_adapter_sdxl_plus_weights_path: Path,
clip_image_encoder_huge_weights_path: Path,
lora_scifi_weights_path: Path,
lora_pixelart_weights_path: Path,
image_prompt_german_castle: Image.Image,
expected_image_guide_adapting_sdxl_loras_ip_adapter: Image.Image,
) -> None:
@ -285,16 +222,16 @@ def test_guide_adapting_sdxl_loras_ip_adapter(
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
manager = SDLoraManager(sdxl)
manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights), scale=1.5)
manager.add_loras("pixel-art-lora", load_from_safetensors(pixelart_lora_weights), scale=1.55)
manager.add_loras("scifi-lora", load_from_safetensors(lora_scifi_weights_path), scale=1.5)
manager.add_loras("pixel-art-lora", load_from_safetensors(lora_pixelart_weights_path), scale=1.55)
ip_adapter = SDXLIPAdapter(
target=sdxl.unet,
weights=load_from_safetensors(sdxl_ip_adapter_plus_weights),
weights=load_from_safetensors(ip_adapter_sdxl_plus_weights_path),
scale=1.0,
fine_grained=True,
)
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
ip_adapter.inject()
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(

View file

@ -26,51 +26,6 @@ def ensure_gc():
gc.collect()
@pytest.fixture
def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path:
r = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
if not r.is_file():
warn(f"could not find weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture
def sdxl_unet_weights(test_weights_path: Path) -> Path:
r = test_weights_path / "sdxl-unet.safetensors"
if not r.is_file():
warn(f"could not find weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture
def sdxl_lcm_unet_weights(test_weights_path: Path) -> Path:
r = test_weights_path / "sdxl-lcm-unet.safetensors"
if not r.is_file():
warn(f"could not find weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture
def sdxl_text_encoder_weights(test_weights_path: Path) -> Path:
r = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
if not r.is_file():
warn(f"could not find weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture
def sdxl_lcm_lora_weights(test_weights_path: Path) -> Path:
r = test_weights_path / "sdxl-lcm-lora.safetensors"
if not r.is_file():
warn(f"could not find weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture(scope="module")
def ref_path(test_e2e_path: Path) -> Path:
return test_e2e_path / "test_lcm_ref"
@ -94,9 +49,9 @@ def expected_lcm_lora_1_2(ref_path: Path) -> Image.Image:
@no_grad()
def test_lcm_base(
test_device: torch.device,
sdxl_lda_fp16_fix_weights: Path,
sdxl_lcm_unet_weights: Path,
sdxl_text_encoder_weights: Path,
sdxl_autoencoder_fp16fix_weights_path: Path,
sdxl_unet_lcm_weights_path: Path,
sdxl_text_encoder_weights_path: Path,
expected_lcm_base: Image.Image,
) -> None:
if test_device.type == "cpu":
@ -111,9 +66,9 @@ def test_lcm_base(
# not in the diffusion loop.
SDXLLcmAdapter(sdxl.unet, condition_scale=8.0).inject()
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights)
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights)
sdxl.unet.load_from_safetensors(sdxl_lcm_unet_weights)
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path)
sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path)
sdxl.unet.load_from_safetensors(sdxl_unet_lcm_weights_path)
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
expected_image = expected_lcm_base
@ -141,10 +96,10 @@ def test_lcm_base(
@pytest.mark.parametrize("condition_scale", [1.0, 1.2])
def test_lcm_lora_with_guidance(
test_device: torch.device,
sdxl_lda_fp16_fix_weights: Path,
sdxl_unet_weights: Path,
sdxl_text_encoder_weights: Path,
sdxl_lcm_lora_weights: Path,
sdxl_autoencoder_fp16fix_weights_path: Path,
sdxl_unet_weights_path: Path,
sdxl_text_encoder_weights_path: Path,
lora_sdxl_lcm_weights_path: Path,
expected_lcm_lora_1_0: Image.Image,
expected_lcm_lora_1_2: Image.Image,
condition_scale: float,
@ -156,12 +111,12 @@ def test_lcm_lora_with_guidance(
solver = LCMSolver(num_inference_steps=4)
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver)
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights)
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights)
sdxl.unet.load_from_safetensors(sdxl_unet_weights)
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path)
sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path)
sdxl.unet.load_from_safetensors(sdxl_unet_weights_path)
manager = SDLoraManager(sdxl)
add_lcm_lora(manager, load_from_safetensors(sdxl_lcm_lora_weights))
add_lcm_lora(manager, load_from_safetensors(lora_sdxl_lcm_weights_path))
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
expected_image = expected_lcm_lora_1_0 if condition_scale == 1.0 else expected_lcm_lora_1_2
@ -191,10 +146,10 @@ def test_lcm_lora_with_guidance(
@no_grad()
def test_lcm_lora_without_guidance(
test_device: torch.device,
sdxl_lda_fp16_fix_weights: Path,
sdxl_unet_weights: Path,
sdxl_text_encoder_weights: Path,
sdxl_lcm_lora_weights: Path,
sdxl_autoencoder_fp16fix_weights_path: Path,
sdxl_unet_weights_path: Path,
sdxl_text_encoder_weights_path: Path,
lora_sdxl_lcm_weights_path: Path,
expected_lcm_lora_1_0: Image.Image,
) -> None:
if test_device.type == "cpu":
@ -205,12 +160,12 @@ def test_lcm_lora_without_guidance(
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver)
sdxl.classifier_free_guidance = False
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights)
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights)
sdxl.unet.load_from_safetensors(sdxl_unet_weights)
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path)
sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path)
sdxl.unet.load_from_safetensors(sdxl_unet_weights_path)
manager = SDLoraManager(sdxl)
add_lcm_lora(manager, load_from_safetensors(sdxl_lcm_lora_weights))
add_lcm_lora(manager, load_from_safetensors(lora_sdxl_lcm_weights_path))
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
expected_image = expected_lcm_lora_1_0

View file

@ -25,60 +25,6 @@ def ensure_gc():
gc.collect()
@pytest.fixture
def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path:
r = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
if not r.is_file():
warn(f"could not find weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture
def sdxl_unet_weights(test_weights_path: Path) -> Path:
r = test_weights_path / "sdxl-unet.safetensors"
if not r.is_file():
warn(f"could not find weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture
def sdxl_lightning_4step_unet_weights(test_weights_path: Path) -> Path:
r = test_weights_path / "sdxl_lightning_4step_unet.safetensors"
if not r.is_file():
warn(f"could not find weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture
def sdxl_lightning_1step_unet_weights(test_weights_path: Path) -> Path:
r = test_weights_path / "sdxl_lightning_1step_unet_x0.safetensors"
if not r.is_file():
warn(f"could not find weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture
def sdxl_text_encoder_weights(test_weights_path: Path) -> Path:
r = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
if not r.is_file():
warn(f"could not find weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture
def sdxl_lightning_4step_lora_weights(test_weights_path: Path) -> Path:
r = test_weights_path / "sdxl_lightning_4step_lora.safetensors"
if not r.is_file():
warn(f"could not find weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture(scope="module")
def ref_path(test_e2e_path: Path) -> Path:
return test_e2e_path / "test_lightning_ref"
@ -102,16 +48,16 @@ def expected_lightning_lora_4step(ref_path: Path) -> Image.Image:
@no_grad()
def test_lightning_base_4step(
test_device: torch.device,
sdxl_lda_fp16_fix_weights: Path,
sdxl_lightning_4step_unet_weights: Path,
sdxl_text_encoder_weights: Path,
sdxl_autoencoder_fp16fix_weights_path: Path,
sdxl_unet_lightning_4step_weights_path: Path,
sdxl_text_encoder_weights_path: Path,
expected_lightning_base_4step: Image.Image,
) -> None:
if test_device.type == "cpu":
warn(message="not running on CPU, skipping")
pytest.skip()
unet_weights = sdxl_lightning_4step_unet_weights
unet_weights = sdxl_unet_lightning_4step_weights_path
expected_image = expected_lightning_base_4step
solver = Euler(
@ -125,8 +71,8 @@ def test_lightning_base_4step(
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver)
sdxl.classifier_free_guidance = False
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights)
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights)
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path)
sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path)
sdxl.unet.load_from_safetensors(unet_weights)
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
@ -153,16 +99,16 @@ def test_lightning_base_4step(
@no_grad()
def test_lightning_base_1step(
test_device: torch.device,
sdxl_lda_fp16_fix_weights: Path,
sdxl_lightning_1step_unet_weights: Path,
sdxl_text_encoder_weights: Path,
sdxl_autoencoder_fp16fix_weights_path: Path,
sdxl_unet_lightning_1step_weights_path: Path,
sdxl_text_encoder_weights_path: Path,
expected_lightning_base_1step: Image.Image,
) -> None:
if test_device.type == "cpu":
warn(message="not running on CPU, skipping")
pytest.skip()
unet_weights = sdxl_lightning_1step_unet_weights
unet_weights = sdxl_unet_lightning_1step_weights_path
expected_image = expected_lightning_base_1step
solver = Euler(
@ -176,8 +122,8 @@ def test_lightning_base_1step(
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver)
sdxl.classifier_free_guidance = False
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights)
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights)
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path)
sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path)
sdxl.unet.load_from_safetensors(unet_weights)
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
@ -204,10 +150,10 @@ def test_lightning_base_1step(
@no_grad()
def test_lightning_lora_4step(
test_device: torch.device,
sdxl_lda_fp16_fix_weights: Path,
sdxl_unet_weights: Path,
sdxl_text_encoder_weights: Path,
sdxl_lightning_4step_lora_weights: Path,
sdxl_autoencoder_fp16fix_weights_path: Path,
sdxl_unet_weights_path: Path,
sdxl_text_encoder_weights_path: Path,
lora_sdxl_lightning_4step_weights_path: Path,
expected_lightning_lora_4step: Image.Image,
) -> None:
if test_device.type == "cpu":
@ -227,12 +173,12 @@ def test_lightning_lora_4step(
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver)
sdxl.classifier_free_guidance = False
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights)
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights)
sdxl.unet.load_from_safetensors(sdxl_unet_weights)
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path)
sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path)
sdxl.unet.load_from_safetensors(sdxl_unet_weights_path)
manager = SDLoraManager(sdxl)
add_lcm_lora(manager, load_from_safetensors(sdxl_lightning_4step_lora_weights), name="lightning")
add_lcm_lora(manager, load_from_safetensors(lora_sdxl_lightning_4step_weights_path), name="lightning")
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"

View file

@ -29,19 +29,10 @@ def expected_cactus_mask(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_cactus_mask.png")
@pytest.fixture(scope="module")
def mvanet_weights(test_weights_path: Path) -> Path:
weights = test_weights_path / "mvanet" / "mvanet.safetensors"
if not weights.is_file():
warn(f"could not find weights at {test_weights_path}, skipping")
pytest.skip(allow_module_level=True)
return weights
@pytest.fixture
def mvanet_model(mvanet_weights: Path, test_device: torch.device) -> MVANet:
def mvanet_model(mvanet_weights_path: Path, test_device: torch.device) -> MVANet:
model = MVANet(device=test_device).eval() # .eval() is important!
model.load_from_safetensors(mvanet_weights)
model.load_from_safetensors(mvanet_weights_path)
return model
@ -61,7 +52,7 @@ def test_mvanet(
@no_grad()
def test_mvanet_to(
mvanet_weights: Path,
mvanet_weights_path: Path,
ref_cactus: Image.Image,
expected_cactus_mask: Image.Image,
test_device: torch.device,
@ -71,7 +62,7 @@ def test_mvanet_to(
pytest.skip()
model = MVANet(device=torch.device("cpu")).eval()
model.load_from_safetensors(mvanet_weights)
model.load_from_safetensors(mvanet_weights_path)
model.to(test_device)
in_t = image_to_tensor(ref_cactus.resize((1024, 1024), Image.Resampling.BILINEAR)).squeeze()

View file

@ -1,5 +1,4 @@
from pathlib import Path
from warnings import warn
import pytest
import torch
@ -29,19 +28,13 @@ def expected_image_informative_drawings(diffusion_ref_path: Path) -> Image.Image
return _img_open(diffusion_ref_path / "cutecat_guide_lineart.png").convert("RGB")
@pytest.fixture(scope="module")
def informative_drawings_weights(test_weights_path: Path) -> Path:
weights = test_weights_path / "informative-drawings.safetensors"
if not weights.is_file():
warn(f"could not find weights at {test_weights_path}, skipping")
pytest.skip(allow_module_level=True)
return weights
@pytest.fixture
def informative_drawings_model(informative_drawings_weights: Path, test_device: torch.device) -> InformativeDrawings:
def informative_drawings_model(
controlnet_preprocessor_info_drawings_weights_path: Path,
test_device: torch.device,
) -> InformativeDrawings:
model = InformativeDrawings(device=test_device)
model.load_from_safetensors(informative_drawings_weights)
model.load_from_safetensors(controlnet_preprocessor_info_drawings_weights_path)
return model

View file

@ -1,5 +1,4 @@
from pathlib import Path
from warnings import warn
import pytest
import torch
@ -38,24 +37,15 @@ def expected_box_segmenter_spray_cropped_mask(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_box_segmenter_spray_cropped_mask.png")
@pytest.fixture(scope="module")
def box_segmenter_weights(test_weights_path: Path) -> Path:
weights = test_weights_path / "finegrain-box-segmenter-v0-1.safetensors"
if not weights.is_file():
warn(f"could not find weights at {test_weights_path}, skipping")
pytest.skip(allow_module_level=True)
return weights
def test_box_segmenter(
box_segmenter_weights: Path,
box_segmenter_weights_path: Path,
ref_shelves: Image.Image,
expected_box_segmenter_plant_mask: Image.Image,
expected_box_segmenter_spray_mask: Image.Image,
expected_box_segmenter_spray_cropped_mask: Image.Image,
test_device: torch.device,
):
segmenter = BoxSegmenter(weights=box_segmenter_weights, device=test_device)
segmenter = BoxSegmenter(weights=box_segmenter_weights_path, device=test_device)
plant_mask = segmenter(ref_shelves, box_prompt=(504, 82, 754, 368))
ensure_similar_images(plant_mask.convert("RGB"), expected_box_segmenter_plant_mask.convert("RGB"))

View file

@ -4,7 +4,7 @@ import torch
from torch import Tensor, nn
import refiners.fluxion.layers as fl
from refiners.fluxion.model_converter import ConversionStage, ModelConverter
from refiners.conversion.model_converter import ConversionStage, ModelConverter
from refiners.fluxion.utils import manual_seed

View file

@ -1,5 +1,4 @@
from pathlib import Path
from warnings import warn
import pytest
import torch
@ -20,17 +19,13 @@ PROMPTS = [
@pytest.fixture(scope="module")
def our_encoder_with_new_concepts(
test_weights_path: Path,
sd15_text_encoder_weights_path: Path,
test_device: torch.device,
cat_embedding_textual_inversion: torch.Tensor,
gta5_artwork_embedding_textual_inversion: torch.Tensor,
) -> CLIPTextEncoderL:
weights = test_weights_path / "CLIPTextEncoderL.safetensors"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)
encoder = CLIPTextEncoderL(device=test_device)
tensors = load_from_safetensors(weights)
tensors = load_from_safetensors(sd15_text_encoder_weights_path)
encoder.load_state_dict(tensors)
concept_extender = ConceptExtender(encoder)
concept_extender.add_concept("<cat-toy>", cat_embedding_textual_inversion)
@ -41,24 +36,21 @@ def our_encoder_with_new_concepts(
@pytest.fixture(scope="module")
def ref_sd15_with_new_concepts(
runwayml_weights_path: Path, test_textual_inversion_path: Path, test_device: torch.device
sd15_diffusers_runwayml_path: str,
test_textual_inversion_path: Path,
test_device: torch.device,
use_local_weights: bool,
) -> StableDiffusionPipeline:
pipe = StableDiffusionPipeline.from_pretrained(runwayml_weights_path).to(test_device) # type: ignore
pipe = StableDiffusionPipeline.from_pretrained( # type: ignore
sd15_diffusers_runwayml_path,
local_files_only=use_local_weights,
).to(test_device) # type: ignore
assert isinstance(pipe, StableDiffusionPipeline)
pipe.load_textual_inversion(test_textual_inversion_path / "cat-toy") # type: ignore
pipe.load_textual_inversion(test_textual_inversion_path / "gta5-artwork") # type: ignore
return pipe
@pytest.fixture(scope="module")
def runwayml_weights_path(test_weights_path: Path):
r = test_weights_path / "runwayml" / "stable-diffusion-v1-5"
if not r.is_dir():
warn(f"could not find RunwayML weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture(scope="module")
def ref_tokenizer_with_new_concepts(ref_sd15_with_new_concepts: StableDiffusionPipeline) -> transformers.CLIPTokenizer:
return ref_sd15_with_new_concepts.tokenizer # type: ignore

View file

@ -1,5 +1,4 @@
from pathlib import Path
from warnings import warn
import pytest
import torch
@ -11,39 +10,28 @@ from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
@pytest.fixture(scope="module")
def our_encoder(
test_weights_path: Path,
clip_image_encoder_huge_weights_path: Path,
test_device: torch.device,
test_dtype_fp32_bf16_fp16: torch.dtype,
) -> CLIPImageEncoderH:
weights = test_weights_path / "CLIPImageEncoderH.safetensors"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)
encoder = CLIPImageEncoderH(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
tensors = load_from_safetensors(weights)
tensors = load_from_safetensors(clip_image_encoder_huge_weights_path)
encoder.load_state_dict(tensors)
return encoder
@pytest.fixture(scope="module")
def stabilityai_unclip_weights_path(test_weights_path: Path):
r = test_weights_path / "stabilityai" / "stable-diffusion-2-1-unclip"
if not r.is_dir():
warn(f"could not find Stability AI weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture(scope="module")
def ref_encoder(
stabilityai_unclip_weights_path: Path,
unclip21_transformers_stabilityai_path: str,
test_device: torch.device,
test_dtype_fp32_bf16_fp16: torch.dtype,
use_local_weights: bool,
) -> CLIPVisionModelWithProjection:
return CLIPVisionModelWithProjection.from_pretrained( # type: ignore
stabilityai_unclip_weights_path,
unclip21_transformers_stabilityai_path,
local_files_only=use_local_weights,
subfolder="image_encoder",
).to(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
).to(device=test_device, dtype=test_dtype_fp32_bf16_fp16) # type: ignore
@no_grad()

View file

@ -1,5 +1,4 @@
from pathlib import Path
from warnings import warn
import pytest
import torch
@ -31,42 +30,39 @@ PROMPTS = [
@pytest.fixture(scope="module")
def our_encoder(
test_weights_path: Path,
sd15_text_encoder_weights_path: Path,
test_device: torch.device,
test_dtype_fp32_fp16: torch.dtype,
) -> CLIPTextEncoderL:
weights = test_weights_path / "CLIPTextEncoderL.safetensors"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)
tensors = load_from_safetensors(weights)
encoder = CLIPTextEncoderL(device=test_device, dtype=test_dtype_fp32_fp16)
tensors = load_from_safetensors(sd15_text_encoder_weights_path)
encoder.load_state_dict(tensors)
return encoder
@pytest.fixture(scope="module")
def runwayml_weights_path(test_weights_path: Path):
r = test_weights_path / "runwayml" / "stable-diffusion-v1-5"
if not r.is_dir():
warn(f"could not find RunwayML weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture(scope="module")
def ref_tokenizer(runwayml_weights_path: Path) -> transformers.CLIPTokenizer:
return transformers.CLIPTokenizer.from_pretrained(runwayml_weights_path, subfolder="tokenizer") # type: ignore
def ref_tokenizer(
sd15_diffusers_runwayml_path: str,
use_local_weights: bool,
) -> transformers.CLIPTokenizer:
return transformers.CLIPTokenizer.from_pretrained( # type: ignore
sd15_diffusers_runwayml_path,
local_files_only=use_local_weights,
subfolder="tokenizer",
)
@pytest.fixture(scope="module")
def ref_encoder(
runwayml_weights_path: Path,
sd15_diffusers_runwayml_path: str,
test_device: torch.device,
test_dtype_fp32_fp16: torch.dtype,
use_local_weights: bool,
) -> transformers.CLIPTextModel:
return transformers.CLIPTextModel.from_pretrained( # type: ignore
runwayml_weights_path,
sd15_diffusers_runwayml_path,
local_files_only=use_local_weights,
subfolder="text_encoder",
).to(device=test_device, dtype=test_dtype_fp32_fp16) # type: ignore

View file

@ -4,6 +4,7 @@ from warnings import warn
import pytest
import torch
from huggingface_hub import hf_hub_download # type: ignore
from refiners.fluxion.utils import load_from_safetensors, load_tensors, manual_seed, no_grad
from refiners.foundationals.dinov2.dinov2 import (
@ -18,7 +19,7 @@ from refiners.foundationals.dinov2.dinov2 import (
)
from refiners.foundationals.dinov2.vit import ViT
FLAVORS_MAP = {
FLAVORS_MAP_REFINERS = {
"dinov2_vits14": DINOv2_small,
"dinov2_vits14_reg": DINOv2_small_reg,
"dinov2_vitb14": DINOv2_base,
@ -28,6 +29,27 @@ FLAVORS_MAP = {
"dinov2_vitg14": DINOv2_giant,
"dinov2_vitg14_reg": DINOv2_giant_reg,
}
FLAVORS_MAP_HUB = {
"dinov2_vits14": "refiners/dinov2.small.patch_14",
"dinov2_vits14_reg": "refiners/dinov2.small.patch_14.reg_4",
"dinov2_vitb14": "refiners/dinov2.base.patch_14",
"dinov2_vitb14_reg": "refiners/dinov2.base.patch_14.reg_4",
"dinov2_vitl14": "refiners/dinov2.large.patch_14",
"dinov2_vitl14_reg": "refiners/dinov2.large.patch_14.reg_4",
"dinov2_vitg14": "refiners/dinov2.giant.patch_14",
"dinov2_vitg14_reg": "refiners/dinov2.giant.patch_14.reg_4",
}
@pytest.fixture(scope="module", params=["float16", "bfloat16"])
def dtype(request: pytest.FixtureRequest) -> torch.dtype:
match request.param:
case "float16":
return torch.float16
case "bfloat16":
return torch.bfloat16
case _ as dtype:
raise ValueError(f"unsupported dtype: {dtype}")
@pytest.fixture(scope="module", params=[224, 518])
@ -35,7 +57,7 @@ def resolution(request: pytest.FixtureRequest) -> int:
return request.param
@pytest.fixture(scope="module", params=FLAVORS_MAP.keys())
@pytest.fixture(scope="module", params=FLAVORS_MAP_REFINERS.keys())
def flavor(request: pytest.FixtureRequest) -> str:
return request.param
@ -53,7 +75,14 @@ def dinov2_repo_path(test_repos_path: Path) -> Path:
def ref_model(
flavor: str,
dinov2_repo_path: Path,
test_weights_path: Path,
dinov2_small_unconverted_weights_path: Path,
dinov2_small_reg4_unconverted_weights_path: Path,
dinov2_base_unconverted_weights_path: Path,
dinov2_base_reg4_unconverted_weights_path: Path,
dinov2_large_unconverted_weights_path: Path,
dinov2_large_reg4_unconverted_weights_path: Path,
dinov2_giant_unconverted_weights_path: Path,
dinov2_giant_reg4_unconverted_weights_path: Path,
test_device: torch.device,
) -> torch.nn.Module:
kwargs: dict[str, Any] = {}
@ -69,34 +98,51 @@ def ref_model(
)
model = model.to(device=test_device)
flavor = flavor.replace("_reg", "_reg4")
weights = test_weights_path / f"{flavor}_pretrain.pth"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)
model.load_state_dict(load_tensors(weights, device=test_device))
weight_map = {
"dinov2_vits14": dinov2_small_unconverted_weights_path,
"dinov2_vits14_reg": dinov2_small_reg4_unconverted_weights_path,
"dinov2_vitb14": dinov2_base_unconverted_weights_path,
"dinov2_vitb14_reg": dinov2_base_reg4_unconverted_weights_path,
"dinov2_vitl14": dinov2_large_unconverted_weights_path,
"dinov2_vitl14_reg": dinov2_large_reg4_unconverted_weights_path,
"dinov2_vitg14": dinov2_giant_unconverted_weights_path,
"dinov2_vitg14_reg": dinov2_giant_reg4_unconverted_weights_path,
}
weights_path = weight_map[flavor]
model.load_state_dict(load_tensors(weights_path, device=test_device))
assert isinstance(model, torch.nn.Module)
return model
@pytest.fixture(scope="module")
def our_model(
test_weights_path: Path,
flavor: str,
dinov2_small_weights_path: Path,
dinov2_small_reg4_weights_path: Path,
dinov2_base_weights_path: Path,
dinov2_base_reg4_weights_path: Path,
dinov2_large_weights_path: Path,
dinov2_large_reg4_weights_path: Path,
dinov2_giant_weights_path: Path,
dinov2_giant_reg4_weights_path: Path,
test_device: torch.device,
) -> ViT:
model = FLAVORS_MAP[flavor](device=test_device)
weight_map = {
"dinov2_vits14": dinov2_small_weights_path,
"dinov2_vits14_reg": dinov2_small_reg4_weights_path,
"dinov2_vitb14": dinov2_base_weights_path,
"dinov2_vitb14_reg": dinov2_base_reg4_weights_path,
"dinov2_vitl14": dinov2_large_weights_path,
"dinov2_vitl14_reg": dinov2_large_reg4_weights_path,
"dinov2_vitg14": dinov2_giant_weights_path,
"dinov2_vitg14_reg": dinov2_giant_reg4_weights_path,
}
weights_path = weight_map[flavor]
flavor = flavor.replace("_reg", "_reg4")
weights = test_weights_path / f"{flavor}_pretrain.safetensors"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)
tensors = load_from_safetensors(weights)
model = FLAVORS_MAP_REFINERS[flavor](device=test_device)
tensors = load_from_safetensors(weights_path)
model.load_state_dict(tensors)
return model

View file

@ -0,0 +1,139 @@
from pathlib import Path
import pytest
import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
from refiners.fluxion.utils import load_from_safetensors
from refiners.foundationals.latent_diffusion import (
CLIPTextEncoderL,
DoubleTextEncoder,
SD1Autoencoder,
SD1UNet,
SDXLAutoencoder,
SDXLUNet,
StableDiffusion_1,
StableDiffusion_XL,
)
@pytest.fixture(scope="module")
def refiners_sd15_autoencoder(sd15_autoencoder_weights_path: Path) -> SD1Autoencoder:
autoencoder = SD1Autoencoder()
tensors = load_from_safetensors(sd15_autoencoder_weights_path)
autoencoder.load_state_dict(tensors)
return autoencoder
@pytest.fixture(scope="module")
def refiners_sd15_unet(sd15_unet_weights_path: Path) -> SD1UNet:
unet = SD1UNet(in_channels=4)
tensors = load_from_safetensors(sd15_unet_weights_path)
unet.load_state_dict(tensors)
return unet
@pytest.fixture(scope="module")
def refiners_sd15_text_encoder(sd15_text_encoder_weights_path: Path) -> CLIPTextEncoderL:
text_encoder = CLIPTextEncoderL()
tensors = load_from_safetensors(sd15_text_encoder_weights_path)
text_encoder.load_state_dict(tensors)
return text_encoder
@pytest.fixture(scope="module")
def refiners_sd15(
refiners_sd15_autoencoder: SD1Autoencoder,
refiners_sd15_unet: SD1UNet,
refiners_sd15_text_encoder: CLIPTextEncoderL,
) -> StableDiffusion_1:
return StableDiffusion_1(
lda=refiners_sd15_autoencoder,
unet=refiners_sd15_unet,
clip_text_encoder=refiners_sd15_text_encoder,
)
@pytest.fixture(scope="module")
def refiners_sdxl_autoencoder(sdxl_autoencoder_weights_path: Path) -> SDXLAutoencoder:
autoencoder = SDXLAutoencoder()
tensors = load_from_safetensors(sdxl_autoencoder_weights_path)
autoencoder.load_state_dict(tensors)
return autoencoder
@pytest.fixture(scope="module")
def refiners_sdxl_unet(sdxl_unet_weights_path: Path) -> SDXLUNet:
unet = SDXLUNet(in_channels=4)
tensors = load_from_safetensors(sdxl_unet_weights_path)
unet.load_state_dict(tensors)
return unet
@pytest.fixture(scope="module")
def refiners_sdxl_text_encoder(sdxl_text_encoder_weights_path: Path) -> DoubleTextEncoder:
text_encoder = DoubleTextEncoder()
tensors = load_from_safetensors(sdxl_text_encoder_weights_path)
text_encoder.load_state_dict(tensors)
return text_encoder
@pytest.fixture(scope="module")
def refiners_sdxl(
refiners_sdxl_autoencoder: SDXLAutoencoder,
refiners_sdxl_unet: SDXLUNet,
refiners_sd15_text_encoder: DoubleTextEncoder,
) -> StableDiffusion_XL:
return StableDiffusion_XL(
lda=refiners_sdxl_autoencoder,
unet=refiners_sdxl_unet,
clip_text_encoder=refiners_sd15_text_encoder,
)
@pytest.fixture(scope="module", params=["SD1.5", "SDXL"])
def refiners_autoencoder(
request: pytest.FixtureRequest,
refiners_sd15_autoencoder: SD1Autoencoder,
refiners_sdxl_autoencoder: SDXLAutoencoder,
test_dtype_fp32_bf16_fp16: torch.dtype,
) -> SD1Autoencoder | SDXLAutoencoder:
model_version = request.param
match (model_version, test_dtype_fp32_bf16_fp16):
case ("SD1.5", _):
return refiners_sd15_autoencoder
case ("SDXL", torch.float16):
return refiners_sdxl_autoencoder
case ("SDXL", _):
return refiners_sdxl_autoencoder
case _:
raise ValueError(f"Unknown model version: {model_version}")
@pytest.fixture(scope="module")
def diffusers_sd15_pipeline(
sd15_diffusers_runwayml_path: str,
use_local_weights: bool,
) -> StableDiffusionPipeline:
return StableDiffusionPipeline.from_pretrained( # type: ignore
sd15_diffusers_runwayml_path,
local_files_only=use_local_weights,
)
@pytest.fixture(scope="module")
def diffusers_sdxl_pipeline(
sdxl_diffusers_stabilityai_path: str,
use_local_weights: bool,
) -> StableDiffusionXLPipeline:
return StableDiffusionXLPipeline.from_pretrained( # type: ignore
sdxl_diffusers_stabilityai_path,
local_files_only=use_local_weights,
)
@pytest.fixture(scope="module")
def diffusers_sdxl_unet(diffusers_sdxl_pipeline: StableDiffusionXLPipeline) -> UNet2DConditionModel:
return diffusers_sdxl_pipeline.unet # type: ignore

View file

@ -1,134 +0,0 @@
from pathlib import Path
from warnings import warn
import pytest
import torch
from PIL import Image
from tests.utils import ensure_similar_images
from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion import LatentDiffusionAutoencoder, SD1Autoencoder, SDXLAutoencoder
@pytest.fixture(scope="module")
def ref_path() -> Path:
return Path(__file__).parent / "test_auto_encoder_ref"
@pytest.fixture(scope="module", params=["SD1.5", "SDXL"])
def lda(
request: pytest.FixtureRequest,
test_weights_path: Path,
test_dtype_fp32_bf16_fp16: torch.dtype,
test_device: torch.device,
) -> LatentDiffusionAutoencoder:
model_version = request.param
match (model_version, test_dtype_fp32_bf16_fp16):
case ("SD1.5", _):
weight_path = test_weights_path / "lda.safetensors"
if not weight_path.is_file():
warn(f"could not find weights at {weight_path}, skipping")
pytest.skip(allow_module_level=True)
model = SD1Autoencoder().load_from_safetensors(weight_path)
case ("SDXL", torch.float16):
weight_path = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
if not weight_path.is_file():
warn(f"could not find weights at {weight_path}, skipping")
pytest.skip(allow_module_level=True)
model = SDXLAutoencoder().load_from_safetensors(weight_path)
case ("SDXL", _):
weight_path = test_weights_path / "sdxl-lda.safetensors"
if not weight_path.is_file():
warn(f"could not find weights at {weight_path}, skipping")
pytest.skip(allow_module_level=True)
model = SDXLAutoencoder().load_from_safetensors(weight_path)
case _:
raise ValueError(f"Unknown model version: {model_version}")
model = model.to(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
return model
@pytest.fixture(scope="module")
def sample_image(ref_path: Path) -> Image.Image:
test_image = ref_path / "macaw.png"
if not test_image.is_file():
warn(f"could not reference image at {test_image}, skipping")
pytest.skip(allow_module_level=True)
img = Image.open(test_image) # type: ignore
assert img.size == (512, 512)
return img
@no_grad()
def test_encode_decode_image(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
encoded = lda.image_to_latents(sample_image)
decoded = lda.latents_to_image(encoded)
assert decoded.mode == "RGB" # type: ignore
# Ensure no saturation. The green channel (band = 1) must not max out.
assert max(iter(decoded.getdata(band=1))) < 255 # type: ignore
ensure_similar_images(sample_image, decoded, min_psnr=20, min_ssim=0.9)
@no_grad()
def test_encode_decode_images(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
encoded = lda.images_to_latents([sample_image, sample_image])
images = lda.latents_to_images(encoded)
assert isinstance(images, list)
assert len(images) == 2
ensure_similar_images(sample_image, images[1], min_psnr=20, min_ssim=0.9)
@no_grad()
def test_tiled_autoencoder(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
sample_image = sample_image.resize((2048, 2048)) # type: ignore
with lda.tiled_inference(sample_image, tile_size=(512, 512)):
encoded = lda.tiled_image_to_latents(sample_image)
result = lda.tiled_latents_to_image(encoded)
ensure_similar_images(sample_image, result, min_psnr=35, min_ssim=0.985)
@no_grad()
def test_tiled_autoencoder_rectangular_tiles(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
sample_image = sample_image.resize((2048, 2048)) # type: ignore
with lda.tiled_inference(sample_image, tile_size=(512, 1024)):
encoded = lda.tiled_image_to_latents(sample_image)
result = lda.tiled_latents_to_image(encoded)
ensure_similar_images(sample_image, result, min_psnr=35, min_ssim=0.985)
@no_grad()
def test_tiled_autoencoder_large_tile(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
sample_image = sample_image.resize((1024, 1024)) # type: ignore
with lda.tiled_inference(sample_image, tile_size=(2048, 2048)):
encoded = lda.tiled_image_to_latents(sample_image)
result = lda.tiled_latents_to_image(encoded)
ensure_similar_images(sample_image, result, min_psnr=34, min_ssim=0.975)
@no_grad()
def test_tiled_autoencoder_rectangular_image(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
sample_image = sample_image.crop((0, 0, 300, 500))
sample_image = sample_image.resize((sample_image.width * 4, sample_image.height * 4)) # type: ignore
with lda.tiled_inference(sample_image, tile_size=(512, 512)):
encoded = lda.tiled_image_to_latents(sample_image)
result = lda.tiled_latents_to_image(encoded)
ensure_similar_images(sample_image, result, min_psnr=37, min_ssim=0.985)
def test_value_error_tile_encode_no_context(lda: LatentDiffusionAutoencoder, sample_image: Image.Image) -> None:
with pytest.raises(ValueError):
lda.tiled_image_to_latents(sample_image)
with pytest.raises(ValueError):
lda.tiled_latents_to_image(torch.randn(1, 8, 16, 16, device=lda.device))

View file

@ -0,0 +1,104 @@
from pathlib import Path
from warnings import warn
import pytest
import torch
from PIL import Image
from tests.utils import ensure_similar_images
from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
@pytest.fixture(scope="module")
def sample_image() -> Image.Image:
test_image = Path(__file__).parent / "test_auto_encoder_ref" / "macaw.png"
if not test_image.is_file():
warn(f"could not reference image at {test_image}, skipping")
pytest.skip(allow_module_level=True)
img = Image.open(test_image) # type: ignore
assert img.size == (512, 512)
return img
@pytest.fixture(scope="module")
def autoencoder(
refiners_autoencoder: LatentDiffusionAutoencoder,
test_device: torch.device,
) -> LatentDiffusionAutoencoder:
return refiners_autoencoder.to(test_device)
@no_grad()
def test_encode_decode_image(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
encoded = autoencoder.image_to_latents(sample_image)
decoded = autoencoder.latents_to_image(encoded)
assert decoded.mode == "RGB" # type: ignore
# Ensure no saturation. The green channel (band = 1) must not max out.
assert max(iter(decoded.getdata(band=1))) < 255 # type: ignore
ensure_similar_images(sample_image, decoded, min_psnr=20, min_ssim=0.9)
@no_grad()
def test_encode_decode_images(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
encoded = autoencoder.images_to_latents([sample_image, sample_image])
images = autoencoder.latents_to_images(encoded)
assert isinstance(images, list)
assert len(images) == 2
ensure_similar_images(sample_image, images[1], min_psnr=20, min_ssim=0.9)
@no_grad()
def test_tiled_autoencoder(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
sample_image = sample_image.resize((2048, 2048)) # type: ignore
with autoencoder.tiled_inference(sample_image, tile_size=(512, 512)):
encoded = autoencoder.tiled_image_to_latents(sample_image)
result = autoencoder.tiled_latents_to_image(encoded)
ensure_similar_images(sample_image, result, min_psnr=35, min_ssim=0.985)
@no_grad()
def test_tiled_autoencoder_rectangular_tiles(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
sample_image = sample_image.resize((2048, 2048)) # type: ignore
with autoencoder.tiled_inference(sample_image, tile_size=(512, 1024)):
encoded = autoencoder.tiled_image_to_latents(sample_image)
result = autoencoder.tiled_latents_to_image(encoded)
ensure_similar_images(sample_image, result, min_psnr=35, min_ssim=0.985)
@no_grad()
def test_tiled_autoencoder_large_tile(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
sample_image = sample_image.resize((1024, 1024)) # type: ignore
with autoencoder.tiled_inference(sample_image, tile_size=(2048, 2048)):
encoded = autoencoder.tiled_image_to_latents(sample_image)
result = autoencoder.tiled_latents_to_image(encoded)
ensure_similar_images(sample_image, result, min_psnr=34, min_ssim=0.975)
@no_grad()
def test_tiled_autoencoder_rectangular_image(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
sample_image = sample_image.crop((0, 0, 300, 500))
sample_image = sample_image.resize((sample_image.width * 4, sample_image.height * 4)) # type: ignore
with autoencoder.tiled_inference(sample_image, tile_size=(512, 512)):
encoded = autoencoder.tiled_image_to_latents(sample_image)
result = autoencoder.tiled_latents_to_image(encoded)
ensure_similar_images(sample_image, result, min_psnr=37, min_ssim=0.985)
def test_value_error_tile_encode_no_context(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image) -> None:
with pytest.raises(ValueError):
autoencoder.tiled_image_to_latents(sample_image)
with pytest.raises(ValueError):
autoencoder.tiled_latents_to_image(torch.randn(1, 8, 16, 16, device=autoencoder.device))

View file

@ -1,31 +0,0 @@
import torch
from PIL import Image
from refiners.fluxion.utils import manual_seed, no_grad
from refiners.foundationals.latent_diffusion import StableDiffusion_1_Inpainting
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
@no_grad()
def test_sample_noise():
manual_seed(2)
latents_0 = LatentDiffusionModel.sample_noise(size=(1, 4, 64, 64))
manual_seed(2)
latents_1 = LatentDiffusionModel.sample_noise(size=(1, 4, 64, 64), offset_noise=0.0)
assert torch.allclose(latents_0, latents_1, atol=1e-6, rtol=0)
@no_grad()
def test_sd1_inpainting(test_device: torch.device) -> None:
sd = StableDiffusion_1_Inpainting(device=test_device)
latent_noise = torch.randn(1, 4, 64, 64, device=test_device)
target_image = Image.new("RGB", (512, 512))
mask = Image.new("L", (512, 512))
sd.set_inpainting_conditions(target_image=target_image, mask=mask)
text_embedding = sd.compute_clip_text_embedding("")
output = sd(latent_noise, step=0, clip_text_embedding=text_embedding)
assert output.shape == (1, 4, 64, 64)

View file

@ -0,0 +1,65 @@
import torch
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from torch import Tensor
from refiners.fluxion.utils import manual_seed, no_grad
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
@no_grad()
def test_text_encoder(
diffusers_sd15_pipeline: StableDiffusionPipeline,
refiners_sd15_text_encoder: CLIPTextEncoderL,
) -> None:
"""Compare our refiners implementation with the diffusers implementation."""
manual_seed(seed=0) # unnecessary, but just in case
prompt = "A photo of a pizza."
negative_prompt = ""
atol = 1e-2 # FIXME: very high tolerance, figure out why
( # encode text prompts using diffusers pipeline
diffusers_embeds, # type: ignore
diffusers_negative_embeds, # type: ignore
) = diffusers_sd15_pipeline.encode_prompt( # type: ignore
prompt=prompt,
negative_prompt=negative_prompt,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
device=diffusers_sd15_pipeline.device,
)
assert isinstance(diffusers_embeds, Tensor)
assert isinstance(diffusers_negative_embeds, Tensor)
# encode text prompts using refiners model
refiners_embeds = refiners_sd15_text_encoder(prompt)
refiners_negative_embeds = refiners_sd15_text_encoder("")
# check that the shapes are the same
assert diffusers_embeds.shape == refiners_embeds.shape == (1, 77, 768)
assert diffusers_negative_embeds.shape == refiners_negative_embeds.shape == (1, 77, 768)
# check that the values are close
assert torch.allclose(input=refiners_embeds, other=diffusers_embeds, atol=atol)
assert torch.allclose(input=refiners_negative_embeds, other=diffusers_negative_embeds, atol=atol)
@no_grad()
def test_text_encoder_batched(refiners_sd15_text_encoder: CLIPTextEncoderL) -> None:
"""Check that encoding two prompts works as expected whether batched or not."""
manual_seed(seed=0) # unnecessary, but just in case
prompt1 = "A photo of a pizza."
prompt2 = "A giant duck."
atol = 1e-6
# encode the two prompts at once
embeds_batched = refiners_sd15_text_encoder([prompt1, prompt2])
assert embeds_batched.shape == (2, 77, 768)
# encode the prompts one by one
embeds_1 = refiners_sd15_text_encoder(prompt1)
embeds_2 = refiners_sd15_text_encoder(prompt2)
assert embeds_1.shape == embeds_2.shape == (1, 77, 768)
# check that the values are close
assert torch.allclose(input=embeds_1, other=embeds_batched[0].unsqueeze(0), atol=atol)
assert torch.allclose(input=embeds_2, other=embeds_batched[1].unsqueeze(0), atol=atol)

View file

@ -1,121 +0,0 @@
from pathlib import Path
from typing import Any, Protocol, cast
from warnings import warn
import pytest
import torch
from torch import Tensor
import refiners.fluxion.layers as fl
from refiners.fluxion.utils import manual_seed, no_grad
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
class DiffusersSDXL(Protocol):
unet: fl.Module
text_encoder: fl.Module
text_encoder_2: fl.Module
tokenizer: fl.Module
tokenizer_2: fl.Module
vae: fl.Module
def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any: ...
def encode_prompt(
self,
prompt: str,
prompt_2: str | None = None,
negative_prompt: str | None = None,
negative_prompt_2: str | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ...
@pytest.fixture(scope="module")
def stabilityai_sdxl_base_path(test_weights_path: Path) -> Path:
r = test_weights_path / "stabilityai" / "stable-diffusion-xl-base-1.0"
if not r.is_dir():
warn(message=f"could not find Stability SDXL base weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture(scope="module")
def double_text_encoder_weights(test_weights_path: Path) -> Path:
text_encoder_weights = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
if not text_encoder_weights.is_file():
warn(f"could not find weights at {text_encoder_weights}, skipping")
pytest.skip(allow_module_level=True)
return text_encoder_weights
@pytest.fixture(scope="module")
def diffusers_sdxl(stabilityai_sdxl_base_path: Path) -> Any:
from diffusers import DiffusionPipeline # type: ignore
return DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=stabilityai_sdxl_base_path) # type: ignore
@pytest.fixture(scope="module")
def double_text_encoder(double_text_encoder_weights: Path) -> DoubleTextEncoder:
double_text_encoder = DoubleTextEncoder()
double_text_encoder.load_from_safetensors(double_text_encoder_weights)
return double_text_encoder
@no_grad()
def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder: DoubleTextEncoder) -> None:
manual_seed(seed=0)
prompt = "A photo of a pizza."
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = diffusers_sdxl.encode_prompt(prompt=prompt, negative_prompt="")
double_embedding, pooled_embedding = double_text_encoder(prompt)
assert double_embedding.shape == torch.Size([1, 77, 2048])
assert pooled_embedding.shape == torch.Size([1, 1280])
embedding_1, embedding_2 = cast(
tuple[Tensor, Tensor],
prompt_embeds.split(split_size=[768, 1280], dim=-1), # type: ignore
)
rembedding_1, rembedding_2 = cast(
tuple[Tensor, Tensor],
double_embedding.split(split_size=[768, 1280], dim=-1), # type: ignore
)
assert torch.allclose(input=embedding_1, other=rembedding_1, rtol=1e-3, atol=1e-3)
assert torch.allclose(input=embedding_2, other=rembedding_2, rtol=1e-3, atol=1e-3)
assert torch.allclose(input=pooled_embedding, other=pooled_prompt_embeds, rtol=1e-3, atol=1e-3)
negative_double_embedding, negative_pooled_embedding = double_text_encoder("")
assert torch.allclose(input=negative_double_embedding, other=negative_prompt_embeds, rtol=1e-3, atol=1e-3)
assert torch.allclose(input=negative_pooled_embedding, other=negative_pooled_prompt_embeds, rtol=1e-3, atol=1e-3)
@no_grad()
def test_double_text_encoder_batch2(double_text_encoder: DoubleTextEncoder) -> None:
manual_seed(seed=0)
prompt1 = "A photo of a pizza."
prompt2 = "A giant duck."
double_embedding_b2, pooled_embedding_b2 = double_text_encoder([prompt1, prompt2])
assert double_embedding_b2.shape == torch.Size([2, 77, 2048])
assert pooled_embedding_b2.shape == torch.Size([2, 1280])
double_embedding_1, pooled_embedding_1 = double_text_encoder(prompt1)
double_embedding_2, pooled_embedding_2 = double_text_encoder(prompt2)
assert torch.allclose(input=double_embedding_1, other=double_embedding_b2[0:1], rtol=1e-3, atol=1e-3)
assert torch.allclose(input=pooled_embedding_1, other=pooled_embedding_b2[0:1], rtol=1e-3, atol=1e-3)
assert torch.allclose(input=double_embedding_2, other=double_embedding_b2[1:2], rtol=1e-3, atol=1e-3)
assert torch.allclose(input=pooled_embedding_2, other=pooled_embedding_b2[1:2], rtol=1e-3, atol=1e-3)

View file

@ -0,0 +1,70 @@
import torch
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
from torch import Tensor
from refiners.fluxion.utils import manual_seed, no_grad
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
@no_grad()
def test_double_text_encoder(
diffusers_sdxl_pipeline: StableDiffusionXLPipeline,
refiners_sdxl_text_encoder: DoubleTextEncoder,
) -> None:
"""Compare our refiners implementation with the diffusers implementation."""
manual_seed(seed=0) # unnecessary, but just in case
prompt = "A photo of a pizza."
negative_prompt = ""
atol = 1e-6
( # encode text prompts using diffusers pipeline
diffusers_embeds,
diffusers_negative_embeds, # type: ignore
diffusers_pooled_embeds, # type: ignore
diffusers_negative_pooled_embeds, # type: ignore
) = diffusers_sdxl_pipeline.encode_prompt(prompt=prompt, negative_prompt=negative_prompt)
assert diffusers_negative_embeds is not None
assert isinstance(diffusers_pooled_embeds, Tensor)
assert isinstance(diffusers_negative_pooled_embeds, Tensor)
# encode text prompts using refiners model
refiners_embeds, refiners_pooled_embeds = refiners_sdxl_text_encoder(prompt)
refiners_negative_embeds, refiners_negative_pooled_embeds = refiners_sdxl_text_encoder("")
# check that the shapes are the same
assert diffusers_embeds.shape == refiners_embeds.shape == (1, 77, 2048)
assert diffusers_pooled_embeds.shape == refiners_pooled_embeds.shape == (1, 1280)
assert diffusers_negative_embeds.shape == refiners_negative_embeds.shape == (1, 77, 2048)
assert diffusers_negative_pooled_embeds.shape == refiners_negative_pooled_embeds.shape == (1, 1280)
# check that the values are close
assert torch.allclose(input=refiners_embeds, other=diffusers_embeds, atol=atol)
assert torch.allclose(input=refiners_negative_embeds, other=diffusers_negative_embeds, atol=atol)
assert torch.allclose(input=refiners_negative_pooled_embeds, other=diffusers_negative_pooled_embeds, atol=atol)
assert torch.allclose(input=refiners_pooled_embeds, other=diffusers_pooled_embeds, atol=atol)
@no_grad()
def test_double_text_encoder_batched(refiners_sdxl_text_encoder: DoubleTextEncoder) -> None:
"""Check that encoding two prompts works as expected whether batched or not."""
manual_seed(seed=0) # unnecessary, but just in case
prompt1 = "A photo of a pizza."
prompt2 = "A giant duck."
atol = 1e-6
# encode the two prompts at once
embeds_batched, pooled_embeds_batched = refiners_sdxl_text_encoder([prompt1, prompt2])
assert embeds_batched.shape == (2, 77, 2048)
assert pooled_embeds_batched.shape == (2, 1280)
# encode the prompts one by one
embeds_1, pooled_embeds_1 = refiners_sdxl_text_encoder(prompt1)
embeds_2, pooled_embeds_2 = refiners_sdxl_text_encoder(prompt2)
assert embeds_1.shape == embeds_2.shape == (1, 77, 2048)
assert pooled_embeds_1.shape == pooled_embeds_2.shape == (1, 1280)
# check that the values are close
assert torch.allclose(input=embeds_1, other=embeds_batched[0].unsqueeze(0), atol=atol)
assert torch.allclose(input=pooled_embeds_1, other=pooled_embeds_batched[0].unsqueeze(0), atol=atol)
assert torch.allclose(input=embeds_2, other=embeds_batched[1].unsqueeze(0), atol=atol)
assert torch.allclose(input=pooled_embeds_2, other=pooled_embeds_batched[1].unsqueeze(0), atol=atol)

View file

@ -1,36 +1,13 @@
from pathlib import Path
from typing import Any
from warnings import warn
import pytest
import torch
from refiners.fluxion.model_converter import ConversionStage, ModelConverter
from refiners.conversion.model_converter import ConversionStage, ModelConverter
from refiners.fluxion.utils import manual_seed, no_grad
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
@pytest.fixture(scope="module")
def stabilityai_sdxl_base_path(test_weights_path: Path) -> Path:
r = test_weights_path / "stabilityai" / "stable-diffusion-xl-base-1.0"
if not r.is_dir():
warn(f"could not find Stability SDXL base weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture(scope="module")
def diffusers_sdxl(stabilityai_sdxl_base_path: Path) -> Any:
from diffusers import DiffusionPipeline # type: ignore
return DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=stabilityai_sdxl_base_path) # type: ignore
@pytest.fixture(scope="module")
def diffusers_sdxl_unet(diffusers_sdxl: Any) -> Any:
return diffusers_sdxl.unet
@pytest.fixture(scope="module")
def refiners_sdxl_unet() -> SDXLUNet:
unet = SDXLUNet(in_channels=4)
@ -38,7 +15,10 @@ def refiners_sdxl_unet() -> SDXLUNet:
@no_grad()
def test_sdxl_unet(diffusers_sdxl_unet: Any, refiners_sdxl_unet: SDXLUNet) -> None:
def test_sdxl_unet(
diffusers_sdxl_unet: Any,
refiners_sdxl_unet: SDXLUNet,
) -> None:
source = diffusers_sdxl_unet
target = refiners_sdxl_unet

View file

@ -1,8 +1,7 @@
import gc
from pathlib import Path
from warnings import warn
from pytest import fixture, skip
from pytest import fixture
@fixture(autouse=True)
@ -15,12 +14,3 @@ def ensure_gc():
@fixture(scope="package")
def ref_path(test_sam_path: Path) -> Path:
return test_sam_path / "test_sam_ref"
@fixture(scope="package")
def sam_h_weights(test_weights_path: Path) -> Path:
sam_h_weights = test_weights_path / "segment-anything-h.safetensors"
if not sam_h_weights.is_file():
warn(f"could not find weights at {sam_h_weights}, skipping")
skip(allow_module_level=True)
return sam_h_weights

View file

@ -1,6 +1,5 @@
from pathlib import Path
from typing import cast
from warnings import warn
import numpy as np
import pytest
@ -36,37 +35,17 @@ def tennis(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "tennis.png").convert("RGB") # type: ignore
@pytest.fixture(scope="module")
def hq_adapter_weights(test_weights_path: Path) -> Path:
"""Path to the HQ adapter weights in Refiners format"""
refiners_hq_adapter_sam_weights = test_weights_path / "refiners-sam-hq-vit-h.safetensors"
if not refiners_hq_adapter_sam_weights.is_file():
warn(f"Test weights not found at {refiners_hq_adapter_sam_weights}, skipping")
pytest.skip(allow_module_level=True)
return refiners_hq_adapter_sam_weights
@pytest.fixture
def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
def sam_h(sam_h_weights_path: Path, test_device: torch.device) -> SegmentAnythingH:
# HQSAMAdapter is designed to be used with single-output only, hence multimask_output=False.
sam_h = SegmentAnythingH(multimask_output=False, device=test_device)
sam_h.load_from_safetensors(tensors_path=sam_h_weights)
sam_h.load_from_safetensors(tensors_path=sam_h_weights_path)
return sam_h
@pytest.fixture(scope="module")
def reference_hq_adapter_weights(test_weights_path: Path) -> Path:
"""Path to the HQ adapter weights in default format"""
reference_hq_adapter_sam_weights = test_weights_path / "sam_hq_vit_h.pth"
if not reference_hq_adapter_sam_weights.is_file():
warn(f"Test weights not found at {reference_hq_adapter_sam_weights}, skipping")
pytest.skip(allow_module_level=True)
return reference_hq_adapter_sam_weights
@pytest.fixture(scope="module")
def reference_sam_h(reference_hq_adapter_weights: Path, test_device: torch.device) -> FacebookSAM:
sam_h = cast(FacebookSAM, sam_model_registry_hq["vit_h"](checkpoint=reference_hq_adapter_weights))
def reference_sam_h(sam_h_hq_adapter_unconverted_weights_path: Path, test_device: torch.device) -> FacebookSAM:
sam_h = cast(FacebookSAM, sam_model_registry_hq["vit_h"](checkpoint=sam_h_hq_adapter_unconverted_weights_path))
return sam_h.to(device=test_device)
@ -142,11 +121,11 @@ def test_mask_decoder_tokens_extender() -> None:
@no_grad()
def test_early_vit_embedding(
sam_h: SegmentAnythingH,
hq_adapter_weights: Path,
sam_h_hq_adapter_weights_path: Path,
reference_sam_h: FacebookSAM,
tennis: Image.Image,
) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject()
image_tensor = image_to_tensor(image=tennis.resize(size=(1024, 1024))) # type: ignore
@ -159,8 +138,8 @@ def test_early_vit_embedding(
assert torch.equal(early_vit_embedding, early_vit_embedding_refiners)
def test_tokens(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
def test_tokens(sam_h: SegmentAnythingH, sam_h_hq_adapter_weights_path: Path, reference_sam_h: FacebookSAM) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject()
mask_decoder_tokens_extender = sam_h.mask_decoder.ensure_find(MaskDecoderTokensExtender)
@ -175,8 +154,10 @@ def test_tokens(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam
@no_grad()
def test_compress_vit_feat(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
def test_compress_vit_feat(
sam_h: SegmentAnythingH, sam_h_hq_adapter_weights_path: Path, reference_sam_h: FacebookSAM
) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject()
early_vit_embedding = torch.randn(1, 64, 64, 1280, device=sam_h.device, dtype=sam_h.dtype)
@ -189,8 +170,10 @@ def test_compress_vit_feat(sam_h: SegmentAnythingH, hq_adapter_weights: Path, re
@no_grad()
def test_embedding_encoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
def test_embedding_encoder(
sam_h: SegmentAnythingH, sam_h_hq_adapter_weights_path: Path, reference_sam_h: FacebookSAM
) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject()
x = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype)
@ -203,8 +186,10 @@ def test_embedding_encoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path, re
@no_grad()
def test_hq_token_mlp(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
def test_hq_token_mlp(
sam_h: SegmentAnythingH, sam_h_hq_adapter_weights_path: Path, reference_sam_h: FacebookSAM
) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject()
x = torch.randn(1, 6, 256, device=sam_h.device, dtype=sam_h.dtype)
@ -217,13 +202,13 @@ def test_hq_token_mlp(sam_h: SegmentAnythingH, hq_adapter_weights: Path, referen
@pytest.mark.parametrize("hq_mask_only", [True, False])
def test_predictor(
sam_h: SegmentAnythingH,
hq_adapter_weights: Path,
sam_h_hq_adapter_weights_path: Path,
hq_mask_only: bool,
reference_sam_h_predictor: FacebookSAMPredictorHQ,
tennis: Image.Image,
one_prompt: SAMPrompt,
) -> None:
adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject()
adapter.hq_mask_only = hq_mask_only
assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only
@ -268,13 +253,13 @@ def test_predictor(
@pytest.mark.parametrize("hq_mask_only", [True, False])
def test_predictor_equal(
sam_h: SegmentAnythingH,
hq_adapter_weights: Path,
sam_h_hq_adapter_weights_path: Path,
hq_mask_only: bool,
reference_sam_h_predictor: FacebookSAMPredictorHQ,
tennis: Image.Image,
one_prompt: SAMPrompt,
) -> None:
adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject()
adapter.hq_mask_only = hq_mask_only
assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only
@ -318,8 +303,8 @@ def test_predictor_equal(
@no_grad()
def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
def test_batch_mask_decoder(sam_h: SegmentAnythingH, sam_h_hq_adapter_weights_path: Path) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject()
batch_size = 5
@ -348,8 +333,10 @@ def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) -
assert torch.equal(mask_prediction[0], mask_prediction[1])
def test_hq_sam_load_save_weights(sam_h: SegmentAnythingH, hq_adapter_weights: Path, test_device: torch.device) -> None:
weights = load_from_safetensors(hq_adapter_weights, device=test_device)
def test_hq_sam_load_save_weights(
sam_h: SegmentAnythingH, sam_h_hq_adapter_weights_path: Path, test_device: torch.device
) -> None:
weights = load_from_safetensors(sam_h_hq_adapter_weights_path, device=test_device)
hq_sam_adapter = HQSAMAdapter(sam_h)
out_weights_init = hq_sam_adapter.weights

View file

@ -1,7 +1,6 @@
from math import isclose
from pathlib import Path
from typing import cast
from warnings import warn
import numpy as np
import pytest
@ -17,8 +16,8 @@ from tests.foundationals.segment_anything.utils import (
from torch import Tensor
import refiners.fluxion.layers as fl
from refiners.conversion.model_converter import ModelConverter
from refiners.fluxion import manual_seed
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
@ -49,20 +48,11 @@ def one_prompt() -> SAMPrompt:
@pytest.fixture(scope="module")
def facebook_sam_h_weights(test_weights_path: Path) -> Path:
sam_h_weights = test_weights_path / "sam_vit_h_4b8939.pth"
if not sam_h_weights.is_file():
warn(f"could not find weights at {sam_h_weights}, skipping")
pytest.skip(allow_module_level=True)
return sam_h_weights
@pytest.fixture(scope="module")
def facebook_sam_h(facebook_sam_h_weights: Path, test_device: torch.device) -> FacebookSAM:
def facebook_sam_h(sam_h_unconverted_weights_path: Path, test_device: torch.device) -> FacebookSAM:
from segment_anything import build_sam_vit_h # type: ignore
sam_h = cast(FacebookSAM, build_sam_vit_h())
sam_h.load_state_dict(state_dict=load_tensors(facebook_sam_h_weights))
sam_h.load_state_dict(state_dict=load_tensors(sam_h_unconverted_weights_path))
return sam_h.to(device=test_device)
@ -76,16 +66,16 @@ def facebook_sam_h_predictor(facebook_sam_h: FacebookSAM) -> FacebookSAMPredicto
@pytest.fixture(scope="module")
def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
def sam_h(sam_h_weights_path: Path, test_device: torch.device) -> SegmentAnythingH:
sam_h = SegmentAnythingH(device=test_device)
sam_h.load_from_safetensors(tensors_path=sam_h_weights)
sam_h.load_from_safetensors(tensors_path=sam_h_weights_path)
return sam_h
@pytest.fixture(scope="module")
def sam_h_single_output(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
def sam_h_single_output(sam_h_weights_path: Path, test_device: torch.device) -> SegmentAnythingH:
sam_h = SegmentAnythingH(multimask_output=False, device=test_device)
sam_h.load_from_safetensors(tensors_path=sam_h_weights)
sam_h.load_from_safetensors(tensors_path=sam_h_weights_path)
return sam_h
@ -469,7 +459,10 @@ def test_predictor_resized_single_output(
def test_mask_encoder(
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt
facebook_sam_h_predictor: FacebookSAMPredictor,
sam_h: SegmentAnythingH,
truck: Image.Image,
one_prompt: SAMPrompt,
) -> None:
predictor = facebook_sam_h_predictor
predictor.set_image(np.array(truck))

View file

@ -1,5 +1,4 @@
from pathlib import Path
from warnings import warn
import pytest
import torch
@ -25,16 +24,11 @@ class CifarDataset(Dataset[torch.Tensor]):
@pytest.fixture(scope="module")
def dinov2_l(
test_weights_path: Path,
dinov2_large_weights_path: Path,
test_device: torch.device,
) -> dinov2.DINOv2_large:
weights = test_weights_path / f"dinov2_vitl14_pretrain.safetensors"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)
model = dinov2.DINOv2_large(device=test_device)
model.load_from_safetensors(weights)
model.load_from_safetensors(dinov2_large_weights_path)
return model

View file

@ -24,11 +24,20 @@ def ensure_similar_images(img_1: Image.Image, img_2: Image.Image, min_psnr: int
class T5TextEmbedder(nn.Module):
def __init__(
self, pretrained_path: Path = Path("tests/weights/QQGYLab/T5XLFP16"), max_length: int | None = None
self,
pretrained_path: Path | str,
max_length: int | None = None,
local_files_only: bool = False,
) -> None:
super().__init__() # type: ignore[reportUnknownMemberType]
self.model: nn.Module = T5EncoderModel.from_pretrained(pretrained_path, local_files_only=True) # type: ignore
self.tokenizer: transformers.T5Tokenizer = T5Tokenizer.from_pretrained(pretrained_path, local_files_only=True) # type: ignore
self.model: nn.Module = T5EncoderModel.from_pretrained( # type: ignore
pretrained_path,
local_files_only=local_files_only,
)
self.tokenizer: transformers.T5Tokenizer = T5Tokenizer.from_pretrained( # type: ignore
pretrained_path,
local_files_only=local_files_only,
)
self.max_length = max_length
def forward(