deprecate _img_open in tests

This commit is contained in:
Laurent 2024-10-15 13:51:19 +00:00 committed by Laureηt
parent 0620473109
commit 82fd2a63f7
7 changed files with 80 additions and 108 deletions

View file

@ -50,10 +50,6 @@ from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAd
from ..weight_paths import get_path from ..weight_paths import get_path
def _img_open(path: Path) -> Image.Image:
return Image.open(path) # type: ignore
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def ensure_gc(): def ensure_gc():
# Avoid GPU OOMs # Avoid GPU OOMs
@ -68,132 +64,132 @@ def ref_path(test_e2e_path: Path) -> Path:
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def cutecat_init(ref_path: Path) -> Image.Image: def cutecat_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "cutecat_init.png").convert("RGB") return Image.open(ref_path / "cutecat_init.png").convert("RGB")
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def kitchen_dog(ref_path: Path) -> Image.Image: def kitchen_dog(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "kitchen_dog.png").convert("RGB") return Image.open(ref_path / "kitchen_dog.png").convert("RGB")
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def kitchen_dog_mask(ref_path: Path) -> Image.Image: def kitchen_dog_mask(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "kitchen_dog_mask.png").convert("RGB") return Image.open(ref_path / "kitchen_dog_mask.png").convert("RGB")
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def woman_image(ref_path: Path) -> Image.Image: def woman_image(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "woman.png").convert("RGB") return Image.open(ref_path / "woman.png").convert("RGB")
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def statue_image(ref_path: Path) -> Image.Image: def statue_image(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "statue.png").convert("RGB") return Image.open(ref_path / "statue.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_std_random_init(ref_path: Path) -> Image.Image: def expected_image_std_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_random_init.png").convert("RGB") return Image.open(ref_path / "expected_std_random_init.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_std_random_init_bfloat16(ref_path: Path) -> Image.Image: def expected_image_std_random_init_bfloat16(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_random_init_bfloat16.png").convert("RGB") return Image.open(ref_path / "expected_std_random_init_bfloat16.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_std_sde_random_init(ref_path: Path) -> Image.Image: def expected_image_std_sde_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_sde_random_init.png").convert("RGB") return Image.open(ref_path / "expected_std_sde_random_init.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_std_sde_karras_random_init(ref_path: Path) -> Image.Image: def expected_image_std_sde_karras_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_sde_karras_random_init.png").convert("RGB") return Image.open(ref_path / "expected_std_sde_karras_random_init.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image: def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB") return Image.open(ref_path / "expected_std_random_init_euler.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_karras_random_init(ref_path: Path) -> Image.Image: def expected_karras_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_karras_random_init.png").convert("RGB") return Image.open(ref_path / "expected_karras_random_init.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_std_random_init_sag(ref_path: Path) -> Image.Image: def expected_image_std_random_init_sag(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_random_init_sag.png").convert("RGB") return Image.open(ref_path / "expected_std_random_init_sag.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_std_init_image(ref_path: Path) -> Image.Image: def expected_image_std_init_image(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_init_image.png").convert("RGB") return Image.open(ref_path / "expected_std_init_image.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_ella_adapter(ref_path: Path) -> Image.Image: def expected_image_ella_adapter(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_ella_adapter.png").convert("RGB") return Image.open(ref_path / "expected_image_ella_adapter.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_std_inpainting(ref_path: Path) -> Image.Image: def expected_image_std_inpainting(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_inpainting.png").convert("RGB") return Image.open(ref_path / "expected_std_inpainting.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_controlnet_stack(ref_path: Path) -> Image.Image: def expected_image_controlnet_stack(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_controlnet_stack.png").convert("RGB") return Image.open(ref_path / "expected_controlnet_stack.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image: def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB") return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_ip_adapter_multi(ref_path: Path) -> Image.Image: def expected_image_ip_adapter_multi(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_ip_adapter_multi.png").convert("RGB") return Image.open(ref_path / "expected_image_ip_adapter_multi.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_ip_adapter_plus_statue(ref_path: Path) -> Image.Image: def expected_image_ip_adapter_plus_statue(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_ip_adapter_plus_statue.png").convert("RGB") return Image.open(ref_path / "expected_image_ip_adapter_plus_statue.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_sdxl_ip_adapter_woman(ref_path: Path) -> Image.Image: def expected_image_sdxl_ip_adapter_woman(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB") return Image.open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_sdxl_ip_adapter_plus_woman(ref_path: Path) -> Image.Image: def expected_image_sdxl_ip_adapter_plus_woman(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_sdxl_ip_adapter_plus_woman.png").convert("RGB") return Image.open(ref_path / "expected_image_sdxl_ip_adapter_plus_woman.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_ip_adapter_controlnet(ref_path: Path) -> Image.Image: def expected_image_ip_adapter_controlnet(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_ip_adapter_controlnet.png").convert("RGB") return Image.open(ref_path / "expected_ip_adapter_controlnet.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image: def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert("RGB") return Image.open(ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_sdxl_ddim_random_init_sag(ref_path: Path) -> Image.Image: def expected_sdxl_ddim_random_init_sag(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert("RGB") return Image.open(ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_sdxl_euler_random_init(ref_path: Path) -> Image.Image: def expected_sdxl_euler_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_cutecat_sdxl_euler_random_init.png").convert("RGB") return Image.open(ref_path / "expected_cutecat_sdxl_euler_random_init.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_style_aligned(ref_path: Path) -> Image.Image: def expected_style_aligned(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_style_aligned.png").convert(mode="RGB") return Image.open(ref_path / "expected_style_aligned.png").convert(mode="RGB")
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"]) @pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
@ -207,8 +203,8 @@ def controlnet_data(
request: pytest.FixtureRequest, request: pytest.FixtureRequest,
) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]: ) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]:
cn_name: str = request.param cn_name: str = request.param
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
weights_fn = { weights_fn = {
"depth": controlnet_depth_weights_path, "depth": controlnet_depth_weights_path,
@ -229,8 +225,8 @@ def controlnet_data_scale_decay(
request: pytest.FixtureRequest, request: pytest.FixtureRequest,
) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]: ) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]:
cn_name: str = request.param cn_name: str = request.param
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}_scale_decay.png").convert("RGB") expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}_scale_decay.png").convert("RGB")
weights_fn = { weights_fn = {
"canny": controlnet_canny_weights_path, "canny": controlnet_canny_weights_path,
@ -245,8 +241,8 @@ def controlnet_data_tile(
ref_path: Path, ref_path: Path,
controlnet_tiles_weights_path: Path, controlnet_tiles_weights_path: Path,
) -> tuple[Image.Image, Image.Image, Path]: ) -> tuple[Image.Image, Image.Image, Path]:
condition_image = _img_open(ref_path / f"low_res_dog.png").convert("RGB").resize((1024, 1024)) # type: ignore condition_image = Image.open(ref_path / f"low_res_dog.png").convert("RGB").resize((1024, 1024)) # type: ignore
expected_image = _img_open(ref_path / f"expected_controlnet_tile.png").convert("RGB") expected_image = Image.open(ref_path / f"expected_controlnet_tile.png").convert("RGB")
return condition_image, expected_image, controlnet_tiles_weights_path return condition_image, expected_image, controlnet_tiles_weights_path
@ -256,8 +252,8 @@ def controlnet_data_canny(
controlnet_canny_weights_path: Path, controlnet_canny_weights_path: Path,
) -> tuple[str, Image.Image, Image.Image, Path]: ) -> tuple[str, Image.Image, Image.Image, Path]:
cn_name = "canny" cn_name = "canny"
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
return cn_name, condition_image, expected_image, controlnet_canny_weights_path return cn_name, condition_image, expected_image, controlnet_canny_weights_path
@ -267,8 +263,8 @@ def controlnet_data_depth(
controlnet_depth_weights_path: Path, controlnet_depth_weights_path: Path,
) -> tuple[str, Image.Image, Image.Image, Path]: ) -> tuple[str, Image.Image, Image.Image, Path]:
cn_name = "depth" cn_name = "depth"
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
return cn_name, condition_image, expected_image, controlnet_depth_weights_path return cn_name, condition_image, expected_image, controlnet_depth_weights_path
@ -336,12 +332,12 @@ def controllora_sdxl_config(
) -> tuple[Image.Image, dict[str, ControlLoraResolvedConfig]]: ) -> tuple[Image.Image, dict[str, ControlLoraResolvedConfig]]:
name: str = request.param[0] name: str = request.param[0]
configs: dict[str, ControlLoraConfig] = request.param[1] configs: dict[str, ControlLoraConfig] = request.param[1]
expected_image = _img_open(ref_path / name).convert("RGB") expected_image = Image.open(ref_path / name).convert("RGB")
loaded_configs = { loaded_configs = {
config_name: ControlLoraResolvedConfig( config_name: ControlLoraResolvedConfig(
scale=config.scale, scale=config.scale,
condition_image=_img_open(ref_path / config.condition_path).convert("RGB"), condition_image=Image.open(ref_path / config.condition_path).convert("RGB"),
weights_path=get_path(config.weights, use_local_weights), weights_path=get_path(config.weights, use_local_weights),
) )
for config_name, config in configs.items() for config_name, config in configs.items()
@ -356,8 +352,8 @@ def t2i_adapter_data_depth(
t2i_depth_weights_path: Path, t2i_depth_weights_path: Path,
) -> tuple[str, Image.Image, Image.Image, Path]: ) -> tuple[str, Image.Image, Image.Image, Path]:
name = "depth" name = "depth"
condition_image = _img_open(ref_path / f"cutecat_guide_{name}.png").convert("RGB") condition_image = Image.open(ref_path / f"cutecat_guide_{name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_t2i_adapter_{name}.png").convert("RGB") expected_image = Image.open(ref_path / f"expected_t2i_adapter_{name}.png").convert("RGB")
return name, condition_image, expected_image, t2i_depth_weights_path return name, condition_image, expected_image, t2i_depth_weights_path
@ -367,8 +363,8 @@ def t2i_adapter_xl_data_canny(
t2i_sdxl_canny_weights_path: Path, t2i_sdxl_canny_weights_path: Path,
) -> tuple[str, Image.Image, Image.Image, Path]: ) -> tuple[str, Image.Image, Image.Image, Path]:
name = "canny" name = "canny"
condition_image = _img_open(ref_path / f"fairy_guide_{name}.png").convert("RGB") condition_image = Image.open(ref_path / f"fairy_guide_{name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB") expected_image = Image.open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB")
return name, condition_image, expected_image, t2i_sdxl_canny_weights_path return name, condition_image, expected_image, t2i_sdxl_canny_weights_path
@ -377,7 +373,7 @@ def lora_data_pokemon(
ref_path: Path, ref_path: Path,
lora_pokemon_weights_path: Path, lora_pokemon_weights_path: Path,
) -> tuple[Image.Image, dict[str, torch.Tensor]]: ) -> tuple[Image.Image, dict[str, torch.Tensor]]:
expected_image = _img_open(ref_path / "expected_lora_pokemon.png").convert("RGB") expected_image = Image.open(ref_path / "expected_lora_pokemon.png").convert("RGB")
tensors = load_tensors(lora_pokemon_weights_path) tensors = load_tensors(lora_pokemon_weights_path)
return expected_image, tensors return expected_image, tensors
@ -387,7 +383,7 @@ def lora_data_dpo(
ref_path: Path, ref_path: Path,
lora_dpo_weights_path: Path, lora_dpo_weights_path: Path,
) -> tuple[Image.Image, dict[str, torch.Tensor]]: ) -> tuple[Image.Image, dict[str, torch.Tensor]]:
expected_image = _img_open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB") expected_image = Image.open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB")
tensors = load_from_safetensors(lora_dpo_weights_path) tensors = load_from_safetensors(lora_dpo_weights_path)
return expected_image, tensors return expected_image, tensors
@ -411,62 +407,62 @@ def lora_sliders(
@pytest.fixture @pytest.fixture
def scene_image_inpainting_refonly(ref_path: Path) -> Image.Image: def scene_image_inpainting_refonly(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "inpainting-scene.png").convert("RGB") return Image.open(ref_path / "inpainting-scene.png").convert("RGB")
@pytest.fixture @pytest.fixture
def mask_image_inpainting_refonly(ref_path: Path) -> Image.Image: def mask_image_inpainting_refonly(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "inpainting-mask.png").convert("RGB") return Image.open(ref_path / "inpainting-mask.png").convert("RGB")
@pytest.fixture @pytest.fixture
def target_image_inpainting_refonly(ref_path: Path) -> Image.Image: def target_image_inpainting_refonly(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "inpainting-target.png").convert("RGB") return Image.open(ref_path / "inpainting-target.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_inpainting_refonly(ref_path: Path) -> Image.Image: def expected_image_inpainting_refonly(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_inpainting_refonly.png").convert("RGB") return Image.open(ref_path / "expected_inpainting_refonly.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_refonly(ref_path: Path) -> Image.Image: def expected_image_refonly(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_refonly.png").convert("RGB") return Image.open(ref_path / "expected_refonly.png").convert("RGB")
@pytest.fixture @pytest.fixture
def condition_image_refonly(ref_path: Path) -> Image.Image: def condition_image_refonly(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "cyberpunk_guide.png").convert("RGB") return Image.open(ref_path / "cyberpunk_guide.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_textual_inversion_random_init(ref_path: Path) -> Image.Image: def expected_image_textual_inversion_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_textual_inversion_random_init.png").convert("RGB") return Image.open(ref_path / "expected_textual_inversion_random_init.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_multi_diffusion(ref_path: Path) -> Image.Image: def expected_multi_diffusion(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_multi_diffusion.png").convert(mode="RGB") return Image.open(ref_path / "expected_multi_diffusion.png").convert(mode="RGB")
@pytest.fixture @pytest.fixture
def expected_multi_diffusion_dpm(ref_path: Path) -> Image.Image: def expected_multi_diffusion_dpm(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_multi_diffusion_dpm.png").convert(mode="RGB") return Image.open(ref_path / "expected_multi_diffusion_dpm.png").convert(mode="RGB")
@pytest.fixture @pytest.fixture
def expected_restart(ref_path: Path) -> Image.Image: def expected_restart(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_restart.png").convert(mode="RGB") return Image.open(ref_path / "expected_restart.png").convert(mode="RGB")
@pytest.fixture @pytest.fixture
def expected_freeu(ref_path: Path) -> Image.Image: def expected_freeu(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_freeu.png").convert(mode="RGB") return Image.open(ref_path / "expected_freeu.png").convert(mode="RGB")
@pytest.fixture @pytest.fixture
def expected_sdxl_multi_loras(ref_path: Path) -> Image.Image: def expected_sdxl_multi_loras(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_sdxl_multi_loras.png").convert(mode="RGB") return Image.open(ref_path / "expected_sdxl_multi_loras.png").convert(mode="RGB")
@pytest.fixture @pytest.fixture
@ -476,10 +472,10 @@ def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image.
image_prompt = assets / "dragon_quest_slime.jpg" image_prompt = assets / "dragon_quest_slime.jpg"
condition_image = assets / "dropy_canny.png" condition_image = assets / "dropy_canny.png"
return ( return (
_img_open(dropy).convert(mode="RGB"), Image.open(dropy).convert(mode="RGB"),
_img_open(image_prompt).convert(mode="RGB"), Image.open(image_prompt).convert(mode="RGB"),
_img_open(condition_image).convert(mode="RGB"), Image.open(condition_image).convert(mode="RGB"),
_img_open(ref_path / "expected_dropy_slime_9752.png").convert(mode="RGB"), Image.open(ref_path / "expected_dropy_slime_9752.png").convert(mode="RGB"),
) )
@ -2639,7 +2635,7 @@ def test_multi_upscaler_small(
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def expected_ic_light(ref_path: Path) -> Image.Image: def expected_ic_light(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_ic_light.png").convert("RGB") return Image.open(ref_path / "expected_ic_light.png").convert("RGB")
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View file

@ -13,10 +13,6 @@ from refiners.foundationals.latent_diffusion.lora import SDLoraManager
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
def _img_open(path: Path) -> Image.Image:
return Image.open(path) # type: ignore
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def ensure_gc(): def ensure_gc():
# Avoid GPU OOMs # Avoid GPU OOMs
@ -51,27 +47,27 @@ def sdxl(
@pytest.fixture @pytest.fixture
def image_prompt_german_castle(ref_path: Path) -> Image.Image: def image_prompt_german_castle(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "german-castle.jpg").convert("RGB") return Image.open(ref_path / "german-castle.jpg").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_guide_adapting_sdxl_vanilla(ref_path: Path) -> Image.Image: def expected_image_guide_adapting_sdxl_vanilla(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_guide_adapting_sdxl_vanilla.png").convert("RGB") return Image.open(ref_path / "expected_image_guide_adapting_sdxl_vanilla.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_guide_adapting_sdxl_single_lora(ref_path: Path) -> Image.Image: def expected_image_guide_adapting_sdxl_single_lora(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_guide_adapting_sdxl_single_lora.png").convert("RGB") return Image.open(ref_path / "expected_image_guide_adapting_sdxl_single_lora.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_guide_adapting_sdxl_multiple_loras(ref_path: Path) -> Image.Image: def expected_image_guide_adapting_sdxl_multiple_loras(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_guide_adapting_sdxl_multiple_loras.png").convert("RGB") return Image.open(ref_path / "expected_image_guide_adapting_sdxl_multiple_loras.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_guide_adapting_sdxl_loras_ip_adapter(ref_path: Path) -> Image.Image: def expected_image_guide_adapting_sdxl_loras_ip_adapter(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_guide_adapting_sdxl_loras_ip_adapter.png").convert("RGB") return Image.open(ref_path / "expected_image_guide_adapting_sdxl_loras_ip_adapter.png").convert("RGB")
@no_grad() @no_grad()

View file

@ -15,10 +15,6 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm_lora import
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
def _img_open(path: Path) -> Image.Image:
return Image.open(path) # type: ignore
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def ensure_gc(): def ensure_gc():
# Avoid GPU OOMs # Avoid GPU OOMs
@ -33,17 +29,17 @@ def ref_path(test_e2e_path: Path) -> Path:
@pytest.fixture @pytest.fixture
def expected_lcm_base(ref_path: Path) -> Image.Image: def expected_lcm_base(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_lcm_base.png").convert("RGB") return Image.open(ref_path / "expected_lcm_base.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_lcm_lora_1_0(ref_path: Path) -> Image.Image: def expected_lcm_lora_1_0(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_lcm_lora_1_0.png").convert("RGB") return Image.open(ref_path / "expected_lcm_lora_1_0.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_lcm_lora_1_2(ref_path: Path) -> Image.Image: def expected_lcm_lora_1_2(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_lcm_lora_1_2.png").convert("RGB") return Image.open(ref_path / "expected_lcm_lora_1_2.png").convert("RGB")
@no_grad() @no_grad()

View file

@ -14,10 +14,6 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm_lora import
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
def _img_open(path: Path) -> Image.Image:
return Image.open(path) # type: ignore
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def ensure_gc(): def ensure_gc():
# Avoid GPU OOMs # Avoid GPU OOMs
@ -32,17 +28,17 @@ def ref_path(test_e2e_path: Path) -> Path:
@pytest.fixture @pytest.fixture
def expected_lightning_base_4step(ref_path: Path) -> Image.Image: def expected_lightning_base_4step(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_lightning_base_4step.png").convert("RGB") return Image.open(ref_path / "expected_lightning_base_4step.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_lightning_base_1step(ref_path: Path) -> Image.Image: def expected_lightning_base_1step(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_lightning_base_1step.png").convert("RGB") return Image.open(ref_path / "expected_lightning_base_1step.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_lightning_lora_4step(ref_path: Path) -> Image.Image: def expected_lightning_lora_4step(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_lightning_lora_4step.png").convert("RGB") return Image.open(ref_path / "expected_lightning_lora_4step.png").convert("RGB")
@no_grad() @no_grad()

View file

@ -10,10 +10,6 @@ from refiners.fluxion.utils import image_to_tensor, no_grad, normalize, tensor_t
from refiners.foundationals.swin.mvanet import MVANet from refiners.foundationals.swin.mvanet import MVANet
def _img_open(path: Path) -> Image.Image:
return Image.open(path) # type: ignore
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ref_path(test_e2e_path: Path) -> Path: def ref_path(test_e2e_path: Path) -> Path:
return test_e2e_path / "test_mvanet_ref" return test_e2e_path / "test_mvanet_ref"
@ -21,12 +17,12 @@ def ref_path(test_e2e_path: Path) -> Path:
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ref_cactus(ref_path: Path) -> Image.Image: def ref_cactus(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "cactus.png").convert("RGB") return Image.open(ref_path / "cactus.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_cactus_mask(ref_path: Path) -> Image.Image: def expected_cactus_mask(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_cactus_mask.png") return Image.open(ref_path / "expected_cactus_mask.png")
@pytest.fixture @pytest.fixture

View file

@ -9,10 +9,6 @@ from refiners.fluxion.utils import image_to_tensor, no_grad, tensor_to_image
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
def _img_open(path: Path) -> Image.Image:
return Image.open(path) # type: ignore
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def diffusion_ref_path(test_e2e_path: Path) -> Path: def diffusion_ref_path(test_e2e_path: Path) -> Path:
return test_e2e_path / "test_diffusion_ref" return test_e2e_path / "test_diffusion_ref"
@ -20,12 +16,12 @@ def diffusion_ref_path(test_e2e_path: Path) -> Path:
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def cutecat_init(diffusion_ref_path: Path) -> Image.Image: def cutecat_init(diffusion_ref_path: Path) -> Image.Image:
return _img_open(diffusion_ref_path / "cutecat_init.png").convert("RGB") return Image.open(diffusion_ref_path / "cutecat_init.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_informative_drawings(diffusion_ref_path: Path) -> Image.Image: def expected_image_informative_drawings(diffusion_ref_path: Path) -> Image.Image:
return _img_open(diffusion_ref_path / "cutecat_guide_lineart.png").convert("RGB") return Image.open(diffusion_ref_path / "cutecat_guide_lineart.png").convert("RGB")
@pytest.fixture @pytest.fixture

View file

@ -8,10 +8,6 @@ from tests.utils import ensure_similar_images
from refiners.solutions import BoxSegmenter from refiners.solutions import BoxSegmenter
def _img_open(path: Path) -> Image.Image:
return Image.open(path) # type: ignore
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ref_path(test_e2e_path: Path) -> Path: def ref_path(test_e2e_path: Path) -> Path:
return test_e2e_path / "test_solutions_ref" return test_e2e_path / "test_solutions_ref"
@ -19,22 +15,22 @@ def ref_path(test_e2e_path: Path) -> Path:
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ref_shelves(ref_path: Path) -> Image.Image: def ref_shelves(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "shelves.jpg").convert("RGB") return Image.open(ref_path / "shelves.jpg").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_box_segmenter_plant_mask(ref_path: Path) -> Image.Image: def expected_box_segmenter_plant_mask(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_box_segmenter_plant_mask.png") return Image.open(ref_path / "expected_box_segmenter_plant_mask.png")
@pytest.fixture @pytest.fixture
def expected_box_segmenter_spray_mask(ref_path: Path) -> Image.Image: def expected_box_segmenter_spray_mask(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_box_segmenter_spray_mask.png") return Image.open(ref_path / "expected_box_segmenter_spray_mask.png")
@pytest.fixture @pytest.fixture
def expected_box_segmenter_spray_cropped_mask(ref_path: Path) -> Image.Image: def expected_box_segmenter_spray_cropped_mask(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_box_segmenter_spray_cropped_mask.png") return Image.open(ref_path / "expected_box_segmenter_spray_cropped_mask.png")
def test_box_segmenter( def test_box_segmenter(