mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
update tests to use new fixtures
This commit is contained in:
parent
94eeb1afc3
commit
316fe6e4f0
|
@ -1,5 +1,4 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from warnings import warn
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
@ -16,14 +15,8 @@ def manager() -> SDLoraManager:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def weights(test_weights_path: Path) -> dict[str, torch.Tensor]:
|
def weights(lora_pokemon_weights_path: Path) -> dict[str, torch.Tensor]:
|
||||||
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin"
|
return load_tensors(lora_pokemon_weights_path)
|
||||||
|
|
||||||
if not weights_path.is_file():
|
|
||||||
warn(f"could not find weights at {weights_path}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
|
|
||||||
return load_tensors(weights_path)
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
def test_add_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
||||||
|
|
|
@ -9,6 +9,8 @@ import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tests.utils import T5TextEmbedder, ensure_similar_images
|
from tests.utils import T5TextEmbedder, ensure_similar_images
|
||||||
|
|
||||||
|
from refiners.conversion import controllora_sdxl
|
||||||
|
from refiners.conversion.utils import Hub
|
||||||
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
||||||
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad
|
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad
|
||||||
from refiners.foundationals.clip.concepts import ConceptExtender
|
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||||
|
@ -45,6 +47,8 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler i
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||||
from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAdapter
|
from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAdapter
|
||||||
|
|
||||||
|
from ..weight_paths import get_path
|
||||||
|
|
||||||
|
|
||||||
def _img_open(path: Path) -> Image.Image:
|
def _img_open(path: Path) -> Image.Image:
|
||||||
return Image.open(path) # type: ignore
|
return Image.open(path) # type: ignore
|
||||||
|
@ -194,69 +198,85 @@ def expected_style_aligned(ref_path: Path) -> Image.Image:
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
|
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
|
||||||
def controlnet_data(
|
def controlnet_data(
|
||||||
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
|
ref_path: Path,
|
||||||
|
controlnet_depth_weights_path: Path,
|
||||||
|
controlnet_canny_weights_path: Path,
|
||||||
|
controlnet_lineart_weights_path: Path,
|
||||||
|
controlnet_normals_weights_path: Path,
|
||||||
|
controlnet_sam_weights_path: Path,
|
||||||
|
request: pytest.FixtureRequest,
|
||||||
) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]:
|
) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]:
|
||||||
cn_name: str = request.param
|
cn_name: str = request.param
|
||||||
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
||||||
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
|
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
|
||||||
weights_fn = {
|
|
||||||
"depth": "lllyasviel_control_v11f1p_sd15_depth",
|
|
||||||
"canny": "lllyasviel_control_v11p_sd15_canny",
|
|
||||||
"lineart": "lllyasviel_control_v11p_sd15_lineart",
|
|
||||||
"normals": "lllyasviel_control_v11p_sd15_normalbae",
|
|
||||||
"sam": "mfidabel_controlnet-segment-anything",
|
|
||||||
}
|
|
||||||
|
|
||||||
weights_path = test_weights_path / "controlnet" / f"{weights_fn[cn_name]}.safetensors"
|
weights_fn = {
|
||||||
yield (cn_name, condition_image, expected_image, weights_path)
|
"depth": controlnet_depth_weights_path,
|
||||||
|
"canny": controlnet_canny_weights_path,
|
||||||
|
"lineart": controlnet_lineart_weights_path,
|
||||||
|
"normals": controlnet_normals_weights_path,
|
||||||
|
"sam": controlnet_sam_weights_path,
|
||||||
|
}
|
||||||
|
weights_path = weights_fn[cn_name]
|
||||||
|
|
||||||
|
yield cn_name, condition_image, expected_image, weights_path
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=["canny"])
|
@pytest.fixture(scope="module", params=["canny"])
|
||||||
def controlnet_data_scale_decay(
|
def controlnet_data_scale_decay(
|
||||||
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
|
ref_path: Path,
|
||||||
|
controlnet_canny_weights_path: Path,
|
||||||
|
request: pytest.FixtureRequest,
|
||||||
) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]:
|
) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]:
|
||||||
cn_name: str = request.param
|
cn_name: str = request.param
|
||||||
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
||||||
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}_scale_decay.png").convert("RGB")
|
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}_scale_decay.png").convert("RGB")
|
||||||
weights_fn = {
|
|
||||||
"canny": "lllyasviel_control_v11p_sd15_canny",
|
|
||||||
}
|
|
||||||
|
|
||||||
weights_path = test_weights_path / "controlnet" / f"{weights_fn[cn_name]}.safetensors"
|
weights_fn = {
|
||||||
|
"canny": controlnet_canny_weights_path,
|
||||||
|
}
|
||||||
|
weights_path = weights_fn[cn_name]
|
||||||
|
|
||||||
yield (cn_name, condition_image, expected_image, weights_path)
|
yield (cn_name, condition_image, expected_image, weights_path)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def controlnet_data_tile(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, Image.Image, Path]:
|
def controlnet_data_tile(
|
||||||
|
ref_path: Path,
|
||||||
|
controlnet_tiles_weights_path: Path,
|
||||||
|
) -> tuple[Image.Image, Image.Image, Path]:
|
||||||
condition_image = _img_open(ref_path / f"low_res_dog.png").convert("RGB").resize((1024, 1024)) # type: ignore
|
condition_image = _img_open(ref_path / f"low_res_dog.png").convert("RGB").resize((1024, 1024)) # type: ignore
|
||||||
expected_image = _img_open(ref_path / f"expected_controlnet_tile.png").convert("RGB")
|
expected_image = _img_open(ref_path / f"expected_controlnet_tile.png").convert("RGB")
|
||||||
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11f1e_sd15_tile.safetensors"
|
return condition_image, expected_image, controlnet_tiles_weights_path
|
||||||
return condition_image, expected_image, weights_path
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
|
def controlnet_data_canny(
|
||||||
|
ref_path: Path,
|
||||||
|
controlnet_canny_weights_path: Path,
|
||||||
|
) -> tuple[str, Image.Image, Image.Image, Path]:
|
||||||
cn_name = "canny"
|
cn_name = "canny"
|
||||||
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
||||||
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
|
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
|
||||||
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11p_sd15_canny.safetensors"
|
return cn_name, condition_image, expected_image, controlnet_canny_weights_path
|
||||||
return cn_name, condition_image, expected_image, weights_path
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def controlnet_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
|
def controlnet_data_depth(
|
||||||
|
ref_path: Path,
|
||||||
|
controlnet_depth_weights_path: Path,
|
||||||
|
) -> tuple[str, Image.Image, Image.Image, Path]:
|
||||||
cn_name = "depth"
|
cn_name = "depth"
|
||||||
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
||||||
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
|
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
|
||||||
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11f1p_sd15_depth.safetensors"
|
return cn_name, condition_image, expected_image, controlnet_depth_weights_path
|
||||||
return cn_name, condition_image, expected_image, weights_path
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ControlLoraConfig:
|
class ControlLoraConfig:
|
||||||
scale: float
|
scale: float
|
||||||
condition_path: str
|
condition_path: str
|
||||||
weights_path: str
|
weights: Hub
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -271,38 +291,38 @@ CONTROL_LORA_CONFIGS: dict[str, dict[str, ControlLoraConfig]] = {
|
||||||
"PyraCanny": ControlLoraConfig(
|
"PyraCanny": ControlLoraConfig(
|
||||||
scale=1.0,
|
scale=1.0,
|
||||||
condition_path="cutecat_guide_PyraCanny.png",
|
condition_path="cutecat_guide_PyraCanny.png",
|
||||||
weights_path="refiners_control-lora-canny-rank128.safetensors",
|
weights=controllora_sdxl.canny.converted,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
"expected_controllora_CPDS.png": {
|
"expected_controllora_CPDS.png": {
|
||||||
"CPDS": ControlLoraConfig(
|
"CPDS": ControlLoraConfig(
|
||||||
scale=1.0,
|
scale=1.0,
|
||||||
condition_path="cutecat_guide_CPDS.png",
|
condition_path="cutecat_guide_CPDS.png",
|
||||||
weights_path="refiners_fooocus_xl_cpds_128.safetensors",
|
weights=controllora_sdxl.cpds.converted,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
"expected_controllora_PyraCanny+CPDS.png": {
|
"expected_controllora_PyraCanny+CPDS.png": {
|
||||||
"PyraCanny": ControlLoraConfig(
|
"PyraCanny": ControlLoraConfig(
|
||||||
scale=0.55,
|
scale=0.55,
|
||||||
condition_path="cutecat_guide_PyraCanny.png",
|
condition_path="cutecat_guide_PyraCanny.png",
|
||||||
weights_path="refiners_control-lora-canny-rank128.safetensors",
|
weights=controllora_sdxl.canny.converted,
|
||||||
),
|
),
|
||||||
"CPDS": ControlLoraConfig(
|
"CPDS": ControlLoraConfig(
|
||||||
scale=0.55,
|
scale=0.55,
|
||||||
condition_path="cutecat_guide_CPDS.png",
|
condition_path="cutecat_guide_CPDS.png",
|
||||||
weights_path="refiners_fooocus_xl_cpds_128.safetensors",
|
weights=controllora_sdxl.cpds.converted,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
"expected_controllora_disabled.png": {
|
"expected_controllora_disabled.png": {
|
||||||
"PyraCanny": ControlLoraConfig(
|
"PyraCanny": ControlLoraConfig(
|
||||||
scale=0.0,
|
scale=0.0,
|
||||||
condition_path="cutecat_guide_PyraCanny.png",
|
condition_path="cutecat_guide_PyraCanny.png",
|
||||||
weights_path="refiners_control-lora-canny-rank128.safetensors",
|
weights=controllora_sdxl.canny.converted,
|
||||||
),
|
),
|
||||||
"CPDS": ControlLoraConfig(
|
"CPDS": ControlLoraConfig(
|
||||||
scale=0.0,
|
scale=0.0,
|
||||||
condition_path="cutecat_guide_CPDS.png",
|
condition_path="cutecat_guide_CPDS.png",
|
||||||
weights_path="refiners_fooocus_xl_cpds_128.safetensors",
|
weights=controllora_sdxl.cpds.converted,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -311,8 +331,8 @@ CONTROL_LORA_CONFIGS: dict[str, dict[str, ControlLoraConfig]] = {
|
||||||
@pytest.fixture(params=CONTROL_LORA_CONFIGS.items())
|
@pytest.fixture(params=CONTROL_LORA_CONFIGS.items())
|
||||||
def controllora_sdxl_config(
|
def controllora_sdxl_config(
|
||||||
request: pytest.FixtureRequest,
|
request: pytest.FixtureRequest,
|
||||||
|
use_local_weights: bool,
|
||||||
ref_path: Path,
|
ref_path: Path,
|
||||||
test_weights_path: Path,
|
|
||||||
) -> tuple[Image.Image, dict[str, ControlLoraResolvedConfig]]:
|
) -> tuple[Image.Image, dict[str, ControlLoraResolvedConfig]]:
|
||||||
name: str = request.param[0]
|
name: str = request.param[0]
|
||||||
configs: dict[str, ControlLoraConfig] = request.param[1]
|
configs: dict[str, ControlLoraConfig] = request.param[1]
|
||||||
|
@ -322,7 +342,7 @@ def controllora_sdxl_config(
|
||||||
config_name: ControlLoraResolvedConfig(
|
config_name: ControlLoraResolvedConfig(
|
||||||
scale=config.scale,
|
scale=config.scale,
|
||||||
condition_image=_img_open(ref_path / config.condition_path).convert("RGB"),
|
condition_image=_img_open(ref_path / config.condition_path).convert("RGB"),
|
||||||
weights_path=test_weights_path / "control-loras" / config.weights_path,
|
weights_path=get_path(config.weights, use_local_weights),
|
||||||
)
|
)
|
||||||
for config_name, config in configs.items()
|
for config_name, config in configs.items()
|
||||||
}
|
}
|
||||||
|
@ -331,66 +351,57 @@ def controllora_sdxl_config(
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
|
def t2i_adapter_data_depth(
|
||||||
|
ref_path: Path,
|
||||||
|
t2i_depth_weights_path: Path,
|
||||||
|
) -> tuple[str, Image.Image, Image.Image, Path]:
|
||||||
name = "depth"
|
name = "depth"
|
||||||
condition_image = _img_open(ref_path / f"cutecat_guide_{name}.png").convert("RGB")
|
condition_image = _img_open(ref_path / f"cutecat_guide_{name}.png").convert("RGB")
|
||||||
expected_image = _img_open(ref_path / f"expected_t2i_adapter_{name}.png").convert("RGB")
|
expected_image = _img_open(ref_path / f"expected_t2i_adapter_{name}.png").convert("RGB")
|
||||||
weights_path = test_weights_path / "T2I-Adapter" / "t2iadapter_depth_sd15v2.safetensors"
|
return name, condition_image, expected_image, t2i_depth_weights_path
|
||||||
return name, condition_image, expected_image, weights_path
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def t2i_adapter_xl_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
|
def t2i_adapter_xl_data_canny(
|
||||||
|
ref_path: Path,
|
||||||
|
t2i_sdxl_canny_weights_path: Path,
|
||||||
|
) -> tuple[str, Image.Image, Image.Image, Path]:
|
||||||
name = "canny"
|
name = "canny"
|
||||||
condition_image = _img_open(ref_path / f"fairy_guide_{name}.png").convert("RGB")
|
condition_image = _img_open(ref_path / f"fairy_guide_{name}.png").convert("RGB")
|
||||||
expected_image = _img_open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB")
|
expected_image = _img_open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB")
|
||||||
weights_path = test_weights_path / "T2I-Adapter" / "t2i-adapter-canny-sdxl-1.0.safetensors"
|
return name, condition_image, expected_image, t2i_sdxl_canny_weights_path
|
||||||
|
|
||||||
if not weights_path.is_file():
|
|
||||||
warn(f"could not find weights at {weights_path}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
|
|
||||||
return name, condition_image, expected_image, weights_path
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]:
|
def lora_data_pokemon(
|
||||||
|
ref_path: Path,
|
||||||
|
lora_pokemon_weights_path: Path,
|
||||||
|
) -> tuple[Image.Image, dict[str, torch.Tensor]]:
|
||||||
expected_image = _img_open(ref_path / "expected_lora_pokemon.png").convert("RGB")
|
expected_image = _img_open(ref_path / "expected_lora_pokemon.png").convert("RGB")
|
||||||
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin"
|
tensors = load_tensors(lora_pokemon_weights_path)
|
||||||
|
|
||||||
if not weights_path.is_file():
|
|
||||||
warn(f"could not find weights at {weights_path}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
|
|
||||||
tensors = load_tensors(weights_path)
|
|
||||||
return expected_image, tensors
|
return expected_image, tensors
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def lora_data_dpo(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]:
|
def lora_data_dpo(
|
||||||
|
ref_path: Path,
|
||||||
|
lora_dpo_weights_path: Path,
|
||||||
|
) -> tuple[Image.Image, dict[str, torch.Tensor]]:
|
||||||
expected_image = _img_open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB")
|
expected_image = _img_open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB")
|
||||||
weights_path = test_weights_path / "loras" / "dpo-lora" / "pytorch_lora_weights.safetensors"
|
tensors = load_from_safetensors(lora_dpo_weights_path)
|
||||||
|
|
||||||
if not weights_path.is_file():
|
|
||||||
warn(f"could not find weights at {weights_path}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
|
|
||||||
tensors = load_from_safetensors(weights_path)
|
|
||||||
return expected_image, tensors
|
return expected_image, tensors
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def lora_sliders(test_weights_path: Path) -> tuple[dict[str, dict[str, torch.Tensor]], dict[str, float]]:
|
def lora_sliders(
|
||||||
weights_path = test_weights_path / "loras" / "sliders"
|
lora_slider_age_weights_path: Path,
|
||||||
|
lora_slider_cartoon_style_weights_path: Path,
|
||||||
if not weights_path.is_dir():
|
lora_slider_eyesize_weights_path: Path,
|
||||||
warn(f"could not find weights at {weights_path}, skipping")
|
) -> tuple[dict[str, dict[str, torch.Tensor]], dict[str, float]]:
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"age": load_tensors(weights_path / "age.pt"), # type: ignore
|
"age": load_tensors(lora_slider_age_weights_path),
|
||||||
"cartoon_style": load_tensors(weights_path / "cartoon_style.pt"), # type: ignore
|
"cartoon_style": load_tensors(lora_slider_cartoon_style_weights_path),
|
||||||
"eyesize": load_tensors(weights_path / "eyesize.pt"), # type: ignore
|
"eyesize": load_tensors(lora_slider_eyesize_weights_path),
|
||||||
}, {
|
}, {
|
||||||
"age": 0.3,
|
"age": 0.3,
|
||||||
"cartoon_style": -0.2,
|
"cartoon_style": -0.2,
|
||||||
|
@ -477,122 +488,12 @@ def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch
|
||||||
return load_tensors(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"]
|
return load_tensors(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def text_encoder_weights(test_weights_path: Path) -> Path:
|
|
||||||
text_encoder_weights = test_weights_path / "CLIPTextEncoderL.safetensors"
|
|
||||||
if not text_encoder_weights.is_file():
|
|
||||||
warn(f"could not find weights at {text_encoder_weights}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return text_encoder_weights
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def lda_weights(test_weights_path: Path) -> Path:
|
|
||||||
lda_weights = test_weights_path / "lda.safetensors"
|
|
||||||
if not lda_weights.is_file():
|
|
||||||
warn(f"could not find weights at {lda_weights}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return lda_weights
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def unet_weights_std(test_weights_path: Path) -> Path:
|
|
||||||
unet_weights_std = test_weights_path / "unet.safetensors"
|
|
||||||
if not unet_weights_std.is_file():
|
|
||||||
warn(f"could not find weights at {unet_weights_std}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return unet_weights_std
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def unet_weights_inpainting(test_weights_path: Path) -> Path:
|
|
||||||
unet_weights_inpainting = test_weights_path / "inpainting" / "unet.safetensors"
|
|
||||||
if not unet_weights_inpainting.is_file():
|
|
||||||
warn(f"could not find weights at {unet_weights_inpainting}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return unet_weights_inpainting
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def lda_ft_mse_weights(test_weights_path: Path) -> Path:
|
|
||||||
lda_weights = test_weights_path / "lda_ft_mse.safetensors"
|
|
||||||
if not lda_weights.is_file():
|
|
||||||
warn(f"could not find weights at {lda_weights}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return lda_weights
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def ella_weights(test_weights_path: Path) -> tuple[Path, Path]:
|
|
||||||
ella_adapter_weights = test_weights_path / "ELLA-Adapter" / "ella-sd1.5-tsc-t5xl.safetensors"
|
|
||||||
if not ella_adapter_weights.is_file():
|
|
||||||
warn(f"could not find weights at {ella_adapter_weights}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
t5xl_weights = test_weights_path / "QQGYLab" / "T5XLFP16"
|
|
||||||
t5xl_files = [
|
|
||||||
"config.json",
|
|
||||||
"model.safetensors",
|
|
||||||
"special_tokens_map.json",
|
|
||||||
"spiece.model",
|
|
||||||
"tokenizer_config.json",
|
|
||||||
"tokenizer.json",
|
|
||||||
]
|
|
||||||
for file in t5xl_files:
|
|
||||||
if not (t5xl_weights / file).is_file():
|
|
||||||
warn(f"could not find weights at {t5xl_weights / file}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
|
|
||||||
return (ella_adapter_weights, t5xl_weights)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def ip_adapter_weights(test_weights_path: Path) -> Path:
|
|
||||||
ip_adapter_weights = test_weights_path / "ip-adapter_sd15.safetensors"
|
|
||||||
if not ip_adapter_weights.is_file():
|
|
||||||
warn(f"could not find weights at {ip_adapter_weights}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return ip_adapter_weights
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def ip_adapter_plus_weights(test_weights_path: Path) -> Path:
|
|
||||||
ip_adapter_weights = test_weights_path / "ip-adapter-plus_sd15.safetensors"
|
|
||||||
if not ip_adapter_weights.is_file():
|
|
||||||
warn(f"could not find weights at {ip_adapter_weights}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return ip_adapter_weights
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def sdxl_ip_adapter_weights(test_weights_path: Path) -> Path:
|
|
||||||
ip_adapter_weights = test_weights_path / "ip-adapter_sdxl_vit-h.safetensors"
|
|
||||||
if not ip_adapter_weights.is_file():
|
|
||||||
warn(f"could not find weights at {ip_adapter_weights}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return ip_adapter_weights
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def sdxl_ip_adapter_plus_weights(test_weights_path: Path) -> Path:
|
|
||||||
ip_adapter_weights = test_weights_path / "ip-adapter-plus_sdxl_vit-h.safetensors"
|
|
||||||
if not ip_adapter_weights.is_file():
|
|
||||||
warn(f"could not find weights at {ip_adapter_weights}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return ip_adapter_weights
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def image_encoder_weights(test_weights_path: Path) -> Path:
|
|
||||||
image_encoder_weights = test_weights_path / "CLIPImageEncoderH.safetensors"
|
|
||||||
if not image_encoder_weights.is_file():
|
|
||||||
warn(f"could not find weights at {image_encoder_weights}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return image_encoder_weights
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sd15_std(
|
def sd15_std(
|
||||||
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
|
sd15_text_encoder_weights_path: Path,
|
||||||
|
sd15_autoencoder_weights_path: Path,
|
||||||
|
sd15_unet_weights_path: Path,
|
||||||
|
test_device: torch.device,
|
||||||
) -> StableDiffusion_1:
|
) -> StableDiffusion_1:
|
||||||
if test_device.type == "cpu":
|
if test_device.type == "cpu":
|
||||||
warn("not running on CPU, skipping")
|
warn("not running on CPU, skipping")
|
||||||
|
@ -600,16 +501,19 @@ def sd15_std(
|
||||||
|
|
||||||
sd15 = StableDiffusion_1(device=test_device)
|
sd15 = StableDiffusion_1(device=test_device)
|
||||||
|
|
||||||
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
|
||||||
sd15.lda.load_from_safetensors(lda_weights)
|
sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
|
||||||
sd15.unet.load_from_safetensors(unet_weights_std)
|
sd15.unet.load_from_safetensors(sd15_unet_weights_path)
|
||||||
|
|
||||||
return sd15
|
return sd15
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sd15_std_sde(
|
def sd15_std_sde(
|
||||||
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
|
sd15_text_encoder_weights_path: Path,
|
||||||
|
sd15_autoencoder_weights_path: Path,
|
||||||
|
sd15_unet_weights_path: Path,
|
||||||
|
test_device: torch.device,
|
||||||
) -> StableDiffusion_1:
|
) -> StableDiffusion_1:
|
||||||
if test_device.type == "cpu":
|
if test_device.type == "cpu":
|
||||||
warn("not running on CPU, skipping")
|
warn("not running on CPU, skipping")
|
||||||
|
@ -618,16 +522,19 @@ def sd15_std_sde(
|
||||||
sde_solver = DPMSolver(num_inference_steps=30, last_step_first_order=True, params=SolverParams(sde_variance=1.0))
|
sde_solver = DPMSolver(num_inference_steps=30, last_step_first_order=True, params=SolverParams(sde_variance=1.0))
|
||||||
sd15 = StableDiffusion_1(device=test_device, solver=sde_solver)
|
sd15 = StableDiffusion_1(device=test_device, solver=sde_solver)
|
||||||
|
|
||||||
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
|
||||||
sd15.lda.load_from_safetensors(lda_weights)
|
sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
|
||||||
sd15.unet.load_from_safetensors(unet_weights_std)
|
sd15.unet.load_from_safetensors(sd15_unet_weights_path)
|
||||||
|
|
||||||
return sd15
|
return sd15
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sd15_std_float16(
|
def sd15_std_float16(
|
||||||
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
|
sd15_text_encoder_weights_path: Path,
|
||||||
|
sd15_autoencoder_weights_path: Path,
|
||||||
|
sd15_unet_weights_path: Path,
|
||||||
|
test_device: torch.device,
|
||||||
) -> StableDiffusion_1:
|
) -> StableDiffusion_1:
|
||||||
if test_device.type == "cpu":
|
if test_device.type == "cpu":
|
||||||
warn("not running on CPU, skipping")
|
warn("not running on CPU, skipping")
|
||||||
|
@ -635,18 +542,18 @@ def sd15_std_float16(
|
||||||
|
|
||||||
sd15 = StableDiffusion_1(device=test_device, dtype=torch.float16)
|
sd15 = StableDiffusion_1(device=test_device, dtype=torch.float16)
|
||||||
|
|
||||||
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
|
||||||
sd15.lda.load_from_safetensors(lda_weights)
|
sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
|
||||||
sd15.unet.load_from_safetensors(unet_weights_std)
|
sd15.unet.load_from_safetensors(sd15_unet_weights_path)
|
||||||
|
|
||||||
return sd15
|
return sd15
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sd15_std_bfloat16(
|
def sd15_std_bfloat16(
|
||||||
text_encoder_weights: Path,
|
sd15_text_encoder_weights_path: Path,
|
||||||
lda_weights: Path,
|
sd15_autoencoder_weights_path: Path,
|
||||||
unet_weights_std: Path,
|
sd15_unet_weights_path: Path,
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
) -> StableDiffusion_1:
|
) -> StableDiffusion_1:
|
||||||
if test_device.type == "cpu":
|
if test_device.type == "cpu":
|
||||||
|
@ -655,16 +562,19 @@ def sd15_std_bfloat16(
|
||||||
|
|
||||||
sd15 = StableDiffusion_1(device=test_device, dtype=torch.bfloat16)
|
sd15 = StableDiffusion_1(device=test_device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
|
||||||
sd15.lda.load_from_safetensors(lda_weights)
|
sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
|
||||||
sd15.unet.load_from_safetensors(unet_weights_std)
|
sd15.unet.load_from_safetensors(sd15_unet_weights_path)
|
||||||
|
|
||||||
return sd15
|
return sd15
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sd15_inpainting(
|
def sd15_inpainting(
|
||||||
text_encoder_weights: Path, lda_weights: Path, unet_weights_inpainting: Path, test_device: torch.device
|
sd15_text_encoder_weights_path: Path,
|
||||||
|
sd15_autoencoder_weights_path: Path,
|
||||||
|
sd15_unet_inpainting_weights_path: Path,
|
||||||
|
test_device: torch.device,
|
||||||
) -> StableDiffusion_1_Inpainting:
|
) -> StableDiffusion_1_Inpainting:
|
||||||
if test_device.type == "cpu":
|
if test_device.type == "cpu":
|
||||||
warn("not running on CPU, skipping")
|
warn("not running on CPU, skipping")
|
||||||
|
@ -673,16 +583,19 @@ def sd15_inpainting(
|
||||||
unet = SD1UNet(in_channels=9)
|
unet = SD1UNet(in_channels=9)
|
||||||
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device)
|
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device)
|
||||||
|
|
||||||
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
|
||||||
sd15.lda.load_from_safetensors(lda_weights)
|
sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
|
||||||
sd15.unet.load_from_safetensors(unet_weights_inpainting)
|
sd15.unet.load_from_safetensors(sd15_unet_inpainting_weights_path)
|
||||||
|
|
||||||
return sd15
|
return sd15
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sd15_inpainting_float16(
|
def sd15_inpainting_float16(
|
||||||
text_encoder_weights: Path, lda_weights: Path, unet_weights_inpainting: Path, test_device: torch.device
|
sd15_text_encoder_weights_path: Path,
|
||||||
|
sd15_autoencoder_weights_path: Path,
|
||||||
|
sd15_unet_inpainting_weights_path: Path,
|
||||||
|
test_device: torch.device,
|
||||||
) -> StableDiffusion_1_Inpainting:
|
) -> StableDiffusion_1_Inpainting:
|
||||||
if test_device.type == "cpu":
|
if test_device.type == "cpu":
|
||||||
warn("not running on CPU, skipping")
|
warn("not running on CPU, skipping")
|
||||||
|
@ -691,16 +604,19 @@ def sd15_inpainting_float16(
|
||||||
unet = SD1UNet(in_channels=9)
|
unet = SD1UNet(in_channels=9)
|
||||||
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device, dtype=torch.float16)
|
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device, dtype=torch.float16)
|
||||||
|
|
||||||
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
|
||||||
sd15.lda.load_from_safetensors(lda_weights)
|
sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
|
||||||
sd15.unet.load_from_safetensors(unet_weights_inpainting)
|
sd15.unet.load_from_safetensors(sd15_unet_inpainting_weights_path)
|
||||||
|
|
||||||
return sd15
|
return sd15
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sd15_ddim(
|
def sd15_ddim(
|
||||||
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
|
sd15_text_encoder_weights_path: Path,
|
||||||
|
sd15_autoencoder_weights_path: Path,
|
||||||
|
sd15_unet_weights_path: Path,
|
||||||
|
test_device: torch.device,
|
||||||
) -> StableDiffusion_1:
|
) -> StableDiffusion_1:
|
||||||
if test_device.type == "cpu":
|
if test_device.type == "cpu":
|
||||||
warn("not running on CPU, skipping")
|
warn("not running on CPU, skipping")
|
||||||
|
@ -709,16 +625,19 @@ def sd15_ddim(
|
||||||
ddim_solver = DDIM(num_inference_steps=20)
|
ddim_solver = DDIM(num_inference_steps=20)
|
||||||
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
|
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
|
||||||
|
|
||||||
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
|
||||||
sd15.lda.load_from_safetensors(lda_weights)
|
sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
|
||||||
sd15.unet.load_from_safetensors(unet_weights_std)
|
sd15.unet.load_from_safetensors(sd15_unet_weights_path)
|
||||||
|
|
||||||
return sd15
|
return sd15
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sd15_ddim_karras(
|
def sd15_ddim_karras(
|
||||||
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
|
sd15_text_encoder_weights_path: Path,
|
||||||
|
sd15_autoencoder_weights_path: Path,
|
||||||
|
sd15_unet_weights_path: Path,
|
||||||
|
test_device: torch.device,
|
||||||
) -> StableDiffusion_1:
|
) -> StableDiffusion_1:
|
||||||
if test_device.type == "cpu":
|
if test_device.type == "cpu":
|
||||||
warn("not running on CPU, skipping")
|
warn("not running on CPU, skipping")
|
||||||
|
@ -727,16 +646,18 @@ def sd15_ddim_karras(
|
||||||
ddim_solver = DDIM(num_inference_steps=20, params=SolverParams(noise_schedule=NoiseSchedule.KARRAS))
|
ddim_solver = DDIM(num_inference_steps=20, params=SolverParams(noise_schedule=NoiseSchedule.KARRAS))
|
||||||
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
|
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
|
||||||
|
|
||||||
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
|
||||||
sd15.lda.load_from_safetensors(lda_weights)
|
sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
|
||||||
sd15.unet.load_from_safetensors(unet_weights_std)
|
sd15.unet.load_from_safetensors(sd15_unet_weights_path)
|
||||||
|
|
||||||
return sd15
|
return sd15
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sd15_euler(
|
def sd15_euler(
|
||||||
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
|
sd15_text_encoder_weights_path: Path,
|
||||||
|
sd15_autoencoder_weights_path: Path,
|
||||||
|
sd15_unet_weights_path: Path,
|
||||||
|
test_device: torch.device,
|
||||||
) -> StableDiffusion_1:
|
) -> StableDiffusion_1:
|
||||||
if test_device.type == "cpu":
|
if test_device.type == "cpu":
|
||||||
warn("not running on CPU, skipping")
|
warn("not running on CPU, skipping")
|
||||||
|
@ -745,16 +666,19 @@ def sd15_euler(
|
||||||
euler_solver = Euler(num_inference_steps=30)
|
euler_solver = Euler(num_inference_steps=30)
|
||||||
sd15 = StableDiffusion_1(solver=euler_solver, device=test_device)
|
sd15 = StableDiffusion_1(solver=euler_solver, device=test_device)
|
||||||
|
|
||||||
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
|
||||||
sd15.lda.load_from_safetensors(lda_weights)
|
sd15.lda.load_from_safetensors(sd15_autoencoder_weights_path)
|
||||||
sd15.unet.load_from_safetensors(unet_weights_std)
|
sd15.unet.load_from_safetensors(sd15_unet_weights_path)
|
||||||
|
|
||||||
return sd15
|
return sd15
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sd15_ddim_lda_ft_mse(
|
def sd15_ddim_lda_ft_mse(
|
||||||
text_encoder_weights: Path, lda_ft_mse_weights: Path, unet_weights_std: Path, test_device: torch.device
|
sd15_text_encoder_weights_path: Path,
|
||||||
|
sd15_autoencoder_mse_weights_path: Path,
|
||||||
|
sd15_unet_weights_path: Path,
|
||||||
|
test_device: torch.device,
|
||||||
) -> StableDiffusion_1:
|
) -> StableDiffusion_1:
|
||||||
if test_device.type == "cpu":
|
if test_device.type == "cpu":
|
||||||
warn("not running on CPU, skipping")
|
warn("not running on CPU, skipping")
|
||||||
|
@ -763,52 +687,19 @@ def sd15_ddim_lda_ft_mse(
|
||||||
ddim_solver = DDIM(num_inference_steps=20)
|
ddim_solver = DDIM(num_inference_steps=20)
|
||||||
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
|
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
|
||||||
|
|
||||||
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
sd15.clip_text_encoder.load_from_safetensors(sd15_text_encoder_weights_path)
|
||||||
sd15.lda.load_state_dict(load_from_safetensors(lda_ft_mse_weights))
|
sd15.lda.load_from_safetensors(sd15_autoencoder_mse_weights_path)
|
||||||
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std))
|
sd15.unet.load_from_safetensors(sd15_unet_weights_path)
|
||||||
|
|
||||||
return sd15
|
return sd15
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sdxl_lda_weights(test_weights_path: Path) -> Path:
|
|
||||||
sdxl_lda_weights = test_weights_path / "sdxl-lda.safetensors"
|
|
||||||
if not sdxl_lda_weights.is_file():
|
|
||||||
warn(message=f"could not find weights at {sdxl_lda_weights}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return sdxl_lda_weights
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path:
|
|
||||||
sdxl_lda_weights = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
|
|
||||||
if not sdxl_lda_weights.is_file():
|
|
||||||
warn(message=f"could not find weights at {sdxl_lda_weights}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return sdxl_lda_weights
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sdxl_unet_weights(test_weights_path: Path) -> Path:
|
|
||||||
sdxl_unet_weights = test_weights_path / "sdxl-unet.safetensors"
|
|
||||||
if not sdxl_unet_weights.is_file():
|
|
||||||
warn(message=f"could not find weights at {sdxl_unet_weights}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return sdxl_unet_weights
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sdxl_text_encoder_weights(test_weights_path: Path) -> Path:
|
|
||||||
sdxl_double_text_encoder_weights = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
|
|
||||||
if not sdxl_double_text_encoder_weights.is_file():
|
|
||||||
warn(message=f"could not find weights at {sdxl_double_text_encoder_weights}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return sdxl_double_text_encoder_weights
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sdxl_ddim(
|
def sdxl_ddim(
|
||||||
sdxl_text_encoder_weights: Path, sdxl_lda_weights: Path, sdxl_unet_weights: Path, test_device: torch.device
|
sdxl_text_encoder_weights_path: Path,
|
||||||
|
sdxl_autoencoder_weights_path: Path,
|
||||||
|
sdxl_unet_weights_path: Path,
|
||||||
|
test_device: torch.device,
|
||||||
) -> StableDiffusion_XL:
|
) -> StableDiffusion_XL:
|
||||||
if test_device.type == "cpu":
|
if test_device.type == "cpu":
|
||||||
warn(message="not running on CPU, skipping")
|
warn(message="not running on CPU, skipping")
|
||||||
|
@ -817,16 +708,19 @@ def sdxl_ddim(
|
||||||
solver = DDIM(num_inference_steps=30)
|
solver = DDIM(num_inference_steps=30)
|
||||||
sdxl = StableDiffusion_XL(solver=solver, device=test_device)
|
sdxl = StableDiffusion_XL(solver=solver, device=test_device)
|
||||||
|
|
||||||
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
|
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights_path)
|
||||||
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_weights)
|
sdxl.lda.load_from_safetensors(tensors_path=sdxl_autoencoder_weights_path)
|
||||||
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights)
|
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights_path)
|
||||||
|
|
||||||
return sdxl
|
return sdxl
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sdxl_ddim_lda_fp16_fix(
|
def sdxl_ddim_lda_fp16_fix(
|
||||||
sdxl_text_encoder_weights: Path, sdxl_lda_fp16_fix_weights: Path, sdxl_unet_weights: Path, test_device: torch.device
|
sdxl_text_encoder_weights_path: Path,
|
||||||
|
sdxl_autoencoder_fp16fix_weights_path: Path,
|
||||||
|
sdxl_unet_weights_path: Path,
|
||||||
|
test_device: torch.device,
|
||||||
) -> StableDiffusion_XL:
|
) -> StableDiffusion_XL:
|
||||||
if test_device.type == "cpu":
|
if test_device.type == "cpu":
|
||||||
warn(message="not running on CPU, skipping")
|
warn(message="not running on CPU, skipping")
|
||||||
|
@ -835,9 +729,9 @@ def sdxl_ddim_lda_fp16_fix(
|
||||||
solver = DDIM(num_inference_steps=30)
|
solver = DDIM(num_inference_steps=30)
|
||||||
sdxl = StableDiffusion_XL(solver=solver, device=test_device)
|
sdxl = StableDiffusion_XL(solver=solver, device=test_device)
|
||||||
|
|
||||||
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
|
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights_path)
|
||||||
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights)
|
sdxl.lda.load_from_safetensors(tensors_path=sdxl_autoencoder_fp16fix_weights_path)
|
||||||
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights)
|
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights_path)
|
||||||
|
|
||||||
return sdxl
|
return sdxl
|
||||||
|
|
||||||
|
@ -856,23 +750,18 @@ def sdxl_euler_deterministic(sdxl_ddim: StableDiffusion_XL) -> StableDiffusion_X
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def multi_upscaler(
|
def multi_upscaler(
|
||||||
test_weights_path: Path,
|
controlnet_tiles_weights_path: Path,
|
||||||
unet_weights_std: Path,
|
sd15_text_encoder_weights_path: Path,
|
||||||
text_encoder_weights: Path,
|
sd15_autoencoder_mse_weights_path: Path,
|
||||||
lda_ft_mse_weights: Path,
|
sd15_unet_weights_path: Path,
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
) -> MultiUpscaler:
|
) -> MultiUpscaler:
|
||||||
controlnet_tile_weights = test_weights_path / "controlnet" / "lllyasviel_control_v11f1e_sd15_tile.safetensors"
|
|
||||||
if not controlnet_tile_weights.is_file():
|
|
||||||
warn(message=f"could not find weights at {controlnet_tile_weights}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
|
|
||||||
return MultiUpscaler(
|
return MultiUpscaler(
|
||||||
checkpoints=UpscalerCheckpoints(
|
checkpoints=UpscalerCheckpoints(
|
||||||
unet=unet_weights_std,
|
unet=sd15_unet_weights_path,
|
||||||
clip_text_encoder=text_encoder_weights,
|
clip_text_encoder=sd15_text_encoder_weights_path,
|
||||||
lda=lda_ft_mse_weights,
|
lda=sd15_autoencoder_mse_weights_path,
|
||||||
controlnet_tile=controlnet_tile_weights,
|
controlnet_tile=controlnet_tiles_weights_path,
|
||||||
),
|
),
|
||||||
device=test_device,
|
device=test_device,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
|
@ -891,7 +780,9 @@ def expected_multi_upscaler(ref_path: Path) -> Image.Image:
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_diffusion_std_random_init(
|
def test_diffusion_std_random_init(
|
||||||
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
|
sd15_std: StableDiffusion_1,
|
||||||
|
expected_image_std_random_init: Image.Image,
|
||||||
|
test_device: torch.device,
|
||||||
):
|
):
|
||||||
sd15 = sd15_std
|
sd15 = sd15_std
|
||||||
|
|
||||||
|
@ -1553,6 +1444,9 @@ def test_diffusion_sdxl_control_lora(
|
||||||
|
|
||||||
adapters: dict[str, ControlLoraAdapter] = {}
|
adapters: dict[str, ControlLoraAdapter] = {}
|
||||||
for config_name, config in configs.items():
|
for config_name, config in configs.items():
|
||||||
|
if not config.weights_path.is_file():
|
||||||
|
pytest.skip(f"File not found: {config.weights_path}")
|
||||||
|
|
||||||
adapter = ControlLoraAdapter(
|
adapter = ControlLoraAdapter(
|
||||||
name=config_name,
|
name=config_name,
|
||||||
scale=config.scale,
|
scale=config.scale,
|
||||||
|
@ -1922,13 +1816,18 @@ def test_diffusion_textual_inversion_random_init(
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_diffusion_ella_adapter(
|
def test_diffusion_ella_adapter(
|
||||||
sd15_std_float16: StableDiffusion_1,
|
sd15_std_float16: StableDiffusion_1,
|
||||||
ella_weights: tuple[Path, Path],
|
ella_sd15_tsc_t5xl_weights_path: Path,
|
||||||
|
t5xl_transformers_path: str,
|
||||||
expected_image_ella_adapter: Image.Image,
|
expected_image_ella_adapter: Image.Image,
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
|
use_local_weights: bool,
|
||||||
):
|
):
|
||||||
sd15 = sd15_std_float16
|
sd15 = sd15_std_float16
|
||||||
ella_adapter_weights, t5xl_weights = ella_weights
|
t5_encoder = T5TextEmbedder(
|
||||||
t5_encoder = T5TextEmbedder(pretrained_path=t5xl_weights, max_length=128).to(test_device, torch.float16)
|
pretrained_path=t5xl_transformers_path,
|
||||||
|
local_files_only=use_local_weights,
|
||||||
|
max_length=128,
|
||||||
|
).to(test_device, torch.float16)
|
||||||
|
|
||||||
prompt = "a chinese man wearing a white shirt and a checkered headscarf, holds a large falcon near his shoulder. the falcon has dark feathers with a distinctive beak. the background consists of a clear sky and a fence, suggesting an outdoor setting, possibly a desert or arid region"
|
prompt = "a chinese man wearing a white shirt and a checkered headscarf, holds a large falcon near his shoulder. the falcon has dark feathers with a distinctive beak. the background consists of a clear sky and a fence, suggesting an outdoor setting, possibly a desert or arid region"
|
||||||
negative_prompt = ""
|
negative_prompt = ""
|
||||||
|
@ -1938,7 +1837,7 @@ def test_diffusion_ella_adapter(
|
||||||
llm_text_embedding, negative_prompt_embeds = t5_encoder(prompt), t5_encoder(negative_prompt)
|
llm_text_embedding, negative_prompt_embeds = t5_encoder(prompt), t5_encoder(negative_prompt)
|
||||||
prompt_embedding = torch.cat((negative_prompt_embeds, llm_text_embedding)).to(test_device, torch.float16)
|
prompt_embedding = torch.cat((negative_prompt_embeds, llm_text_embedding)).to(test_device, torch.float16)
|
||||||
|
|
||||||
adapter = SD1ELLAAdapter(target=sd15.unet, weights=load_from_safetensors(ella_adapter_weights))
|
adapter = SD1ELLAAdapter(target=sd15.unet, weights=load_from_safetensors(ella_sd15_tsc_t5xl_weights_path))
|
||||||
adapter.inject()
|
adapter.inject()
|
||||||
sd15.set_inference_steps(50)
|
sd15.set_inference_steps(50)
|
||||||
manual_seed(1001)
|
manual_seed(1001)
|
||||||
|
@ -1959,8 +1858,8 @@ def test_diffusion_ella_adapter(
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_diffusion_ip_adapter(
|
def test_diffusion_ip_adapter(
|
||||||
sd15_ddim_lda_ft_mse: StableDiffusion_1,
|
sd15_ddim_lda_ft_mse: StableDiffusion_1,
|
||||||
ip_adapter_weights: Path,
|
ip_adapter_sd15_weights_path: Path,
|
||||||
image_encoder_weights: Path,
|
clip_image_encoder_huge_weights_path: Path,
|
||||||
woman_image: Image.Image,
|
woman_image: Image.Image,
|
||||||
expected_image_ip_adapter_woman: Image.Image,
|
expected_image_ip_adapter_woman: Image.Image,
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
|
@ -1976,8 +1875,8 @@ def test_diffusion_ip_adapter(
|
||||||
prompt = "best quality, high quality"
|
prompt = "best quality, high quality"
|
||||||
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||||
|
|
||||||
ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_weights))
|
ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_sd15_weights_path))
|
||||||
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
|
||||||
ip_adapter.inject()
|
ip_adapter.inject()
|
||||||
|
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
|
@ -2004,8 +1903,8 @@ def test_diffusion_ip_adapter(
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_diffusion_ip_adapter_multi(
|
def test_diffusion_ip_adapter_multi(
|
||||||
sd15_ddim_lda_ft_mse: StableDiffusion_1,
|
sd15_ddim_lda_ft_mse: StableDiffusion_1,
|
||||||
ip_adapter_weights: Path,
|
ip_adapter_sd15_weights_path: Path,
|
||||||
image_encoder_weights: Path,
|
clip_image_encoder_huge_weights_path: Path,
|
||||||
woman_image: Image.Image,
|
woman_image: Image.Image,
|
||||||
statue_image: Image.Image,
|
statue_image: Image.Image,
|
||||||
expected_image_ip_adapter_multi: Image.Image,
|
expected_image_ip_adapter_multi: Image.Image,
|
||||||
|
@ -2016,8 +1915,8 @@ def test_diffusion_ip_adapter_multi(
|
||||||
prompt = "best quality, high quality"
|
prompt = "best quality, high quality"
|
||||||
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||||
|
|
||||||
ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_weights))
|
ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_sd15_weights_path))
|
||||||
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
|
||||||
ip_adapter.inject()
|
ip_adapter.inject()
|
||||||
|
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
|
@ -2044,8 +1943,8 @@ def test_diffusion_ip_adapter_multi(
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_diffusion_sdxl_ip_adapter(
|
def test_diffusion_sdxl_ip_adapter(
|
||||||
sdxl_ddim: StableDiffusion_XL,
|
sdxl_ddim: StableDiffusion_XL,
|
||||||
sdxl_ip_adapter_weights: Path,
|
ip_adapter_sdxl_weights_path: Path,
|
||||||
image_encoder_weights: Path,
|
clip_image_encoder_huge_weights_path: Path,
|
||||||
woman_image: Image.Image,
|
woman_image: Image.Image,
|
||||||
expected_image_sdxl_ip_adapter_woman: Image.Image,
|
expected_image_sdxl_ip_adapter_woman: Image.Image,
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
|
@ -2055,8 +1954,8 @@ def test_diffusion_sdxl_ip_adapter(
|
||||||
prompt = "best quality, high quality"
|
prompt = "best quality, high quality"
|
||||||
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||||
|
|
||||||
ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_weights))
|
ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(ip_adapter_sdxl_weights_path))
|
||||||
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
|
||||||
ip_adapter.inject()
|
ip_adapter.inject()
|
||||||
|
|
||||||
with no_grad():
|
with no_grad():
|
||||||
|
@ -2093,8 +1992,8 @@ def test_diffusion_sdxl_ip_adapter(
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_diffusion_ip_adapter_controlnet(
|
def test_diffusion_ip_adapter_controlnet(
|
||||||
sd15_ddim: StableDiffusion_1,
|
sd15_ddim: StableDiffusion_1,
|
||||||
ip_adapter_weights: Path,
|
ip_adapter_sd15_weights_path: Path,
|
||||||
image_encoder_weights: Path,
|
clip_image_encoder_huge_weights_path: Path,
|
||||||
lora_data_pokemon: tuple[Image.Image, Path],
|
lora_data_pokemon: tuple[Image.Image, Path],
|
||||||
controlnet_data_depth: tuple[str, Image.Image, Image.Image, Path],
|
controlnet_data_depth: tuple[str, Image.Image, Image.Image, Path],
|
||||||
expected_image_ip_adapter_controlnet: Image.Image,
|
expected_image_ip_adapter_controlnet: Image.Image,
|
||||||
|
@ -2107,8 +2006,8 @@ def test_diffusion_ip_adapter_controlnet(
|
||||||
prompt = "best quality, high quality"
|
prompt = "best quality, high quality"
|
||||||
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||||
|
|
||||||
ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_weights))
|
ip_adapter = SD1IPAdapter(target=sd15.unet, weights=load_from_safetensors(ip_adapter_sd15_weights_path))
|
||||||
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
|
||||||
ip_adapter.inject()
|
ip_adapter.inject()
|
||||||
|
|
||||||
depth_controlnet = SD1ControlnetAdapter(
|
depth_controlnet = SD1ControlnetAdapter(
|
||||||
|
@ -2149,8 +2048,8 @@ def test_diffusion_ip_adapter_controlnet(
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_diffusion_ip_adapter_plus(
|
def test_diffusion_ip_adapter_plus(
|
||||||
sd15_ddim_lda_ft_mse: StableDiffusion_1,
|
sd15_ddim_lda_ft_mse: StableDiffusion_1,
|
||||||
ip_adapter_plus_weights: Path,
|
ip_adapter_sd15_plus_weights_path: Path,
|
||||||
image_encoder_weights: Path,
|
clip_image_encoder_huge_weights_path: Path,
|
||||||
statue_image: Image.Image,
|
statue_image: Image.Image,
|
||||||
expected_image_ip_adapter_plus_statue: Image.Image,
|
expected_image_ip_adapter_plus_statue: Image.Image,
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
|
@ -2161,9 +2060,9 @@ def test_diffusion_ip_adapter_plus(
|
||||||
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||||
|
|
||||||
ip_adapter = SD1IPAdapter(
|
ip_adapter = SD1IPAdapter(
|
||||||
target=sd15.unet, weights=load_from_safetensors(ip_adapter_plus_weights), fine_grained=True
|
target=sd15.unet, weights=load_from_safetensors(ip_adapter_sd15_plus_weights_path), fine_grained=True
|
||||||
)
|
)
|
||||||
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
|
||||||
ip_adapter.inject()
|
ip_adapter.inject()
|
||||||
|
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
|
@ -2190,8 +2089,8 @@ def test_diffusion_ip_adapter_plus(
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_diffusion_sdxl_ip_adapter_plus(
|
def test_diffusion_sdxl_ip_adapter_plus(
|
||||||
sdxl_ddim: StableDiffusion_XL,
|
sdxl_ddim: StableDiffusion_XL,
|
||||||
sdxl_ip_adapter_plus_weights: Path,
|
ip_adapter_sdxl_plus_weights_path: Path,
|
||||||
image_encoder_weights: Path,
|
clip_image_encoder_huge_weights_path: Path,
|
||||||
woman_image: Image.Image,
|
woman_image: Image.Image,
|
||||||
expected_image_sdxl_ip_adapter_plus_woman: Image.Image,
|
expected_image_sdxl_ip_adapter_plus_woman: Image.Image,
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
|
@ -2202,9 +2101,9 @@ def test_diffusion_sdxl_ip_adapter_plus(
|
||||||
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||||
|
|
||||||
ip_adapter = SDXLIPAdapter(
|
ip_adapter = SDXLIPAdapter(
|
||||||
target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_plus_weights), fine_grained=True
|
target=sdxl.unet, weights=load_from_safetensors(ip_adapter_sdxl_plus_weights_path), fine_grained=True
|
||||||
)
|
)
|
||||||
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
|
||||||
ip_adapter.inject()
|
ip_adapter.inject()
|
||||||
|
|
||||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||||
|
@ -2608,8 +2507,8 @@ def test_freeu(
|
||||||
def test_hello_world(
|
def test_hello_world(
|
||||||
sdxl_ddim_lda_fp16_fix: StableDiffusion_XL,
|
sdxl_ddim_lda_fp16_fix: StableDiffusion_XL,
|
||||||
t2i_adapter_xl_data_canny: tuple[str, Image.Image, Image.Image, Path],
|
t2i_adapter_xl_data_canny: tuple[str, Image.Image, Image.Image, Path],
|
||||||
sdxl_ip_adapter_weights: Path,
|
ip_adapter_sdxl_weights_path: Path,
|
||||||
image_encoder_weights: Path,
|
clip_image_encoder_huge_weights_path: Path,
|
||||||
hello_world_assets: tuple[Image.Image, Image.Image, Image.Image, Image.Image],
|
hello_world_assets: tuple[Image.Image, Image.Image, Image.Image, Image.Image],
|
||||||
) -> None:
|
) -> None:
|
||||||
sdxl = sdxl_ddim_lda_fp16_fix.to(dtype=torch.float16)
|
sdxl = sdxl_ddim_lda_fp16_fix.to(dtype=torch.float16)
|
||||||
|
@ -2622,8 +2521,8 @@ def test_hello_world(
|
||||||
warn(f"could not find weights at {weights_path}, skipping")
|
warn(f"could not find weights at {weights_path}, skipping")
|
||||||
pytest.skip(allow_module_level=True)
|
pytest.skip(allow_module_level=True)
|
||||||
|
|
||||||
ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_weights))
|
ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(ip_adapter_sdxl_weights_path))
|
||||||
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
|
||||||
ip_adapter.inject()
|
ip_adapter.inject()
|
||||||
|
|
||||||
image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt))
|
image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt))
|
||||||
|
@ -2743,24 +2642,19 @@ def expected_ic_light(ref_path: Path) -> Image.Image:
|
||||||
return _img_open(ref_path / "expected_ic_light.png").convert("RGB")
|
return _img_open(ref_path / "expected_ic_light.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def ic_light_sd15_fc_weights(test_weights_path: Path) -> Path:
|
|
||||||
return test_weights_path / "iclight_sd15_fc-refiners.safetensors"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def ic_light_sd15_fc(
|
def ic_light_sd15_fc(
|
||||||
ic_light_sd15_fc_weights: Path,
|
ic_light_sd15_fc_weights_path: Path,
|
||||||
unet_weights_std: Path,
|
sd15_unet_weights_path: Path,
|
||||||
lda_weights: Path,
|
sd15_autoencoder_weights_path: Path,
|
||||||
text_encoder_weights: Path,
|
sd15_text_encoder_weights_path: Path,
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
) -> ICLight:
|
) -> ICLight:
|
||||||
return ICLight(
|
return ICLight(
|
||||||
patch_weights=load_from_safetensors(ic_light_sd15_fc_weights),
|
patch_weights=load_from_safetensors(ic_light_sd15_fc_weights_path),
|
||||||
unet=SD1UNet(in_channels=4).load_from_safetensors(unet_weights_std),
|
unet=SD1UNet(in_channels=4).load_from_safetensors(sd15_unet_weights_path),
|
||||||
lda=SD1Autoencoder().load_from_safetensors(lda_weights),
|
lda=SD1Autoencoder().load_from_safetensors(sd15_autoencoder_weights_path),
|
||||||
clip_text_encoder=CLIPTextEncoderL().load_from_safetensors(text_encoder_weights),
|
clip_text_encoder=CLIPTextEncoderL().load_from_safetensors(sd15_text_encoder_weights_path),
|
||||||
device=test_device,
|
device=test_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -29,74 +29,11 @@ def ref_path(test_e2e_path: Path) -> Path:
|
||||||
return test_e2e_path / "test_doc_examples_ref"
|
return test_e2e_path / "test_doc_examples_ref"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def sdxl_text_encoder_weights(test_weights_path: Path) -> Path:
|
|
||||||
path = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
|
|
||||||
if not path.is_file():
|
|
||||||
warn(message=f"could not find weights at {path}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return path
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path:
|
|
||||||
path = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
|
|
||||||
if not path.is_file():
|
|
||||||
warn(message=f"could not find weights at {path}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return path
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def sdxl_unet_weights(test_weights_path: Path) -> Path:
|
|
||||||
path = test_weights_path / "sdxl-unet.safetensors"
|
|
||||||
if not path.is_file():
|
|
||||||
warn(message=f"could not find weights at {path}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return path
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def sdxl_ip_adapter_plus_weights(test_weights_path: Path) -> Path:
|
|
||||||
path = test_weights_path / "ip-adapter-plus_sdxl_vit-h.safetensors"
|
|
||||||
if not path.is_file():
|
|
||||||
warn(f"could not find weights at {path}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return path
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def image_encoder_weights(test_weights_path: Path) -> Path:
|
|
||||||
path = test_weights_path / "CLIPImageEncoderH.safetensors"
|
|
||||||
if not path.is_file():
|
|
||||||
warn(f"could not find weights at {path}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return path
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def scifi_lora_weights(test_weights_path: Path) -> Path:
|
|
||||||
path = test_weights_path / "loras" / "Sci-fi_Environments_sdxl.safetensors"
|
|
||||||
if not path.is_file():
|
|
||||||
warn(message=f"could not find weights at {path}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return path
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def pixelart_lora_weights(test_weights_path: Path) -> Path:
|
|
||||||
path = test_weights_path / "loras" / "pixel-art-xl-v1.1.safetensors"
|
|
||||||
if not path.is_file():
|
|
||||||
warn(message=f"could not find weights at {path}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return path
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sdxl(
|
def sdxl(
|
||||||
sdxl_text_encoder_weights: Path,
|
sdxl_text_encoder_weights_path: Path,
|
||||||
sdxl_lda_fp16_fix_weights: Path,
|
sdxl_autoencoder_fp16fix_weights_path: Path,
|
||||||
sdxl_unet_weights: Path,
|
sdxl_unet_weights_path: Path,
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
) -> StableDiffusion_XL:
|
) -> StableDiffusion_XL:
|
||||||
if test_device.type == "cpu":
|
if test_device.type == "cpu":
|
||||||
|
@ -105,9 +42,9 @@ def sdxl(
|
||||||
|
|
||||||
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16)
|
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16)
|
||||||
|
|
||||||
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
|
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights_path)
|
||||||
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights)
|
sdxl.lda.load_from_safetensors(tensors_path=sdxl_autoencoder_fp16fix_weights_path)
|
||||||
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights)
|
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights_path)
|
||||||
|
|
||||||
return sdxl
|
return sdxl
|
||||||
|
|
||||||
|
@ -180,7 +117,7 @@ def test_guide_adapting_sdxl_vanilla(
|
||||||
def test_guide_adapting_sdxl_single_lora(
|
def test_guide_adapting_sdxl_single_lora(
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
sdxl: StableDiffusion_XL,
|
sdxl: StableDiffusion_XL,
|
||||||
scifi_lora_weights: Path,
|
lora_scifi_weights_path: Path,
|
||||||
expected_image_guide_adapting_sdxl_single_lora: Image.Image,
|
expected_image_guide_adapting_sdxl_single_lora: Image.Image,
|
||||||
) -> None:
|
) -> None:
|
||||||
if test_device.type == "cpu":
|
if test_device.type == "cpu":
|
||||||
|
@ -195,7 +132,7 @@ def test_guide_adapting_sdxl_single_lora(
|
||||||
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
|
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
|
||||||
|
|
||||||
manager = SDLoraManager(sdxl)
|
manager = SDLoraManager(sdxl)
|
||||||
manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights))
|
manager.add_loras("scifi-lora", load_from_safetensors(lora_scifi_weights_path))
|
||||||
|
|
||||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||||
text=prompt + ", best quality, high quality",
|
text=prompt + ", best quality, high quality",
|
||||||
|
@ -222,8 +159,8 @@ def test_guide_adapting_sdxl_single_lora(
|
||||||
def test_guide_adapting_sdxl_multiple_loras(
|
def test_guide_adapting_sdxl_multiple_loras(
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
sdxl: StableDiffusion_XL,
|
sdxl: StableDiffusion_XL,
|
||||||
scifi_lora_weights: Path,
|
lora_scifi_weights_path: Path,
|
||||||
pixelart_lora_weights: Path,
|
lora_pixelart_weights_path: Path,
|
||||||
expected_image_guide_adapting_sdxl_multiple_loras: Image.Image,
|
expected_image_guide_adapting_sdxl_multiple_loras: Image.Image,
|
||||||
) -> None:
|
) -> None:
|
||||||
if test_device.type == "cpu":
|
if test_device.type == "cpu":
|
||||||
|
@ -238,8 +175,8 @@ def test_guide_adapting_sdxl_multiple_loras(
|
||||||
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
|
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
|
||||||
|
|
||||||
manager = SDLoraManager(sdxl)
|
manager = SDLoraManager(sdxl)
|
||||||
manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights))
|
manager.add_loras("scifi-lora", load_from_safetensors(lora_scifi_weights_path))
|
||||||
manager.add_loras("pixel-art-lora", load_from_safetensors(pixelart_lora_weights), scale=1.4)
|
manager.add_loras("pixel-art-lora", load_from_safetensors(lora_pixelart_weights_path), scale=1.4)
|
||||||
|
|
||||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||||
text=prompt + ", best quality, high quality",
|
text=prompt + ", best quality, high quality",
|
||||||
|
@ -266,10 +203,10 @@ def test_guide_adapting_sdxl_multiple_loras(
|
||||||
def test_guide_adapting_sdxl_loras_ip_adapter(
|
def test_guide_adapting_sdxl_loras_ip_adapter(
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
sdxl: StableDiffusion_XL,
|
sdxl: StableDiffusion_XL,
|
||||||
sdxl_ip_adapter_plus_weights: Path,
|
ip_adapter_sdxl_plus_weights_path: Path,
|
||||||
image_encoder_weights: Path,
|
clip_image_encoder_huge_weights_path: Path,
|
||||||
scifi_lora_weights: Path,
|
lora_scifi_weights_path: Path,
|
||||||
pixelart_lora_weights: Path,
|
lora_pixelart_weights_path: Path,
|
||||||
image_prompt_german_castle: Image.Image,
|
image_prompt_german_castle: Image.Image,
|
||||||
expected_image_guide_adapting_sdxl_loras_ip_adapter: Image.Image,
|
expected_image_guide_adapting_sdxl_loras_ip_adapter: Image.Image,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -285,16 +222,16 @@ def test_guide_adapting_sdxl_loras_ip_adapter(
|
||||||
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
|
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
|
||||||
|
|
||||||
manager = SDLoraManager(sdxl)
|
manager = SDLoraManager(sdxl)
|
||||||
manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights), scale=1.5)
|
manager.add_loras("scifi-lora", load_from_safetensors(lora_scifi_weights_path), scale=1.5)
|
||||||
manager.add_loras("pixel-art-lora", load_from_safetensors(pixelart_lora_weights), scale=1.55)
|
manager.add_loras("pixel-art-lora", load_from_safetensors(lora_pixelart_weights_path), scale=1.55)
|
||||||
|
|
||||||
ip_adapter = SDXLIPAdapter(
|
ip_adapter = SDXLIPAdapter(
|
||||||
target=sdxl.unet,
|
target=sdxl.unet,
|
||||||
weights=load_from_safetensors(sdxl_ip_adapter_plus_weights),
|
weights=load_from_safetensors(ip_adapter_sdxl_plus_weights_path),
|
||||||
scale=1.0,
|
scale=1.0,
|
||||||
fine_grained=True,
|
fine_grained=True,
|
||||||
)
|
)
|
||||||
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
|
||||||
ip_adapter.inject()
|
ip_adapter.inject()
|
||||||
|
|
||||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||||
|
|
|
@ -26,51 +26,6 @@ def ensure_gc():
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path:
|
|
||||||
r = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
|
|
||||||
if not r.is_file():
|
|
||||||
warn(f"could not find weights at {r}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sdxl_unet_weights(test_weights_path: Path) -> Path:
|
|
||||||
r = test_weights_path / "sdxl-unet.safetensors"
|
|
||||||
if not r.is_file():
|
|
||||||
warn(f"could not find weights at {r}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sdxl_lcm_unet_weights(test_weights_path: Path) -> Path:
|
|
||||||
r = test_weights_path / "sdxl-lcm-unet.safetensors"
|
|
||||||
if not r.is_file():
|
|
||||||
warn(f"could not find weights at {r}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sdxl_text_encoder_weights(test_weights_path: Path) -> Path:
|
|
||||||
r = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
|
|
||||||
if not r.is_file():
|
|
||||||
warn(f"could not find weights at {r}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sdxl_lcm_lora_weights(test_weights_path: Path) -> Path:
|
|
||||||
r = test_weights_path / "sdxl-lcm-lora.safetensors"
|
|
||||||
if not r.is_file():
|
|
||||||
warn(f"could not find weights at {r}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def ref_path(test_e2e_path: Path) -> Path:
|
def ref_path(test_e2e_path: Path) -> Path:
|
||||||
return test_e2e_path / "test_lcm_ref"
|
return test_e2e_path / "test_lcm_ref"
|
||||||
|
@ -94,9 +49,9 @@ def expected_lcm_lora_1_2(ref_path: Path) -> Image.Image:
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_lcm_base(
|
def test_lcm_base(
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
sdxl_lda_fp16_fix_weights: Path,
|
sdxl_autoencoder_fp16fix_weights_path: Path,
|
||||||
sdxl_lcm_unet_weights: Path,
|
sdxl_unet_lcm_weights_path: Path,
|
||||||
sdxl_text_encoder_weights: Path,
|
sdxl_text_encoder_weights_path: Path,
|
||||||
expected_lcm_base: Image.Image,
|
expected_lcm_base: Image.Image,
|
||||||
) -> None:
|
) -> None:
|
||||||
if test_device.type == "cpu":
|
if test_device.type == "cpu":
|
||||||
|
@ -111,9 +66,9 @@ def test_lcm_base(
|
||||||
# not in the diffusion loop.
|
# not in the diffusion loop.
|
||||||
SDXLLcmAdapter(sdxl.unet, condition_scale=8.0).inject()
|
SDXLLcmAdapter(sdxl.unet, condition_scale=8.0).inject()
|
||||||
|
|
||||||
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights)
|
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path)
|
||||||
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights)
|
sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path)
|
||||||
sdxl.unet.load_from_safetensors(sdxl_lcm_unet_weights)
|
sdxl.unet.load_from_safetensors(sdxl_unet_lcm_weights_path)
|
||||||
|
|
||||||
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
||||||
expected_image = expected_lcm_base
|
expected_image = expected_lcm_base
|
||||||
|
@ -141,10 +96,10 @@ def test_lcm_base(
|
||||||
@pytest.mark.parametrize("condition_scale", [1.0, 1.2])
|
@pytest.mark.parametrize("condition_scale", [1.0, 1.2])
|
||||||
def test_lcm_lora_with_guidance(
|
def test_lcm_lora_with_guidance(
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
sdxl_lda_fp16_fix_weights: Path,
|
sdxl_autoencoder_fp16fix_weights_path: Path,
|
||||||
sdxl_unet_weights: Path,
|
sdxl_unet_weights_path: Path,
|
||||||
sdxl_text_encoder_weights: Path,
|
sdxl_text_encoder_weights_path: Path,
|
||||||
sdxl_lcm_lora_weights: Path,
|
lora_sdxl_lcm_weights_path: Path,
|
||||||
expected_lcm_lora_1_0: Image.Image,
|
expected_lcm_lora_1_0: Image.Image,
|
||||||
expected_lcm_lora_1_2: Image.Image,
|
expected_lcm_lora_1_2: Image.Image,
|
||||||
condition_scale: float,
|
condition_scale: float,
|
||||||
|
@ -156,12 +111,12 @@ def test_lcm_lora_with_guidance(
|
||||||
solver = LCMSolver(num_inference_steps=4)
|
solver = LCMSolver(num_inference_steps=4)
|
||||||
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver)
|
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver)
|
||||||
|
|
||||||
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights)
|
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path)
|
||||||
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights)
|
sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path)
|
||||||
sdxl.unet.load_from_safetensors(sdxl_unet_weights)
|
sdxl.unet.load_from_safetensors(sdxl_unet_weights_path)
|
||||||
|
|
||||||
manager = SDLoraManager(sdxl)
|
manager = SDLoraManager(sdxl)
|
||||||
add_lcm_lora(manager, load_from_safetensors(sdxl_lcm_lora_weights))
|
add_lcm_lora(manager, load_from_safetensors(lora_sdxl_lcm_weights_path))
|
||||||
|
|
||||||
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
||||||
expected_image = expected_lcm_lora_1_0 if condition_scale == 1.0 else expected_lcm_lora_1_2
|
expected_image = expected_lcm_lora_1_0 if condition_scale == 1.0 else expected_lcm_lora_1_2
|
||||||
|
@ -191,10 +146,10 @@ def test_lcm_lora_with_guidance(
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_lcm_lora_without_guidance(
|
def test_lcm_lora_without_guidance(
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
sdxl_lda_fp16_fix_weights: Path,
|
sdxl_autoencoder_fp16fix_weights_path: Path,
|
||||||
sdxl_unet_weights: Path,
|
sdxl_unet_weights_path: Path,
|
||||||
sdxl_text_encoder_weights: Path,
|
sdxl_text_encoder_weights_path: Path,
|
||||||
sdxl_lcm_lora_weights: Path,
|
lora_sdxl_lcm_weights_path: Path,
|
||||||
expected_lcm_lora_1_0: Image.Image,
|
expected_lcm_lora_1_0: Image.Image,
|
||||||
) -> None:
|
) -> None:
|
||||||
if test_device.type == "cpu":
|
if test_device.type == "cpu":
|
||||||
|
@ -205,12 +160,12 @@ def test_lcm_lora_without_guidance(
|
||||||
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver)
|
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver)
|
||||||
sdxl.classifier_free_guidance = False
|
sdxl.classifier_free_guidance = False
|
||||||
|
|
||||||
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights)
|
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path)
|
||||||
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights)
|
sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path)
|
||||||
sdxl.unet.load_from_safetensors(sdxl_unet_weights)
|
sdxl.unet.load_from_safetensors(sdxl_unet_weights_path)
|
||||||
|
|
||||||
manager = SDLoraManager(sdxl)
|
manager = SDLoraManager(sdxl)
|
||||||
add_lcm_lora(manager, load_from_safetensors(sdxl_lcm_lora_weights))
|
add_lcm_lora(manager, load_from_safetensors(lora_sdxl_lcm_weights_path))
|
||||||
|
|
||||||
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
||||||
expected_image = expected_lcm_lora_1_0
|
expected_image = expected_lcm_lora_1_0
|
||||||
|
|
|
@ -25,60 +25,6 @@ def ensure_gc():
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path:
|
|
||||||
r = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
|
|
||||||
if not r.is_file():
|
|
||||||
warn(f"could not find weights at {r}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sdxl_unet_weights(test_weights_path: Path) -> Path:
|
|
||||||
r = test_weights_path / "sdxl-unet.safetensors"
|
|
||||||
if not r.is_file():
|
|
||||||
warn(f"could not find weights at {r}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sdxl_lightning_4step_unet_weights(test_weights_path: Path) -> Path:
|
|
||||||
r = test_weights_path / "sdxl_lightning_4step_unet.safetensors"
|
|
||||||
if not r.is_file():
|
|
||||||
warn(f"could not find weights at {r}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sdxl_lightning_1step_unet_weights(test_weights_path: Path) -> Path:
|
|
||||||
r = test_weights_path / "sdxl_lightning_1step_unet_x0.safetensors"
|
|
||||||
if not r.is_file():
|
|
||||||
warn(f"could not find weights at {r}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sdxl_text_encoder_weights(test_weights_path: Path) -> Path:
|
|
||||||
r = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
|
|
||||||
if not r.is_file():
|
|
||||||
warn(f"could not find weights at {r}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sdxl_lightning_4step_lora_weights(test_weights_path: Path) -> Path:
|
|
||||||
r = test_weights_path / "sdxl_lightning_4step_lora.safetensors"
|
|
||||||
if not r.is_file():
|
|
||||||
warn(f"could not find weights at {r}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def ref_path(test_e2e_path: Path) -> Path:
|
def ref_path(test_e2e_path: Path) -> Path:
|
||||||
return test_e2e_path / "test_lightning_ref"
|
return test_e2e_path / "test_lightning_ref"
|
||||||
|
@ -102,16 +48,16 @@ def expected_lightning_lora_4step(ref_path: Path) -> Image.Image:
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_lightning_base_4step(
|
def test_lightning_base_4step(
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
sdxl_lda_fp16_fix_weights: Path,
|
sdxl_autoencoder_fp16fix_weights_path: Path,
|
||||||
sdxl_lightning_4step_unet_weights: Path,
|
sdxl_unet_lightning_4step_weights_path: Path,
|
||||||
sdxl_text_encoder_weights: Path,
|
sdxl_text_encoder_weights_path: Path,
|
||||||
expected_lightning_base_4step: Image.Image,
|
expected_lightning_base_4step: Image.Image,
|
||||||
) -> None:
|
) -> None:
|
||||||
if test_device.type == "cpu":
|
if test_device.type == "cpu":
|
||||||
warn(message="not running on CPU, skipping")
|
warn(message="not running on CPU, skipping")
|
||||||
pytest.skip()
|
pytest.skip()
|
||||||
|
|
||||||
unet_weights = sdxl_lightning_4step_unet_weights
|
unet_weights = sdxl_unet_lightning_4step_weights_path
|
||||||
expected_image = expected_lightning_base_4step
|
expected_image = expected_lightning_base_4step
|
||||||
|
|
||||||
solver = Euler(
|
solver = Euler(
|
||||||
|
@ -125,8 +71,8 @@ def test_lightning_base_4step(
|
||||||
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver)
|
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver)
|
||||||
sdxl.classifier_free_guidance = False
|
sdxl.classifier_free_guidance = False
|
||||||
|
|
||||||
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights)
|
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path)
|
||||||
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights)
|
sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path)
|
||||||
sdxl.unet.load_from_safetensors(unet_weights)
|
sdxl.unet.load_from_safetensors(unet_weights)
|
||||||
|
|
||||||
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
||||||
|
@ -153,16 +99,16 @@ def test_lightning_base_4step(
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_lightning_base_1step(
|
def test_lightning_base_1step(
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
sdxl_lda_fp16_fix_weights: Path,
|
sdxl_autoencoder_fp16fix_weights_path: Path,
|
||||||
sdxl_lightning_1step_unet_weights: Path,
|
sdxl_unet_lightning_1step_weights_path: Path,
|
||||||
sdxl_text_encoder_weights: Path,
|
sdxl_text_encoder_weights_path: Path,
|
||||||
expected_lightning_base_1step: Image.Image,
|
expected_lightning_base_1step: Image.Image,
|
||||||
) -> None:
|
) -> None:
|
||||||
if test_device.type == "cpu":
|
if test_device.type == "cpu":
|
||||||
warn(message="not running on CPU, skipping")
|
warn(message="not running on CPU, skipping")
|
||||||
pytest.skip()
|
pytest.skip()
|
||||||
|
|
||||||
unet_weights = sdxl_lightning_1step_unet_weights
|
unet_weights = sdxl_unet_lightning_1step_weights_path
|
||||||
expected_image = expected_lightning_base_1step
|
expected_image = expected_lightning_base_1step
|
||||||
|
|
||||||
solver = Euler(
|
solver = Euler(
|
||||||
|
@ -176,8 +122,8 @@ def test_lightning_base_1step(
|
||||||
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver)
|
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver)
|
||||||
sdxl.classifier_free_guidance = False
|
sdxl.classifier_free_guidance = False
|
||||||
|
|
||||||
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights)
|
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path)
|
||||||
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights)
|
sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path)
|
||||||
sdxl.unet.load_from_safetensors(unet_weights)
|
sdxl.unet.load_from_safetensors(unet_weights)
|
||||||
|
|
||||||
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
||||||
|
@ -204,10 +150,10 @@ def test_lightning_base_1step(
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_lightning_lora_4step(
|
def test_lightning_lora_4step(
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
sdxl_lda_fp16_fix_weights: Path,
|
sdxl_autoencoder_fp16fix_weights_path: Path,
|
||||||
sdxl_unet_weights: Path,
|
sdxl_unet_weights_path: Path,
|
||||||
sdxl_text_encoder_weights: Path,
|
sdxl_text_encoder_weights_path: Path,
|
||||||
sdxl_lightning_4step_lora_weights: Path,
|
lora_sdxl_lightning_4step_weights_path: Path,
|
||||||
expected_lightning_lora_4step: Image.Image,
|
expected_lightning_lora_4step: Image.Image,
|
||||||
) -> None:
|
) -> None:
|
||||||
if test_device.type == "cpu":
|
if test_device.type == "cpu":
|
||||||
|
@ -227,12 +173,12 @@ def test_lightning_lora_4step(
|
||||||
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver)
|
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver)
|
||||||
sdxl.classifier_free_guidance = False
|
sdxl.classifier_free_guidance = False
|
||||||
|
|
||||||
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights)
|
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path)
|
||||||
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights)
|
sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path)
|
||||||
sdxl.unet.load_from_safetensors(sdxl_unet_weights)
|
sdxl.unet.load_from_safetensors(sdxl_unet_weights_path)
|
||||||
|
|
||||||
manager = SDLoraManager(sdxl)
|
manager = SDLoraManager(sdxl)
|
||||||
add_lcm_lora(manager, load_from_safetensors(sdxl_lightning_4step_lora_weights), name="lightning")
|
add_lcm_lora(manager, load_from_safetensors(lora_sdxl_lightning_4step_weights_path), name="lightning")
|
||||||
|
|
||||||
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
||||||
|
|
||||||
|
|
|
@ -29,19 +29,10 @@ def expected_cactus_mask(ref_path: Path) -> Image.Image:
|
||||||
return _img_open(ref_path / "expected_cactus_mask.png")
|
return _img_open(ref_path / "expected_cactus_mask.png")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def mvanet_weights(test_weights_path: Path) -> Path:
|
|
||||||
weights = test_weights_path / "mvanet" / "mvanet.safetensors"
|
|
||||||
if not weights.is_file():
|
|
||||||
warn(f"could not find weights at {test_weights_path}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return weights
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mvanet_model(mvanet_weights: Path, test_device: torch.device) -> MVANet:
|
def mvanet_model(mvanet_weights_path: Path, test_device: torch.device) -> MVANet:
|
||||||
model = MVANet(device=test_device).eval() # .eval() is important!
|
model = MVANet(device=test_device).eval() # .eval() is important!
|
||||||
model.load_from_safetensors(mvanet_weights)
|
model.load_from_safetensors(mvanet_weights_path)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -61,7 +52,7 @@ def test_mvanet(
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_mvanet_to(
|
def test_mvanet_to(
|
||||||
mvanet_weights: Path,
|
mvanet_weights_path: Path,
|
||||||
ref_cactus: Image.Image,
|
ref_cactus: Image.Image,
|
||||||
expected_cactus_mask: Image.Image,
|
expected_cactus_mask: Image.Image,
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
|
@ -71,7 +62,7 @@ def test_mvanet_to(
|
||||||
pytest.skip()
|
pytest.skip()
|
||||||
|
|
||||||
model = MVANet(device=torch.device("cpu")).eval()
|
model = MVANet(device=torch.device("cpu")).eval()
|
||||||
model.load_from_safetensors(mvanet_weights)
|
model.load_from_safetensors(mvanet_weights_path)
|
||||||
model.to(test_device)
|
model.to(test_device)
|
||||||
|
|
||||||
in_t = image_to_tensor(ref_cactus.resize((1024, 1024), Image.Resampling.BILINEAR)).squeeze()
|
in_t = image_to_tensor(ref_cactus.resize((1024, 1024), Image.Resampling.BILINEAR)).squeeze()
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from warnings import warn
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
@ -29,19 +28,13 @@ def expected_image_informative_drawings(diffusion_ref_path: Path) -> Image.Image
|
||||||
return _img_open(diffusion_ref_path / "cutecat_guide_lineart.png").convert("RGB")
|
return _img_open(diffusion_ref_path / "cutecat_guide_lineart.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def informative_drawings_weights(test_weights_path: Path) -> Path:
|
|
||||||
weights = test_weights_path / "informative-drawings.safetensors"
|
|
||||||
if not weights.is_file():
|
|
||||||
warn(f"could not find weights at {test_weights_path}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return weights
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def informative_drawings_model(informative_drawings_weights: Path, test_device: torch.device) -> InformativeDrawings:
|
def informative_drawings_model(
|
||||||
|
controlnet_preprocessor_info_drawings_weights_path: Path,
|
||||||
|
test_device: torch.device,
|
||||||
|
) -> InformativeDrawings:
|
||||||
model = InformativeDrawings(device=test_device)
|
model = InformativeDrawings(device=test_device)
|
||||||
model.load_from_safetensors(informative_drawings_weights)
|
model.load_from_safetensors(controlnet_preprocessor_info_drawings_weights_path)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from warnings import warn
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
@ -38,24 +37,15 @@ def expected_box_segmenter_spray_cropped_mask(ref_path: Path) -> Image.Image:
|
||||||
return _img_open(ref_path / "expected_box_segmenter_spray_cropped_mask.png")
|
return _img_open(ref_path / "expected_box_segmenter_spray_cropped_mask.png")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def box_segmenter_weights(test_weights_path: Path) -> Path:
|
|
||||||
weights = test_weights_path / "finegrain-box-segmenter-v0-1.safetensors"
|
|
||||||
if not weights.is_file():
|
|
||||||
warn(f"could not find weights at {test_weights_path}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return weights
|
|
||||||
|
|
||||||
|
|
||||||
def test_box_segmenter(
|
def test_box_segmenter(
|
||||||
box_segmenter_weights: Path,
|
box_segmenter_weights_path: Path,
|
||||||
ref_shelves: Image.Image,
|
ref_shelves: Image.Image,
|
||||||
expected_box_segmenter_plant_mask: Image.Image,
|
expected_box_segmenter_plant_mask: Image.Image,
|
||||||
expected_box_segmenter_spray_mask: Image.Image,
|
expected_box_segmenter_spray_mask: Image.Image,
|
||||||
expected_box_segmenter_spray_cropped_mask: Image.Image,
|
expected_box_segmenter_spray_cropped_mask: Image.Image,
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
):
|
):
|
||||||
segmenter = BoxSegmenter(weights=box_segmenter_weights, device=test_device)
|
segmenter = BoxSegmenter(weights=box_segmenter_weights_path, device=test_device)
|
||||||
|
|
||||||
plant_mask = segmenter(ref_shelves, box_prompt=(504, 82, 754, 368))
|
plant_mask = segmenter(ref_shelves, box_prompt=(504, 82, 754, 368))
|
||||||
ensure_similar_images(plant_mask.convert("RGB"), expected_box_segmenter_plant_mask.convert("RGB"))
|
ensure_similar_images(plant_mask.convert("RGB"), expected_box_segmenter_plant_mask.convert("RGB"))
|
||||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.model_converter import ConversionStage, ModelConverter
|
from refiners.conversion.model_converter import ConversionStage, ModelConverter
|
||||||
from refiners.fluxion.utils import manual_seed
|
from refiners.fluxion.utils import manual_seed
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from warnings import warn
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
@ -20,17 +19,13 @@ PROMPTS = [
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def our_encoder_with_new_concepts(
|
def our_encoder_with_new_concepts(
|
||||||
test_weights_path: Path,
|
sd15_text_encoder_weights_path: Path,
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
cat_embedding_textual_inversion: torch.Tensor,
|
cat_embedding_textual_inversion: torch.Tensor,
|
||||||
gta5_artwork_embedding_textual_inversion: torch.Tensor,
|
gta5_artwork_embedding_textual_inversion: torch.Tensor,
|
||||||
) -> CLIPTextEncoderL:
|
) -> CLIPTextEncoderL:
|
||||||
weights = test_weights_path / "CLIPTextEncoderL.safetensors"
|
|
||||||
if not weights.is_file():
|
|
||||||
warn(f"could not find weights at {weights}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
encoder = CLIPTextEncoderL(device=test_device)
|
encoder = CLIPTextEncoderL(device=test_device)
|
||||||
tensors = load_from_safetensors(weights)
|
tensors = load_from_safetensors(sd15_text_encoder_weights_path)
|
||||||
encoder.load_state_dict(tensors)
|
encoder.load_state_dict(tensors)
|
||||||
concept_extender = ConceptExtender(encoder)
|
concept_extender = ConceptExtender(encoder)
|
||||||
concept_extender.add_concept("<cat-toy>", cat_embedding_textual_inversion)
|
concept_extender.add_concept("<cat-toy>", cat_embedding_textual_inversion)
|
||||||
|
@ -41,24 +36,21 @@ def our_encoder_with_new_concepts(
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def ref_sd15_with_new_concepts(
|
def ref_sd15_with_new_concepts(
|
||||||
runwayml_weights_path: Path, test_textual_inversion_path: Path, test_device: torch.device
|
sd15_diffusers_runwayml_path: str,
|
||||||
|
test_textual_inversion_path: Path,
|
||||||
|
test_device: torch.device,
|
||||||
|
use_local_weights: bool,
|
||||||
) -> StableDiffusionPipeline:
|
) -> StableDiffusionPipeline:
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(runwayml_weights_path).to(test_device) # type: ignore
|
pipe = StableDiffusionPipeline.from_pretrained( # type: ignore
|
||||||
|
sd15_diffusers_runwayml_path,
|
||||||
|
local_files_only=use_local_weights,
|
||||||
|
).to(test_device) # type: ignore
|
||||||
assert isinstance(pipe, StableDiffusionPipeline)
|
assert isinstance(pipe, StableDiffusionPipeline)
|
||||||
pipe.load_textual_inversion(test_textual_inversion_path / "cat-toy") # type: ignore
|
pipe.load_textual_inversion(test_textual_inversion_path / "cat-toy") # type: ignore
|
||||||
pipe.load_textual_inversion(test_textual_inversion_path / "gta5-artwork") # type: ignore
|
pipe.load_textual_inversion(test_textual_inversion_path / "gta5-artwork") # type: ignore
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def runwayml_weights_path(test_weights_path: Path):
|
|
||||||
r = test_weights_path / "runwayml" / "stable-diffusion-v1-5"
|
|
||||||
if not r.is_dir():
|
|
||||||
warn(f"could not find RunwayML weights at {r}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def ref_tokenizer_with_new_concepts(ref_sd15_with_new_concepts: StableDiffusionPipeline) -> transformers.CLIPTokenizer:
|
def ref_tokenizer_with_new_concepts(ref_sd15_with_new_concepts: StableDiffusionPipeline) -> transformers.CLIPTokenizer:
|
||||||
return ref_sd15_with_new_concepts.tokenizer # type: ignore
|
return ref_sd15_with_new_concepts.tokenizer # type: ignore
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from warnings import warn
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
@ -11,39 +10,28 @@ from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def our_encoder(
|
def our_encoder(
|
||||||
test_weights_path: Path,
|
clip_image_encoder_huge_weights_path: Path,
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
test_dtype_fp32_bf16_fp16: torch.dtype,
|
test_dtype_fp32_bf16_fp16: torch.dtype,
|
||||||
) -> CLIPImageEncoderH:
|
) -> CLIPImageEncoderH:
|
||||||
weights = test_weights_path / "CLIPImageEncoderH.safetensors"
|
|
||||||
if not weights.is_file():
|
|
||||||
warn(f"could not find weights at {weights}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
encoder = CLIPImageEncoderH(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
encoder = CLIPImageEncoderH(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
||||||
tensors = load_from_safetensors(weights)
|
tensors = load_from_safetensors(clip_image_encoder_huge_weights_path)
|
||||||
encoder.load_state_dict(tensors)
|
encoder.load_state_dict(tensors)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def stabilityai_unclip_weights_path(test_weights_path: Path):
|
|
||||||
r = test_weights_path / "stabilityai" / "stable-diffusion-2-1-unclip"
|
|
||||||
if not r.is_dir():
|
|
||||||
warn(f"could not find Stability AI weights at {r}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def ref_encoder(
|
def ref_encoder(
|
||||||
stabilityai_unclip_weights_path: Path,
|
unclip21_transformers_stabilityai_path: str,
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
test_dtype_fp32_bf16_fp16: torch.dtype,
|
test_dtype_fp32_bf16_fp16: torch.dtype,
|
||||||
|
use_local_weights: bool,
|
||||||
) -> CLIPVisionModelWithProjection:
|
) -> CLIPVisionModelWithProjection:
|
||||||
return CLIPVisionModelWithProjection.from_pretrained( # type: ignore
|
return CLIPVisionModelWithProjection.from_pretrained( # type: ignore
|
||||||
stabilityai_unclip_weights_path,
|
unclip21_transformers_stabilityai_path,
|
||||||
|
local_files_only=use_local_weights,
|
||||||
subfolder="image_encoder",
|
subfolder="image_encoder",
|
||||||
).to(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
).to(device=test_device, dtype=test_dtype_fp32_bf16_fp16) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from warnings import warn
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
@ -31,42 +30,39 @@ PROMPTS = [
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def our_encoder(
|
def our_encoder(
|
||||||
test_weights_path: Path,
|
sd15_text_encoder_weights_path: Path,
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
test_dtype_fp32_fp16: torch.dtype,
|
test_dtype_fp32_fp16: torch.dtype,
|
||||||
) -> CLIPTextEncoderL:
|
) -> CLIPTextEncoderL:
|
||||||
weights = test_weights_path / "CLIPTextEncoderL.safetensors"
|
|
||||||
if not weights.is_file():
|
|
||||||
warn(f"could not find weights at {weights}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
tensors = load_from_safetensors(weights)
|
|
||||||
encoder = CLIPTextEncoderL(device=test_device, dtype=test_dtype_fp32_fp16)
|
encoder = CLIPTextEncoderL(device=test_device, dtype=test_dtype_fp32_fp16)
|
||||||
|
tensors = load_from_safetensors(sd15_text_encoder_weights_path)
|
||||||
|
|
||||||
encoder.load_state_dict(tensors)
|
encoder.load_state_dict(tensors)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def runwayml_weights_path(test_weights_path: Path):
|
def ref_tokenizer(
|
||||||
r = test_weights_path / "runwayml" / "stable-diffusion-v1-5"
|
sd15_diffusers_runwayml_path: str,
|
||||||
if not r.is_dir():
|
use_local_weights: bool,
|
||||||
warn(f"could not find RunwayML weights at {r}, skipping")
|
) -> transformers.CLIPTokenizer:
|
||||||
pytest.skip(allow_module_level=True)
|
return transformers.CLIPTokenizer.from_pretrained( # type: ignore
|
||||||
return r
|
sd15_diffusers_runwayml_path,
|
||||||
|
local_files_only=use_local_weights,
|
||||||
|
subfolder="tokenizer",
|
||||||
@pytest.fixture(scope="module")
|
)
|
||||||
def ref_tokenizer(runwayml_weights_path: Path) -> transformers.CLIPTokenizer:
|
|
||||||
return transformers.CLIPTokenizer.from_pretrained(runwayml_weights_path, subfolder="tokenizer") # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def ref_encoder(
|
def ref_encoder(
|
||||||
runwayml_weights_path: Path,
|
sd15_diffusers_runwayml_path: str,
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
test_dtype_fp32_fp16: torch.dtype,
|
test_dtype_fp32_fp16: torch.dtype,
|
||||||
|
use_local_weights: bool,
|
||||||
) -> transformers.CLIPTextModel:
|
) -> transformers.CLIPTextModel:
|
||||||
return transformers.CLIPTextModel.from_pretrained( # type: ignore
|
return transformers.CLIPTextModel.from_pretrained( # type: ignore
|
||||||
runwayml_weights_path,
|
sd15_diffusers_runwayml_path,
|
||||||
|
local_files_only=use_local_weights,
|
||||||
subfolder="text_encoder",
|
subfolder="text_encoder",
|
||||||
).to(device=test_device, dtype=test_dtype_fp32_fp16) # type: ignore
|
).to(device=test_device, dtype=test_dtype_fp32_fp16) # type: ignore
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ from warnings import warn
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from huggingface_hub import hf_hub_download # type: ignore
|
||||||
|
|
||||||
from refiners.fluxion.utils import load_from_safetensors, load_tensors, manual_seed, no_grad
|
from refiners.fluxion.utils import load_from_safetensors, load_tensors, manual_seed, no_grad
|
||||||
from refiners.foundationals.dinov2.dinov2 import (
|
from refiners.foundationals.dinov2.dinov2 import (
|
||||||
|
@ -18,7 +19,7 @@ from refiners.foundationals.dinov2.dinov2 import (
|
||||||
)
|
)
|
||||||
from refiners.foundationals.dinov2.vit import ViT
|
from refiners.foundationals.dinov2.vit import ViT
|
||||||
|
|
||||||
FLAVORS_MAP = {
|
FLAVORS_MAP_REFINERS = {
|
||||||
"dinov2_vits14": DINOv2_small,
|
"dinov2_vits14": DINOv2_small,
|
||||||
"dinov2_vits14_reg": DINOv2_small_reg,
|
"dinov2_vits14_reg": DINOv2_small_reg,
|
||||||
"dinov2_vitb14": DINOv2_base,
|
"dinov2_vitb14": DINOv2_base,
|
||||||
|
@ -28,6 +29,27 @@ FLAVORS_MAP = {
|
||||||
"dinov2_vitg14": DINOv2_giant,
|
"dinov2_vitg14": DINOv2_giant,
|
||||||
"dinov2_vitg14_reg": DINOv2_giant_reg,
|
"dinov2_vitg14_reg": DINOv2_giant_reg,
|
||||||
}
|
}
|
||||||
|
FLAVORS_MAP_HUB = {
|
||||||
|
"dinov2_vits14": "refiners/dinov2.small.patch_14",
|
||||||
|
"dinov2_vits14_reg": "refiners/dinov2.small.patch_14.reg_4",
|
||||||
|
"dinov2_vitb14": "refiners/dinov2.base.patch_14",
|
||||||
|
"dinov2_vitb14_reg": "refiners/dinov2.base.patch_14.reg_4",
|
||||||
|
"dinov2_vitl14": "refiners/dinov2.large.patch_14",
|
||||||
|
"dinov2_vitl14_reg": "refiners/dinov2.large.patch_14.reg_4",
|
||||||
|
"dinov2_vitg14": "refiners/dinov2.giant.patch_14",
|
||||||
|
"dinov2_vitg14_reg": "refiners/dinov2.giant.patch_14.reg_4",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", params=["float16", "bfloat16"])
|
||||||
|
def dtype(request: pytest.FixtureRequest) -> torch.dtype:
|
||||||
|
match request.param:
|
||||||
|
case "float16":
|
||||||
|
return torch.float16
|
||||||
|
case "bfloat16":
|
||||||
|
return torch.bfloat16
|
||||||
|
case _ as dtype:
|
||||||
|
raise ValueError(f"unsupported dtype: {dtype}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=[224, 518])
|
@pytest.fixture(scope="module", params=[224, 518])
|
||||||
|
@ -35,7 +57,7 @@ def resolution(request: pytest.FixtureRequest) -> int:
|
||||||
return request.param
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=FLAVORS_MAP.keys())
|
@pytest.fixture(scope="module", params=FLAVORS_MAP_REFINERS.keys())
|
||||||
def flavor(request: pytest.FixtureRequest) -> str:
|
def flavor(request: pytest.FixtureRequest) -> str:
|
||||||
return request.param
|
return request.param
|
||||||
|
|
||||||
|
@ -53,7 +75,14 @@ def dinov2_repo_path(test_repos_path: Path) -> Path:
|
||||||
def ref_model(
|
def ref_model(
|
||||||
flavor: str,
|
flavor: str,
|
||||||
dinov2_repo_path: Path,
|
dinov2_repo_path: Path,
|
||||||
test_weights_path: Path,
|
dinov2_small_unconverted_weights_path: Path,
|
||||||
|
dinov2_small_reg4_unconverted_weights_path: Path,
|
||||||
|
dinov2_base_unconverted_weights_path: Path,
|
||||||
|
dinov2_base_reg4_unconverted_weights_path: Path,
|
||||||
|
dinov2_large_unconverted_weights_path: Path,
|
||||||
|
dinov2_large_reg4_unconverted_weights_path: Path,
|
||||||
|
dinov2_giant_unconverted_weights_path: Path,
|
||||||
|
dinov2_giant_reg4_unconverted_weights_path: Path,
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
kwargs: dict[str, Any] = {}
|
kwargs: dict[str, Any] = {}
|
||||||
|
@ -69,34 +98,51 @@ def ref_model(
|
||||||
)
|
)
|
||||||
model = model.to(device=test_device)
|
model = model.to(device=test_device)
|
||||||
|
|
||||||
flavor = flavor.replace("_reg", "_reg4")
|
weight_map = {
|
||||||
weights = test_weights_path / f"{flavor}_pretrain.pth"
|
"dinov2_vits14": dinov2_small_unconverted_weights_path,
|
||||||
if not weights.is_file():
|
"dinov2_vits14_reg": dinov2_small_reg4_unconverted_weights_path,
|
||||||
warn(f"could not find weights at {weights}, skipping")
|
"dinov2_vitb14": dinov2_base_unconverted_weights_path,
|
||||||
pytest.skip(allow_module_level=True)
|
"dinov2_vitb14_reg": dinov2_base_reg4_unconverted_weights_path,
|
||||||
model.load_state_dict(load_tensors(weights, device=test_device))
|
"dinov2_vitl14": dinov2_large_unconverted_weights_path,
|
||||||
|
"dinov2_vitl14_reg": dinov2_large_reg4_unconverted_weights_path,
|
||||||
|
"dinov2_vitg14": dinov2_giant_unconverted_weights_path,
|
||||||
|
"dinov2_vitg14_reg": dinov2_giant_reg4_unconverted_weights_path,
|
||||||
|
}
|
||||||
|
weights_path = weight_map[flavor]
|
||||||
|
|
||||||
|
model.load_state_dict(load_tensors(weights_path, device=test_device))
|
||||||
assert isinstance(model, torch.nn.Module)
|
assert isinstance(model, torch.nn.Module)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def our_model(
|
def our_model(
|
||||||
test_weights_path: Path,
|
|
||||||
flavor: str,
|
flavor: str,
|
||||||
|
dinov2_small_weights_path: Path,
|
||||||
|
dinov2_small_reg4_weights_path: Path,
|
||||||
|
dinov2_base_weights_path: Path,
|
||||||
|
dinov2_base_reg4_weights_path: Path,
|
||||||
|
dinov2_large_weights_path: Path,
|
||||||
|
dinov2_large_reg4_weights_path: Path,
|
||||||
|
dinov2_giant_weights_path: Path,
|
||||||
|
dinov2_giant_reg4_weights_path: Path,
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
) -> ViT:
|
) -> ViT:
|
||||||
model = FLAVORS_MAP[flavor](device=test_device)
|
weight_map = {
|
||||||
|
"dinov2_vits14": dinov2_small_weights_path,
|
||||||
|
"dinov2_vits14_reg": dinov2_small_reg4_weights_path,
|
||||||
|
"dinov2_vitb14": dinov2_base_weights_path,
|
||||||
|
"dinov2_vitb14_reg": dinov2_base_reg4_weights_path,
|
||||||
|
"dinov2_vitl14": dinov2_large_weights_path,
|
||||||
|
"dinov2_vitl14_reg": dinov2_large_reg4_weights_path,
|
||||||
|
"dinov2_vitg14": dinov2_giant_weights_path,
|
||||||
|
"dinov2_vitg14_reg": dinov2_giant_reg4_weights_path,
|
||||||
|
}
|
||||||
|
weights_path = weight_map[flavor]
|
||||||
|
|
||||||
flavor = flavor.replace("_reg", "_reg4")
|
model = FLAVORS_MAP_REFINERS[flavor](device=test_device)
|
||||||
weights = test_weights_path / f"{flavor}_pretrain.safetensors"
|
tensors = load_from_safetensors(weights_path)
|
||||||
if not weights.is_file():
|
|
||||||
warn(f"could not find weights at {weights}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
|
|
||||||
tensors = load_from_safetensors(weights)
|
|
||||||
model.load_state_dict(tensors)
|
model.load_state_dict(tensors)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
139
tests/foundationals/latent_diffusion/conftest.py
Normal file
139
tests/foundationals/latent_diffusion/conftest.py
Normal file
|
@ -0,0 +1,139 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||||
|
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import load_from_safetensors
|
||||||
|
from refiners.foundationals.latent_diffusion import (
|
||||||
|
CLIPTextEncoderL,
|
||||||
|
DoubleTextEncoder,
|
||||||
|
SD1Autoencoder,
|
||||||
|
SD1UNet,
|
||||||
|
SDXLAutoencoder,
|
||||||
|
SDXLUNet,
|
||||||
|
StableDiffusion_1,
|
||||||
|
StableDiffusion_XL,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def refiners_sd15_autoencoder(sd15_autoencoder_weights_path: Path) -> SD1Autoencoder:
|
||||||
|
autoencoder = SD1Autoencoder()
|
||||||
|
tensors = load_from_safetensors(sd15_autoencoder_weights_path)
|
||||||
|
autoencoder.load_state_dict(tensors)
|
||||||
|
return autoencoder
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def refiners_sd15_unet(sd15_unet_weights_path: Path) -> SD1UNet:
|
||||||
|
unet = SD1UNet(in_channels=4)
|
||||||
|
tensors = load_from_safetensors(sd15_unet_weights_path)
|
||||||
|
unet.load_state_dict(tensors)
|
||||||
|
return unet
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def refiners_sd15_text_encoder(sd15_text_encoder_weights_path: Path) -> CLIPTextEncoderL:
|
||||||
|
text_encoder = CLIPTextEncoderL()
|
||||||
|
tensors = load_from_safetensors(sd15_text_encoder_weights_path)
|
||||||
|
text_encoder.load_state_dict(tensors)
|
||||||
|
return text_encoder
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def refiners_sd15(
|
||||||
|
refiners_sd15_autoencoder: SD1Autoencoder,
|
||||||
|
refiners_sd15_unet: SD1UNet,
|
||||||
|
refiners_sd15_text_encoder: CLIPTextEncoderL,
|
||||||
|
) -> StableDiffusion_1:
|
||||||
|
return StableDiffusion_1(
|
||||||
|
lda=refiners_sd15_autoencoder,
|
||||||
|
unet=refiners_sd15_unet,
|
||||||
|
clip_text_encoder=refiners_sd15_text_encoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def refiners_sdxl_autoencoder(sdxl_autoencoder_weights_path: Path) -> SDXLAutoencoder:
|
||||||
|
autoencoder = SDXLAutoencoder()
|
||||||
|
tensors = load_from_safetensors(sdxl_autoencoder_weights_path)
|
||||||
|
autoencoder.load_state_dict(tensors)
|
||||||
|
return autoencoder
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def refiners_sdxl_unet(sdxl_unet_weights_path: Path) -> SDXLUNet:
|
||||||
|
unet = SDXLUNet(in_channels=4)
|
||||||
|
tensors = load_from_safetensors(sdxl_unet_weights_path)
|
||||||
|
unet.load_state_dict(tensors)
|
||||||
|
return unet
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def refiners_sdxl_text_encoder(sdxl_text_encoder_weights_path: Path) -> DoubleTextEncoder:
|
||||||
|
text_encoder = DoubleTextEncoder()
|
||||||
|
tensors = load_from_safetensors(sdxl_text_encoder_weights_path)
|
||||||
|
text_encoder.load_state_dict(tensors)
|
||||||
|
return text_encoder
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def refiners_sdxl(
|
||||||
|
refiners_sdxl_autoencoder: SDXLAutoencoder,
|
||||||
|
refiners_sdxl_unet: SDXLUNet,
|
||||||
|
refiners_sd15_text_encoder: DoubleTextEncoder,
|
||||||
|
) -> StableDiffusion_XL:
|
||||||
|
return StableDiffusion_XL(
|
||||||
|
lda=refiners_sdxl_autoencoder,
|
||||||
|
unet=refiners_sdxl_unet,
|
||||||
|
clip_text_encoder=refiners_sd15_text_encoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", params=["SD1.5", "SDXL"])
|
||||||
|
def refiners_autoencoder(
|
||||||
|
request: pytest.FixtureRequest,
|
||||||
|
refiners_sd15_autoencoder: SD1Autoencoder,
|
||||||
|
refiners_sdxl_autoencoder: SDXLAutoencoder,
|
||||||
|
test_dtype_fp32_bf16_fp16: torch.dtype,
|
||||||
|
) -> SD1Autoencoder | SDXLAutoencoder:
|
||||||
|
model_version = request.param
|
||||||
|
match (model_version, test_dtype_fp32_bf16_fp16):
|
||||||
|
case ("SD1.5", _):
|
||||||
|
return refiners_sd15_autoencoder
|
||||||
|
case ("SDXL", torch.float16):
|
||||||
|
return refiners_sdxl_autoencoder
|
||||||
|
case ("SDXL", _):
|
||||||
|
return refiners_sdxl_autoencoder
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unknown model version: {model_version}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def diffusers_sd15_pipeline(
|
||||||
|
sd15_diffusers_runwayml_path: str,
|
||||||
|
use_local_weights: bool,
|
||||||
|
) -> StableDiffusionPipeline:
|
||||||
|
return StableDiffusionPipeline.from_pretrained( # type: ignore
|
||||||
|
sd15_diffusers_runwayml_path,
|
||||||
|
local_files_only=use_local_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def diffusers_sdxl_pipeline(
|
||||||
|
sdxl_diffusers_stabilityai_path: str,
|
||||||
|
use_local_weights: bool,
|
||||||
|
) -> StableDiffusionXLPipeline:
|
||||||
|
return StableDiffusionXLPipeline.from_pretrained( # type: ignore
|
||||||
|
sdxl_diffusers_stabilityai_path,
|
||||||
|
local_files_only=use_local_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def diffusers_sdxl_unet(diffusers_sdxl_pipeline: StableDiffusionXLPipeline) -> UNet2DConditionModel:
|
||||||
|
return diffusers_sdxl_pipeline.unet # type: ignore
|
|
@ -1,134 +0,0 @@
|
||||||
from pathlib import Path
|
|
||||||
from warnings import warn
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from tests.utils import ensure_similar_images
|
|
||||||
|
|
||||||
from refiners.fluxion.utils import no_grad
|
|
||||||
from refiners.foundationals.latent_diffusion import LatentDiffusionAutoencoder, SD1Autoencoder, SDXLAutoencoder
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def ref_path() -> Path:
|
|
||||||
return Path(__file__).parent / "test_auto_encoder_ref"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=["SD1.5", "SDXL"])
|
|
||||||
def lda(
|
|
||||||
request: pytest.FixtureRequest,
|
|
||||||
test_weights_path: Path,
|
|
||||||
test_dtype_fp32_bf16_fp16: torch.dtype,
|
|
||||||
test_device: torch.device,
|
|
||||||
) -> LatentDiffusionAutoencoder:
|
|
||||||
model_version = request.param
|
|
||||||
match (model_version, test_dtype_fp32_bf16_fp16):
|
|
||||||
case ("SD1.5", _):
|
|
||||||
weight_path = test_weights_path / "lda.safetensors"
|
|
||||||
if not weight_path.is_file():
|
|
||||||
warn(f"could not find weights at {weight_path}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
model = SD1Autoencoder().load_from_safetensors(weight_path)
|
|
||||||
case ("SDXL", torch.float16):
|
|
||||||
weight_path = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
|
|
||||||
if not weight_path.is_file():
|
|
||||||
warn(f"could not find weights at {weight_path}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
model = SDXLAutoencoder().load_from_safetensors(weight_path)
|
|
||||||
case ("SDXL", _):
|
|
||||||
weight_path = test_weights_path / "sdxl-lda.safetensors"
|
|
||||||
if not weight_path.is_file():
|
|
||||||
warn(f"could not find weights at {weight_path}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
model = SDXLAutoencoder().load_from_safetensors(weight_path)
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"Unknown model version: {model_version}")
|
|
||||||
model = model.to(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def sample_image(ref_path: Path) -> Image.Image:
|
|
||||||
test_image = ref_path / "macaw.png"
|
|
||||||
if not test_image.is_file():
|
|
||||||
warn(f"could not reference image at {test_image}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
img = Image.open(test_image) # type: ignore
|
|
||||||
assert img.size == (512, 512)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
|
||||||
def test_encode_decode_image(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
|
||||||
encoded = lda.image_to_latents(sample_image)
|
|
||||||
decoded = lda.latents_to_image(encoded)
|
|
||||||
|
|
||||||
assert decoded.mode == "RGB" # type: ignore
|
|
||||||
|
|
||||||
# Ensure no saturation. The green channel (band = 1) must not max out.
|
|
||||||
assert max(iter(decoded.getdata(band=1))) < 255 # type: ignore
|
|
||||||
|
|
||||||
ensure_similar_images(sample_image, decoded, min_psnr=20, min_ssim=0.9)
|
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
|
||||||
def test_encode_decode_images(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
|
||||||
encoded = lda.images_to_latents([sample_image, sample_image])
|
|
||||||
images = lda.latents_to_images(encoded)
|
|
||||||
assert isinstance(images, list)
|
|
||||||
assert len(images) == 2
|
|
||||||
ensure_similar_images(sample_image, images[1], min_psnr=20, min_ssim=0.9)
|
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
|
||||||
def test_tiled_autoencoder(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
|
||||||
sample_image = sample_image.resize((2048, 2048)) # type: ignore
|
|
||||||
|
|
||||||
with lda.tiled_inference(sample_image, tile_size=(512, 512)):
|
|
||||||
encoded = lda.tiled_image_to_latents(sample_image)
|
|
||||||
result = lda.tiled_latents_to_image(encoded)
|
|
||||||
|
|
||||||
ensure_similar_images(sample_image, result, min_psnr=35, min_ssim=0.985)
|
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
|
||||||
def test_tiled_autoencoder_rectangular_tiles(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
|
||||||
sample_image = sample_image.resize((2048, 2048)) # type: ignore
|
|
||||||
|
|
||||||
with lda.tiled_inference(sample_image, tile_size=(512, 1024)):
|
|
||||||
encoded = lda.tiled_image_to_latents(sample_image)
|
|
||||||
result = lda.tiled_latents_to_image(encoded)
|
|
||||||
|
|
||||||
ensure_similar_images(sample_image, result, min_psnr=35, min_ssim=0.985)
|
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
|
||||||
def test_tiled_autoencoder_large_tile(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
|
||||||
sample_image = sample_image.resize((1024, 1024)) # type: ignore
|
|
||||||
|
|
||||||
with lda.tiled_inference(sample_image, tile_size=(2048, 2048)):
|
|
||||||
encoded = lda.tiled_image_to_latents(sample_image)
|
|
||||||
result = lda.tiled_latents_to_image(encoded)
|
|
||||||
|
|
||||||
ensure_similar_images(sample_image, result, min_psnr=34, min_ssim=0.975)
|
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
|
||||||
def test_tiled_autoencoder_rectangular_image(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
|
||||||
sample_image = sample_image.crop((0, 0, 300, 500))
|
|
||||||
sample_image = sample_image.resize((sample_image.width * 4, sample_image.height * 4)) # type: ignore
|
|
||||||
|
|
||||||
with lda.tiled_inference(sample_image, tile_size=(512, 512)):
|
|
||||||
encoded = lda.tiled_image_to_latents(sample_image)
|
|
||||||
result = lda.tiled_latents_to_image(encoded)
|
|
||||||
|
|
||||||
ensure_similar_images(sample_image, result, min_psnr=37, min_ssim=0.985)
|
|
||||||
|
|
||||||
|
|
||||||
def test_value_error_tile_encode_no_context(lda: LatentDiffusionAutoencoder, sample_image: Image.Image) -> None:
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
lda.tiled_image_to_latents(sample_image)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
lda.tiled_latents_to_image(torch.randn(1, 8, 16, 16, device=lda.device))
|
|
104
tests/foundationals/latent_diffusion/test_autoencoders.py
Normal file
104
tests/foundationals/latent_diffusion/test_autoencoders.py
Normal file
|
@ -0,0 +1,104 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from warnings import warn
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from tests.utils import ensure_similar_images
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import no_grad
|
||||||
|
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def sample_image() -> Image.Image:
|
||||||
|
test_image = Path(__file__).parent / "test_auto_encoder_ref" / "macaw.png"
|
||||||
|
if not test_image.is_file():
|
||||||
|
warn(f"could not reference image at {test_image}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
img = Image.open(test_image) # type: ignore
|
||||||
|
assert img.size == (512, 512)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def autoencoder(
|
||||||
|
refiners_autoencoder: LatentDiffusionAutoencoder,
|
||||||
|
test_device: torch.device,
|
||||||
|
) -> LatentDiffusionAutoencoder:
|
||||||
|
return refiners_autoencoder.to(test_device)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_encode_decode_image(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||||
|
encoded = autoencoder.image_to_latents(sample_image)
|
||||||
|
decoded = autoencoder.latents_to_image(encoded)
|
||||||
|
|
||||||
|
assert decoded.mode == "RGB" # type: ignore
|
||||||
|
|
||||||
|
# Ensure no saturation. The green channel (band = 1) must not max out.
|
||||||
|
assert max(iter(decoded.getdata(band=1))) < 255 # type: ignore
|
||||||
|
|
||||||
|
ensure_similar_images(sample_image, decoded, min_psnr=20, min_ssim=0.9)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_encode_decode_images(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||||
|
encoded = autoencoder.images_to_latents([sample_image, sample_image])
|
||||||
|
images = autoencoder.latents_to_images(encoded)
|
||||||
|
assert isinstance(images, list)
|
||||||
|
assert len(images) == 2
|
||||||
|
ensure_similar_images(sample_image, images[1], min_psnr=20, min_ssim=0.9)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_tiled_autoencoder(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||||
|
sample_image = sample_image.resize((2048, 2048)) # type: ignore
|
||||||
|
|
||||||
|
with autoencoder.tiled_inference(sample_image, tile_size=(512, 512)):
|
||||||
|
encoded = autoencoder.tiled_image_to_latents(sample_image)
|
||||||
|
result = autoencoder.tiled_latents_to_image(encoded)
|
||||||
|
|
||||||
|
ensure_similar_images(sample_image, result, min_psnr=35, min_ssim=0.985)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_tiled_autoencoder_rectangular_tiles(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||||
|
sample_image = sample_image.resize((2048, 2048)) # type: ignore
|
||||||
|
|
||||||
|
with autoencoder.tiled_inference(sample_image, tile_size=(512, 1024)):
|
||||||
|
encoded = autoencoder.tiled_image_to_latents(sample_image)
|
||||||
|
result = autoencoder.tiled_latents_to_image(encoded)
|
||||||
|
|
||||||
|
ensure_similar_images(sample_image, result, min_psnr=35, min_ssim=0.985)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_tiled_autoencoder_large_tile(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||||
|
sample_image = sample_image.resize((1024, 1024)) # type: ignore
|
||||||
|
|
||||||
|
with autoencoder.tiled_inference(sample_image, tile_size=(2048, 2048)):
|
||||||
|
encoded = autoencoder.tiled_image_to_latents(sample_image)
|
||||||
|
result = autoencoder.tiled_latents_to_image(encoded)
|
||||||
|
|
||||||
|
ensure_similar_images(sample_image, result, min_psnr=34, min_ssim=0.975)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_tiled_autoencoder_rectangular_image(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||||
|
sample_image = sample_image.crop((0, 0, 300, 500))
|
||||||
|
sample_image = sample_image.resize((sample_image.width * 4, sample_image.height * 4)) # type: ignore
|
||||||
|
|
||||||
|
with autoencoder.tiled_inference(sample_image, tile_size=(512, 512)):
|
||||||
|
encoded = autoencoder.tiled_image_to_latents(sample_image)
|
||||||
|
result = autoencoder.tiled_latents_to_image(encoded)
|
||||||
|
|
||||||
|
ensure_similar_images(sample_image, result, min_psnr=37, min_ssim=0.985)
|
||||||
|
|
||||||
|
|
||||||
|
def test_value_error_tile_encode_no_context(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image) -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
autoencoder.tiled_image_to_latents(sample_image)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
autoencoder.tiled_latents_to_image(torch.randn(1, 8, 16, 16, device=autoencoder.device))
|
|
@ -1,31 +0,0 @@
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from refiners.fluxion.utils import manual_seed, no_grad
|
|
||||||
from refiners.foundationals.latent_diffusion import StableDiffusion_1_Inpainting
|
|
||||||
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
|
||||||
def test_sample_noise():
|
|
||||||
manual_seed(2)
|
|
||||||
latents_0 = LatentDiffusionModel.sample_noise(size=(1, 4, 64, 64))
|
|
||||||
manual_seed(2)
|
|
||||||
latents_1 = LatentDiffusionModel.sample_noise(size=(1, 4, 64, 64), offset_noise=0.0)
|
|
||||||
|
|
||||||
assert torch.allclose(latents_0, latents_1, atol=1e-6, rtol=0)
|
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
|
||||||
def test_sd1_inpainting(test_device: torch.device) -> None:
|
|
||||||
sd = StableDiffusion_1_Inpainting(device=test_device)
|
|
||||||
|
|
||||||
latent_noise = torch.randn(1, 4, 64, 64, device=test_device)
|
|
||||||
target_image = Image.new("RGB", (512, 512))
|
|
||||||
mask = Image.new("L", (512, 512))
|
|
||||||
|
|
||||||
sd.set_inpainting_conditions(target_image=target_image, mask=mask)
|
|
||||||
text_embedding = sd.compute_clip_text_embedding("")
|
|
||||||
output = sd(latent_noise, step=0, clip_text_embedding=text_embedding)
|
|
||||||
|
|
||||||
assert output.shape == (1, 4, 64, 64)
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
import torch
|
||||||
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import manual_seed, no_grad
|
||||||
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_text_encoder(
|
||||||
|
diffusers_sd15_pipeline: StableDiffusionPipeline,
|
||||||
|
refiners_sd15_text_encoder: CLIPTextEncoderL,
|
||||||
|
) -> None:
|
||||||
|
"""Compare our refiners implementation with the diffusers implementation."""
|
||||||
|
manual_seed(seed=0) # unnecessary, but just in case
|
||||||
|
prompt = "A photo of a pizza."
|
||||||
|
negative_prompt = ""
|
||||||
|
atol = 1e-2 # FIXME: very high tolerance, figure out why
|
||||||
|
|
||||||
|
( # encode text prompts using diffusers pipeline
|
||||||
|
diffusers_embeds, # type: ignore
|
||||||
|
diffusers_negative_embeds, # type: ignore
|
||||||
|
) = diffusers_sd15_pipeline.encode_prompt( # type: ignore
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
num_images_per_prompt=1,
|
||||||
|
do_classifier_free_guidance=True,
|
||||||
|
device=diffusers_sd15_pipeline.device,
|
||||||
|
)
|
||||||
|
assert isinstance(diffusers_embeds, Tensor)
|
||||||
|
assert isinstance(diffusers_negative_embeds, Tensor)
|
||||||
|
|
||||||
|
# encode text prompts using refiners model
|
||||||
|
refiners_embeds = refiners_sd15_text_encoder(prompt)
|
||||||
|
refiners_negative_embeds = refiners_sd15_text_encoder("")
|
||||||
|
|
||||||
|
# check that the shapes are the same
|
||||||
|
assert diffusers_embeds.shape == refiners_embeds.shape == (1, 77, 768)
|
||||||
|
assert diffusers_negative_embeds.shape == refiners_negative_embeds.shape == (1, 77, 768)
|
||||||
|
|
||||||
|
# check that the values are close
|
||||||
|
assert torch.allclose(input=refiners_embeds, other=diffusers_embeds, atol=atol)
|
||||||
|
assert torch.allclose(input=refiners_negative_embeds, other=diffusers_negative_embeds, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_text_encoder_batched(refiners_sd15_text_encoder: CLIPTextEncoderL) -> None:
|
||||||
|
"""Check that encoding two prompts works as expected whether batched or not."""
|
||||||
|
manual_seed(seed=0) # unnecessary, but just in case
|
||||||
|
prompt1 = "A photo of a pizza."
|
||||||
|
prompt2 = "A giant duck."
|
||||||
|
atol = 1e-6
|
||||||
|
|
||||||
|
# encode the two prompts at once
|
||||||
|
embeds_batched = refiners_sd15_text_encoder([prompt1, prompt2])
|
||||||
|
assert embeds_batched.shape == (2, 77, 768)
|
||||||
|
|
||||||
|
# encode the prompts one by one
|
||||||
|
embeds_1 = refiners_sd15_text_encoder(prompt1)
|
||||||
|
embeds_2 = refiners_sd15_text_encoder(prompt2)
|
||||||
|
assert embeds_1.shape == embeds_2.shape == (1, 77, 768)
|
||||||
|
|
||||||
|
# check that the values are close
|
||||||
|
assert torch.allclose(input=embeds_1, other=embeds_batched[0].unsqueeze(0), atol=atol)
|
||||||
|
assert torch.allclose(input=embeds_2, other=embeds_batched[1].unsqueeze(0), atol=atol)
|
|
@ -1,121 +0,0 @@
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Protocol, cast
|
|
||||||
from warnings import warn
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
|
||||||
from refiners.fluxion.utils import manual_seed, no_grad
|
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
|
||||||
|
|
||||||
|
|
||||||
class DiffusersSDXL(Protocol):
|
|
||||||
unet: fl.Module
|
|
||||||
text_encoder: fl.Module
|
|
||||||
text_encoder_2: fl.Module
|
|
||||||
tokenizer: fl.Module
|
|
||||||
tokenizer_2: fl.Module
|
|
||||||
vae: fl.Module
|
|
||||||
|
|
||||||
def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any: ...
|
|
||||||
|
|
||||||
def encode_prompt(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
prompt_2: str | None = None,
|
|
||||||
negative_prompt: str | None = None,
|
|
||||||
negative_prompt_2: str | None = None,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ...
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def stabilityai_sdxl_base_path(test_weights_path: Path) -> Path:
|
|
||||||
r = test_weights_path / "stabilityai" / "stable-diffusion-xl-base-1.0"
|
|
||||||
if not r.is_dir():
|
|
||||||
warn(message=f"could not find Stability SDXL base weights at {r}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def double_text_encoder_weights(test_weights_path: Path) -> Path:
|
|
||||||
text_encoder_weights = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
|
|
||||||
if not text_encoder_weights.is_file():
|
|
||||||
warn(f"could not find weights at {text_encoder_weights}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return text_encoder_weights
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def diffusers_sdxl(stabilityai_sdxl_base_path: Path) -> Any:
|
|
||||||
from diffusers import DiffusionPipeline # type: ignore
|
|
||||||
|
|
||||||
return DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=stabilityai_sdxl_base_path) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def double_text_encoder(double_text_encoder_weights: Path) -> DoubleTextEncoder:
|
|
||||||
double_text_encoder = DoubleTextEncoder()
|
|
||||||
double_text_encoder.load_from_safetensors(double_text_encoder_weights)
|
|
||||||
|
|
||||||
return double_text_encoder
|
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
|
||||||
def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder: DoubleTextEncoder) -> None:
|
|
||||||
manual_seed(seed=0)
|
|
||||||
prompt = "A photo of a pizza."
|
|
||||||
|
|
||||||
(
|
|
||||||
prompt_embeds,
|
|
||||||
negative_prompt_embeds,
|
|
||||||
pooled_prompt_embeds,
|
|
||||||
negative_pooled_prompt_embeds,
|
|
||||||
) = diffusers_sdxl.encode_prompt(prompt=prompt, negative_prompt="")
|
|
||||||
|
|
||||||
double_embedding, pooled_embedding = double_text_encoder(prompt)
|
|
||||||
|
|
||||||
assert double_embedding.shape == torch.Size([1, 77, 2048])
|
|
||||||
assert pooled_embedding.shape == torch.Size([1, 1280])
|
|
||||||
|
|
||||||
embedding_1, embedding_2 = cast(
|
|
||||||
tuple[Tensor, Tensor],
|
|
||||||
prompt_embeds.split(split_size=[768, 1280], dim=-1), # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
rembedding_1, rembedding_2 = cast(
|
|
||||||
tuple[Tensor, Tensor],
|
|
||||||
double_embedding.split(split_size=[768, 1280], dim=-1), # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
assert torch.allclose(input=embedding_1, other=rembedding_1, rtol=1e-3, atol=1e-3)
|
|
||||||
assert torch.allclose(input=embedding_2, other=rembedding_2, rtol=1e-3, atol=1e-3)
|
|
||||||
assert torch.allclose(input=pooled_embedding, other=pooled_prompt_embeds, rtol=1e-3, atol=1e-3)
|
|
||||||
|
|
||||||
negative_double_embedding, negative_pooled_embedding = double_text_encoder("")
|
|
||||||
|
|
||||||
assert torch.allclose(input=negative_double_embedding, other=negative_prompt_embeds, rtol=1e-3, atol=1e-3)
|
|
||||||
assert torch.allclose(input=negative_pooled_embedding, other=negative_pooled_prompt_embeds, rtol=1e-3, atol=1e-3)
|
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
|
||||||
def test_double_text_encoder_batch2(double_text_encoder: DoubleTextEncoder) -> None:
|
|
||||||
manual_seed(seed=0)
|
|
||||||
prompt1 = "A photo of a pizza."
|
|
||||||
prompt2 = "A giant duck."
|
|
||||||
|
|
||||||
double_embedding_b2, pooled_embedding_b2 = double_text_encoder([prompt1, prompt2])
|
|
||||||
|
|
||||||
assert double_embedding_b2.shape == torch.Size([2, 77, 2048])
|
|
||||||
assert pooled_embedding_b2.shape == torch.Size([2, 1280])
|
|
||||||
|
|
||||||
double_embedding_1, pooled_embedding_1 = double_text_encoder(prompt1)
|
|
||||||
double_embedding_2, pooled_embedding_2 = double_text_encoder(prompt2)
|
|
||||||
|
|
||||||
assert torch.allclose(input=double_embedding_1, other=double_embedding_b2[0:1], rtol=1e-3, atol=1e-3)
|
|
||||||
assert torch.allclose(input=pooled_embedding_1, other=pooled_embedding_b2[0:1], rtol=1e-3, atol=1e-3)
|
|
||||||
|
|
||||||
assert torch.allclose(input=double_embedding_2, other=double_embedding_b2[1:2], rtol=1e-3, atol=1e-3)
|
|
||||||
assert torch.allclose(input=pooled_embedding_2, other=pooled_embedding_b2[1:2], rtol=1e-3, atol=1e-3)
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
import torch
|
||||||
|
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import manual_seed, no_grad
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_double_text_encoder(
|
||||||
|
diffusers_sdxl_pipeline: StableDiffusionXLPipeline,
|
||||||
|
refiners_sdxl_text_encoder: DoubleTextEncoder,
|
||||||
|
) -> None:
|
||||||
|
"""Compare our refiners implementation with the diffusers implementation."""
|
||||||
|
manual_seed(seed=0) # unnecessary, but just in case
|
||||||
|
prompt = "A photo of a pizza."
|
||||||
|
negative_prompt = ""
|
||||||
|
atol = 1e-6
|
||||||
|
|
||||||
|
( # encode text prompts using diffusers pipeline
|
||||||
|
diffusers_embeds,
|
||||||
|
diffusers_negative_embeds, # type: ignore
|
||||||
|
diffusers_pooled_embeds, # type: ignore
|
||||||
|
diffusers_negative_pooled_embeds, # type: ignore
|
||||||
|
) = diffusers_sdxl_pipeline.encode_prompt(prompt=prompt, negative_prompt=negative_prompt)
|
||||||
|
assert diffusers_negative_embeds is not None
|
||||||
|
assert isinstance(diffusers_pooled_embeds, Tensor)
|
||||||
|
assert isinstance(diffusers_negative_pooled_embeds, Tensor)
|
||||||
|
|
||||||
|
# encode text prompts using refiners model
|
||||||
|
refiners_embeds, refiners_pooled_embeds = refiners_sdxl_text_encoder(prompt)
|
||||||
|
refiners_negative_embeds, refiners_negative_pooled_embeds = refiners_sdxl_text_encoder("")
|
||||||
|
|
||||||
|
# check that the shapes are the same
|
||||||
|
assert diffusers_embeds.shape == refiners_embeds.shape == (1, 77, 2048)
|
||||||
|
assert diffusers_pooled_embeds.shape == refiners_pooled_embeds.shape == (1, 1280)
|
||||||
|
assert diffusers_negative_embeds.shape == refiners_negative_embeds.shape == (1, 77, 2048)
|
||||||
|
assert diffusers_negative_pooled_embeds.shape == refiners_negative_pooled_embeds.shape == (1, 1280)
|
||||||
|
|
||||||
|
# check that the values are close
|
||||||
|
assert torch.allclose(input=refiners_embeds, other=diffusers_embeds, atol=atol)
|
||||||
|
assert torch.allclose(input=refiners_negative_embeds, other=diffusers_negative_embeds, atol=atol)
|
||||||
|
assert torch.allclose(input=refiners_negative_pooled_embeds, other=diffusers_negative_pooled_embeds, atol=atol)
|
||||||
|
assert torch.allclose(input=refiners_pooled_embeds, other=diffusers_pooled_embeds, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_double_text_encoder_batched(refiners_sdxl_text_encoder: DoubleTextEncoder) -> None:
|
||||||
|
"""Check that encoding two prompts works as expected whether batched or not."""
|
||||||
|
manual_seed(seed=0) # unnecessary, but just in case
|
||||||
|
prompt1 = "A photo of a pizza."
|
||||||
|
prompt2 = "A giant duck."
|
||||||
|
atol = 1e-6
|
||||||
|
|
||||||
|
# encode the two prompts at once
|
||||||
|
embeds_batched, pooled_embeds_batched = refiners_sdxl_text_encoder([prompt1, prompt2])
|
||||||
|
assert embeds_batched.shape == (2, 77, 2048)
|
||||||
|
assert pooled_embeds_batched.shape == (2, 1280)
|
||||||
|
|
||||||
|
# encode the prompts one by one
|
||||||
|
embeds_1, pooled_embeds_1 = refiners_sdxl_text_encoder(prompt1)
|
||||||
|
embeds_2, pooled_embeds_2 = refiners_sdxl_text_encoder(prompt2)
|
||||||
|
assert embeds_1.shape == embeds_2.shape == (1, 77, 2048)
|
||||||
|
assert pooled_embeds_1.shape == pooled_embeds_2.shape == (1, 1280)
|
||||||
|
|
||||||
|
# check that the values are close
|
||||||
|
assert torch.allclose(input=embeds_1, other=embeds_batched[0].unsqueeze(0), atol=atol)
|
||||||
|
assert torch.allclose(input=pooled_embeds_1, other=pooled_embeds_batched[0].unsqueeze(0), atol=atol)
|
||||||
|
assert torch.allclose(input=embeds_2, other=embeds_batched[1].unsqueeze(0), atol=atol)
|
||||||
|
assert torch.allclose(input=pooled_embeds_2, other=pooled_embeds_batched[1].unsqueeze(0), atol=atol)
|
|
@ -1,36 +1,13 @@
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from warnings import warn
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from refiners.fluxion.model_converter import ConversionStage, ModelConverter
|
from refiners.conversion.model_converter import ConversionStage, ModelConverter
|
||||||
from refiners.fluxion.utils import manual_seed, no_grad
|
from refiners.fluxion.utils import manual_seed, no_grad
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def stabilityai_sdxl_base_path(test_weights_path: Path) -> Path:
|
|
||||||
r = test_weights_path / "stabilityai" / "stable-diffusion-xl-base-1.0"
|
|
||||||
if not r.is_dir():
|
|
||||||
warn(f"could not find Stability SDXL base weights at {r}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def diffusers_sdxl(stabilityai_sdxl_base_path: Path) -> Any:
|
|
||||||
from diffusers import DiffusionPipeline # type: ignore
|
|
||||||
|
|
||||||
return DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=stabilityai_sdxl_base_path) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def diffusers_sdxl_unet(diffusers_sdxl: Any) -> Any:
|
|
||||||
return diffusers_sdxl.unet
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def refiners_sdxl_unet() -> SDXLUNet:
|
def refiners_sdxl_unet() -> SDXLUNet:
|
||||||
unet = SDXLUNet(in_channels=4)
|
unet = SDXLUNet(in_channels=4)
|
||||||
|
@ -38,7 +15,10 @@ def refiners_sdxl_unet() -> SDXLUNet:
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_sdxl_unet(diffusers_sdxl_unet: Any, refiners_sdxl_unet: SDXLUNet) -> None:
|
def test_sdxl_unet(
|
||||||
|
diffusers_sdxl_unet: Any,
|
||||||
|
refiners_sdxl_unet: SDXLUNet,
|
||||||
|
) -> None:
|
||||||
source = diffusers_sdxl_unet
|
source = diffusers_sdxl_unet
|
||||||
target = refiners_sdxl_unet
|
target = refiners_sdxl_unet
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
import gc
|
import gc
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from warnings import warn
|
|
||||||
|
|
||||||
from pytest import fixture, skip
|
from pytest import fixture
|
||||||
|
|
||||||
|
|
||||||
@fixture(autouse=True)
|
@fixture(autouse=True)
|
||||||
|
@ -15,12 +14,3 @@ def ensure_gc():
|
||||||
@fixture(scope="package")
|
@fixture(scope="package")
|
||||||
def ref_path(test_sam_path: Path) -> Path:
|
def ref_path(test_sam_path: Path) -> Path:
|
||||||
return test_sam_path / "test_sam_ref"
|
return test_sam_path / "test_sam_ref"
|
||||||
|
|
||||||
|
|
||||||
@fixture(scope="package")
|
|
||||||
def sam_h_weights(test_weights_path: Path) -> Path:
|
|
||||||
sam_h_weights = test_weights_path / "segment-anything-h.safetensors"
|
|
||||||
if not sam_h_weights.is_file():
|
|
||||||
warn(f"could not find weights at {sam_h_weights}, skipping")
|
|
||||||
skip(allow_module_level=True)
|
|
||||||
return sam_h_weights
|
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from warnings import warn
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -36,37 +35,17 @@ def tennis(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "tennis.png").convert("RGB") # type: ignore
|
return Image.open(ref_path / "tennis.png").convert("RGB") # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def hq_adapter_weights(test_weights_path: Path) -> Path:
|
|
||||||
"""Path to the HQ adapter weights in Refiners format"""
|
|
||||||
refiners_hq_adapter_sam_weights = test_weights_path / "refiners-sam-hq-vit-h.safetensors"
|
|
||||||
if not refiners_hq_adapter_sam_weights.is_file():
|
|
||||||
warn(f"Test weights not found at {refiners_hq_adapter_sam_weights}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return refiners_hq_adapter_sam_weights
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
|
def sam_h(sam_h_weights_path: Path, test_device: torch.device) -> SegmentAnythingH:
|
||||||
# HQSAMAdapter is designed to be used with single-output only, hence multimask_output=False.
|
# HQSAMAdapter is designed to be used with single-output only, hence multimask_output=False.
|
||||||
sam_h = SegmentAnythingH(multimask_output=False, device=test_device)
|
sam_h = SegmentAnythingH(multimask_output=False, device=test_device)
|
||||||
sam_h.load_from_safetensors(tensors_path=sam_h_weights)
|
sam_h.load_from_safetensors(tensors_path=sam_h_weights_path)
|
||||||
return sam_h
|
return sam_h
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def reference_hq_adapter_weights(test_weights_path: Path) -> Path:
|
def reference_sam_h(sam_h_hq_adapter_unconverted_weights_path: Path, test_device: torch.device) -> FacebookSAM:
|
||||||
"""Path to the HQ adapter weights in default format"""
|
sam_h = cast(FacebookSAM, sam_model_registry_hq["vit_h"](checkpoint=sam_h_hq_adapter_unconverted_weights_path))
|
||||||
reference_hq_adapter_sam_weights = test_weights_path / "sam_hq_vit_h.pth"
|
|
||||||
if not reference_hq_adapter_sam_weights.is_file():
|
|
||||||
warn(f"Test weights not found at {reference_hq_adapter_sam_weights}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return reference_hq_adapter_sam_weights
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def reference_sam_h(reference_hq_adapter_weights: Path, test_device: torch.device) -> FacebookSAM:
|
|
||||||
sam_h = cast(FacebookSAM, sam_model_registry_hq["vit_h"](checkpoint=reference_hq_adapter_weights))
|
|
||||||
return sam_h.to(device=test_device)
|
return sam_h.to(device=test_device)
|
||||||
|
|
||||||
|
|
||||||
|
@ -142,11 +121,11 @@ def test_mask_decoder_tokens_extender() -> None:
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_early_vit_embedding(
|
def test_early_vit_embedding(
|
||||||
sam_h: SegmentAnythingH,
|
sam_h: SegmentAnythingH,
|
||||||
hq_adapter_weights: Path,
|
sam_h_hq_adapter_weights_path: Path,
|
||||||
reference_sam_h: FacebookSAM,
|
reference_sam_h: FacebookSAM,
|
||||||
tennis: Image.Image,
|
tennis: Image.Image,
|
||||||
) -> None:
|
) -> None:
|
||||||
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject()
|
||||||
|
|
||||||
image_tensor = image_to_tensor(image=tennis.resize(size=(1024, 1024))) # type: ignore
|
image_tensor = image_to_tensor(image=tennis.resize(size=(1024, 1024))) # type: ignore
|
||||||
|
|
||||||
|
@ -159,8 +138,8 @@ def test_early_vit_embedding(
|
||||||
assert torch.equal(early_vit_embedding, early_vit_embedding_refiners)
|
assert torch.equal(early_vit_embedding, early_vit_embedding_refiners)
|
||||||
|
|
||||||
|
|
||||||
def test_tokens(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
|
def test_tokens(sam_h: SegmentAnythingH, sam_h_hq_adapter_weights_path: Path, reference_sam_h: FacebookSAM) -> None:
|
||||||
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject()
|
||||||
|
|
||||||
mask_decoder_tokens_extender = sam_h.mask_decoder.ensure_find(MaskDecoderTokensExtender)
|
mask_decoder_tokens_extender = sam_h.mask_decoder.ensure_find(MaskDecoderTokensExtender)
|
||||||
|
|
||||||
|
@ -175,8 +154,10 @@ def test_tokens(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_compress_vit_feat(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
|
def test_compress_vit_feat(
|
||||||
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
sam_h: SegmentAnythingH, sam_h_hq_adapter_weights_path: Path, reference_sam_h: FacebookSAM
|
||||||
|
) -> None:
|
||||||
|
HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject()
|
||||||
|
|
||||||
early_vit_embedding = torch.randn(1, 64, 64, 1280, device=sam_h.device, dtype=sam_h.dtype)
|
early_vit_embedding = torch.randn(1, 64, 64, 1280, device=sam_h.device, dtype=sam_h.dtype)
|
||||||
|
|
||||||
|
@ -189,8 +170,10 @@ def test_compress_vit_feat(sam_h: SegmentAnythingH, hq_adapter_weights: Path, re
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_embedding_encoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
|
def test_embedding_encoder(
|
||||||
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
sam_h: SegmentAnythingH, sam_h_hq_adapter_weights_path: Path, reference_sam_h: FacebookSAM
|
||||||
|
) -> None:
|
||||||
|
HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject()
|
||||||
|
|
||||||
x = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype)
|
x = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype)
|
||||||
|
|
||||||
|
@ -203,8 +186,10 @@ def test_embedding_encoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path, re
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_hq_token_mlp(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
|
def test_hq_token_mlp(
|
||||||
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
sam_h: SegmentAnythingH, sam_h_hq_adapter_weights_path: Path, reference_sam_h: FacebookSAM
|
||||||
|
) -> None:
|
||||||
|
HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject()
|
||||||
|
|
||||||
x = torch.randn(1, 6, 256, device=sam_h.device, dtype=sam_h.dtype)
|
x = torch.randn(1, 6, 256, device=sam_h.device, dtype=sam_h.dtype)
|
||||||
|
|
||||||
|
@ -217,13 +202,13 @@ def test_hq_token_mlp(sam_h: SegmentAnythingH, hq_adapter_weights: Path, referen
|
||||||
@pytest.mark.parametrize("hq_mask_only", [True, False])
|
@pytest.mark.parametrize("hq_mask_only", [True, False])
|
||||||
def test_predictor(
|
def test_predictor(
|
||||||
sam_h: SegmentAnythingH,
|
sam_h: SegmentAnythingH,
|
||||||
hq_adapter_weights: Path,
|
sam_h_hq_adapter_weights_path: Path,
|
||||||
hq_mask_only: bool,
|
hq_mask_only: bool,
|
||||||
reference_sam_h_predictor: FacebookSAMPredictorHQ,
|
reference_sam_h_predictor: FacebookSAMPredictorHQ,
|
||||||
tennis: Image.Image,
|
tennis: Image.Image,
|
||||||
one_prompt: SAMPrompt,
|
one_prompt: SAMPrompt,
|
||||||
) -> None:
|
) -> None:
|
||||||
adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject()
|
||||||
|
|
||||||
adapter.hq_mask_only = hq_mask_only
|
adapter.hq_mask_only = hq_mask_only
|
||||||
assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only
|
assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only
|
||||||
|
@ -268,13 +253,13 @@ def test_predictor(
|
||||||
@pytest.mark.parametrize("hq_mask_only", [True, False])
|
@pytest.mark.parametrize("hq_mask_only", [True, False])
|
||||||
def test_predictor_equal(
|
def test_predictor_equal(
|
||||||
sam_h: SegmentAnythingH,
|
sam_h: SegmentAnythingH,
|
||||||
hq_adapter_weights: Path,
|
sam_h_hq_adapter_weights_path: Path,
|
||||||
hq_mask_only: bool,
|
hq_mask_only: bool,
|
||||||
reference_sam_h_predictor: FacebookSAMPredictorHQ,
|
reference_sam_h_predictor: FacebookSAMPredictorHQ,
|
||||||
tennis: Image.Image,
|
tennis: Image.Image,
|
||||||
one_prompt: SAMPrompt,
|
one_prompt: SAMPrompt,
|
||||||
) -> None:
|
) -> None:
|
||||||
adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject()
|
||||||
|
|
||||||
adapter.hq_mask_only = hq_mask_only
|
adapter.hq_mask_only = hq_mask_only
|
||||||
assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only
|
assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only
|
||||||
|
@ -318,8 +303,8 @@ def test_predictor_equal(
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) -> None:
|
def test_batch_mask_decoder(sam_h: SegmentAnythingH, sam_h_hq_adapter_weights_path: Path) -> None:
|
||||||
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
HQSAMAdapter(sam_h, weights=load_from_safetensors(sam_h_hq_adapter_weights_path)).inject()
|
||||||
|
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
|
|
||||||
|
@ -348,8 +333,10 @@ def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) -
|
||||||
assert torch.equal(mask_prediction[0], mask_prediction[1])
|
assert torch.equal(mask_prediction[0], mask_prediction[1])
|
||||||
|
|
||||||
|
|
||||||
def test_hq_sam_load_save_weights(sam_h: SegmentAnythingH, hq_adapter_weights: Path, test_device: torch.device) -> None:
|
def test_hq_sam_load_save_weights(
|
||||||
weights = load_from_safetensors(hq_adapter_weights, device=test_device)
|
sam_h: SegmentAnythingH, sam_h_hq_adapter_weights_path: Path, test_device: torch.device
|
||||||
|
) -> None:
|
||||||
|
weights = load_from_safetensors(sam_h_hq_adapter_weights_path, device=test_device)
|
||||||
|
|
||||||
hq_sam_adapter = HQSAMAdapter(sam_h)
|
hq_sam_adapter = HQSAMAdapter(sam_h)
|
||||||
out_weights_init = hq_sam_adapter.weights
|
out_weights_init = hq_sam_adapter.weights
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
from math import isclose
|
from math import isclose
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from warnings import warn
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -17,8 +16,8 @@ from tests.foundationals.segment_anything.utils import (
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
|
from refiners.conversion.model_converter import ModelConverter
|
||||||
from refiners.fluxion import manual_seed
|
from refiners.fluxion import manual_seed
|
||||||
from refiners.fluxion.model_converter import ModelConverter
|
|
||||||
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
|
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
|
||||||
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention
|
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention
|
||||||
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
||||||
|
@ -49,20 +48,11 @@ def one_prompt() -> SAMPrompt:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def facebook_sam_h_weights(test_weights_path: Path) -> Path:
|
def facebook_sam_h(sam_h_unconverted_weights_path: Path, test_device: torch.device) -> FacebookSAM:
|
||||||
sam_h_weights = test_weights_path / "sam_vit_h_4b8939.pth"
|
|
||||||
if not sam_h_weights.is_file():
|
|
||||||
warn(f"could not find weights at {sam_h_weights}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return sam_h_weights
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def facebook_sam_h(facebook_sam_h_weights: Path, test_device: torch.device) -> FacebookSAM:
|
|
||||||
from segment_anything import build_sam_vit_h # type: ignore
|
from segment_anything import build_sam_vit_h # type: ignore
|
||||||
|
|
||||||
sam_h = cast(FacebookSAM, build_sam_vit_h())
|
sam_h = cast(FacebookSAM, build_sam_vit_h())
|
||||||
sam_h.load_state_dict(state_dict=load_tensors(facebook_sam_h_weights))
|
sam_h.load_state_dict(state_dict=load_tensors(sam_h_unconverted_weights_path))
|
||||||
return sam_h.to(device=test_device)
|
return sam_h.to(device=test_device)
|
||||||
|
|
||||||
|
|
||||||
|
@ -76,16 +66,16 @@ def facebook_sam_h_predictor(facebook_sam_h: FacebookSAM) -> FacebookSAMPredicto
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
|
def sam_h(sam_h_weights_path: Path, test_device: torch.device) -> SegmentAnythingH:
|
||||||
sam_h = SegmentAnythingH(device=test_device)
|
sam_h = SegmentAnythingH(device=test_device)
|
||||||
sam_h.load_from_safetensors(tensors_path=sam_h_weights)
|
sam_h.load_from_safetensors(tensors_path=sam_h_weights_path)
|
||||||
return sam_h
|
return sam_h
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def sam_h_single_output(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
|
def sam_h_single_output(sam_h_weights_path: Path, test_device: torch.device) -> SegmentAnythingH:
|
||||||
sam_h = SegmentAnythingH(multimask_output=False, device=test_device)
|
sam_h = SegmentAnythingH(multimask_output=False, device=test_device)
|
||||||
sam_h.load_from_safetensors(tensors_path=sam_h_weights)
|
sam_h.load_from_safetensors(tensors_path=sam_h_weights_path)
|
||||||
return sam_h
|
return sam_h
|
||||||
|
|
||||||
|
|
||||||
|
@ -469,7 +459,10 @@ def test_predictor_resized_single_output(
|
||||||
|
|
||||||
|
|
||||||
def test_mask_encoder(
|
def test_mask_encoder(
|
||||||
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt
|
facebook_sam_h_predictor: FacebookSAMPredictor,
|
||||||
|
sam_h: SegmentAnythingH,
|
||||||
|
truck: Image.Image,
|
||||||
|
one_prompt: SAMPrompt,
|
||||||
) -> None:
|
) -> None:
|
||||||
predictor = facebook_sam_h_predictor
|
predictor = facebook_sam_h_predictor
|
||||||
predictor.set_image(np.array(truck))
|
predictor.set_image(np.array(truck))
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from warnings import warn
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
@ -25,16 +24,11 @@ class CifarDataset(Dataset[torch.Tensor]):
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def dinov2_l(
|
def dinov2_l(
|
||||||
test_weights_path: Path,
|
dinov2_large_weights_path: Path,
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
) -> dinov2.DINOv2_large:
|
) -> dinov2.DINOv2_large:
|
||||||
weights = test_weights_path / f"dinov2_vitl14_pretrain.safetensors"
|
|
||||||
if not weights.is_file():
|
|
||||||
warn(f"could not find weights at {weights}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
|
|
||||||
model = dinov2.DINOv2_large(device=test_device)
|
model = dinov2.DINOv2_large(device=test_device)
|
||||||
model.load_from_safetensors(weights)
|
model.load_from_safetensors(dinov2_large_weights_path)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -24,11 +24,20 @@ def ensure_similar_images(img_1: Image.Image, img_2: Image.Image, min_psnr: int
|
||||||
|
|
||||||
class T5TextEmbedder(nn.Module):
|
class T5TextEmbedder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, pretrained_path: Path = Path("tests/weights/QQGYLab/T5XLFP16"), max_length: int | None = None
|
self,
|
||||||
|
pretrained_path: Path | str,
|
||||||
|
max_length: int | None = None,
|
||||||
|
local_files_only: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__() # type: ignore[reportUnknownMemberType]
|
super().__init__() # type: ignore[reportUnknownMemberType]
|
||||||
self.model: nn.Module = T5EncoderModel.from_pretrained(pretrained_path, local_files_only=True) # type: ignore
|
self.model: nn.Module = T5EncoderModel.from_pretrained( # type: ignore
|
||||||
self.tokenizer: transformers.T5Tokenizer = T5Tokenizer.from_pretrained(pretrained_path, local_files_only=True) # type: ignore
|
pretrained_path,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
self.tokenizer: transformers.T5Tokenizer = T5Tokenizer.from_pretrained( # type: ignore
|
||||||
|
pretrained_path,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|
Loading…
Reference in a new issue