diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 90e3534..c953c68 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -50,10 +50,6 @@ from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAd from ..weight_paths import get_path -def _img_open(path: Path) -> Image.Image: - return Image.open(path) # type: ignore - - @pytest.fixture(autouse=True) def ensure_gc(): # Avoid GPU OOMs @@ -68,132 +64,132 @@ def ref_path(test_e2e_path: Path) -> Path: @pytest.fixture(scope="module") def cutecat_init(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "cutecat_init.png").convert("RGB") + return Image.open(ref_path / "cutecat_init.png").convert("RGB") @pytest.fixture(scope="module") def kitchen_dog(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "kitchen_dog.png").convert("RGB") + return Image.open(ref_path / "kitchen_dog.png").convert("RGB") @pytest.fixture(scope="module") def kitchen_dog_mask(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "kitchen_dog_mask.png").convert("RGB") + return Image.open(ref_path / "kitchen_dog_mask.png").convert("RGB") @pytest.fixture(scope="module") def woman_image(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "woman.png").convert("RGB") + return Image.open(ref_path / "woman.png").convert("RGB") @pytest.fixture(scope="module") def statue_image(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "statue.png").convert("RGB") + return Image.open(ref_path / "statue.png").convert("RGB") @pytest.fixture def expected_image_std_random_init(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_std_random_init.png").convert("RGB") + return Image.open(ref_path / "expected_std_random_init.png").convert("RGB") @pytest.fixture def expected_image_std_random_init_bfloat16(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_std_random_init_bfloat16.png").convert("RGB") + return Image.open(ref_path / "expected_std_random_init_bfloat16.png").convert("RGB") @pytest.fixture def expected_image_std_sde_random_init(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_std_sde_random_init.png").convert("RGB") + return Image.open(ref_path / "expected_std_sde_random_init.png").convert("RGB") @pytest.fixture def expected_image_std_sde_karras_random_init(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_std_sde_karras_random_init.png").convert("RGB") + return Image.open(ref_path / "expected_std_sde_karras_random_init.png").convert("RGB") @pytest.fixture def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB") + return Image.open(ref_path / "expected_std_random_init_euler.png").convert("RGB") @pytest.fixture def expected_karras_random_init(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_karras_random_init.png").convert("RGB") + return Image.open(ref_path / "expected_karras_random_init.png").convert("RGB") @pytest.fixture def expected_image_std_random_init_sag(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_std_random_init_sag.png").convert("RGB") + return Image.open(ref_path / "expected_std_random_init_sag.png").convert("RGB") @pytest.fixture def expected_image_std_init_image(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_std_init_image.png").convert("RGB") + return Image.open(ref_path / "expected_std_init_image.png").convert("RGB") @pytest.fixture def expected_image_ella_adapter(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_image_ella_adapter.png").convert("RGB") + return Image.open(ref_path / "expected_image_ella_adapter.png").convert("RGB") @pytest.fixture def expected_image_std_inpainting(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_std_inpainting.png").convert("RGB") + return Image.open(ref_path / "expected_std_inpainting.png").convert("RGB") @pytest.fixture def expected_image_controlnet_stack(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_controlnet_stack.png").convert("RGB") + return Image.open(ref_path / "expected_controlnet_stack.png").convert("RGB") @pytest.fixture def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB") + return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB") @pytest.fixture def expected_image_ip_adapter_multi(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_image_ip_adapter_multi.png").convert("RGB") + return Image.open(ref_path / "expected_image_ip_adapter_multi.png").convert("RGB") @pytest.fixture def expected_image_ip_adapter_plus_statue(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_image_ip_adapter_plus_statue.png").convert("RGB") + return Image.open(ref_path / "expected_image_ip_adapter_plus_statue.png").convert("RGB") @pytest.fixture def expected_image_sdxl_ip_adapter_woman(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB") + return Image.open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB") @pytest.fixture def expected_image_sdxl_ip_adapter_plus_woman(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_image_sdxl_ip_adapter_plus_woman.png").convert("RGB") + return Image.open(ref_path / "expected_image_sdxl_ip_adapter_plus_woman.png").convert("RGB") @pytest.fixture def expected_image_ip_adapter_controlnet(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_ip_adapter_controlnet.png").convert("RGB") + return Image.open(ref_path / "expected_ip_adapter_controlnet.png").convert("RGB") @pytest.fixture def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert("RGB") + return Image.open(ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert("RGB") @pytest.fixture def expected_sdxl_ddim_random_init_sag(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert("RGB") + return Image.open(ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert("RGB") @pytest.fixture def expected_sdxl_euler_random_init(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_cutecat_sdxl_euler_random_init.png").convert("RGB") + return Image.open(ref_path / "expected_cutecat_sdxl_euler_random_init.png").convert("RGB") @pytest.fixture def expected_style_aligned(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_style_aligned.png").convert(mode="RGB") + return Image.open(ref_path / "expected_style_aligned.png").convert(mode="RGB") @pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"]) @@ -207,8 +203,8 @@ def controlnet_data( request: pytest.FixtureRequest, ) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]: cn_name: str = request.param - condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") - expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") + condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") + expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") weights_fn = { "depth": controlnet_depth_weights_path, @@ -229,8 +225,8 @@ def controlnet_data_scale_decay( request: pytest.FixtureRequest, ) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]: cn_name: str = request.param - condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") - expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}_scale_decay.png").convert("RGB") + condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") + expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}_scale_decay.png").convert("RGB") weights_fn = { "canny": controlnet_canny_weights_path, @@ -245,8 +241,8 @@ def controlnet_data_tile( ref_path: Path, controlnet_tiles_weights_path: Path, ) -> tuple[Image.Image, Image.Image, Path]: - condition_image = _img_open(ref_path / f"low_res_dog.png").convert("RGB").resize((1024, 1024)) # type: ignore - expected_image = _img_open(ref_path / f"expected_controlnet_tile.png").convert("RGB") + condition_image = Image.open(ref_path / f"low_res_dog.png").convert("RGB").resize((1024, 1024)) # type: ignore + expected_image = Image.open(ref_path / f"expected_controlnet_tile.png").convert("RGB") return condition_image, expected_image, controlnet_tiles_weights_path @@ -256,8 +252,8 @@ def controlnet_data_canny( controlnet_canny_weights_path: Path, ) -> tuple[str, Image.Image, Image.Image, Path]: cn_name = "canny" - condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") - expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") + condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") + expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") return cn_name, condition_image, expected_image, controlnet_canny_weights_path @@ -267,8 +263,8 @@ def controlnet_data_depth( controlnet_depth_weights_path: Path, ) -> tuple[str, Image.Image, Image.Image, Path]: cn_name = "depth" - condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") - expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") + condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") + expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") return cn_name, condition_image, expected_image, controlnet_depth_weights_path @@ -336,12 +332,12 @@ def controllora_sdxl_config( ) -> tuple[Image.Image, dict[str, ControlLoraResolvedConfig]]: name: str = request.param[0] configs: dict[str, ControlLoraConfig] = request.param[1] - expected_image = _img_open(ref_path / name).convert("RGB") + expected_image = Image.open(ref_path / name).convert("RGB") loaded_configs = { config_name: ControlLoraResolvedConfig( scale=config.scale, - condition_image=_img_open(ref_path / config.condition_path).convert("RGB"), + condition_image=Image.open(ref_path / config.condition_path).convert("RGB"), weights_path=get_path(config.weights, use_local_weights), ) for config_name, config in configs.items() @@ -356,8 +352,8 @@ def t2i_adapter_data_depth( t2i_depth_weights_path: Path, ) -> tuple[str, Image.Image, Image.Image, Path]: name = "depth" - condition_image = _img_open(ref_path / f"cutecat_guide_{name}.png").convert("RGB") - expected_image = _img_open(ref_path / f"expected_t2i_adapter_{name}.png").convert("RGB") + condition_image = Image.open(ref_path / f"cutecat_guide_{name}.png").convert("RGB") + expected_image = Image.open(ref_path / f"expected_t2i_adapter_{name}.png").convert("RGB") return name, condition_image, expected_image, t2i_depth_weights_path @@ -367,8 +363,8 @@ def t2i_adapter_xl_data_canny( t2i_sdxl_canny_weights_path: Path, ) -> tuple[str, Image.Image, Image.Image, Path]: name = "canny" - condition_image = _img_open(ref_path / f"fairy_guide_{name}.png").convert("RGB") - expected_image = _img_open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB") + condition_image = Image.open(ref_path / f"fairy_guide_{name}.png").convert("RGB") + expected_image = Image.open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB") return name, condition_image, expected_image, t2i_sdxl_canny_weights_path @@ -377,7 +373,7 @@ def lora_data_pokemon( ref_path: Path, lora_pokemon_weights_path: Path, ) -> tuple[Image.Image, dict[str, torch.Tensor]]: - expected_image = _img_open(ref_path / "expected_lora_pokemon.png").convert("RGB") + expected_image = Image.open(ref_path / "expected_lora_pokemon.png").convert("RGB") tensors = load_tensors(lora_pokemon_weights_path) return expected_image, tensors @@ -387,7 +383,7 @@ def lora_data_dpo( ref_path: Path, lora_dpo_weights_path: Path, ) -> tuple[Image.Image, dict[str, torch.Tensor]]: - expected_image = _img_open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB") + expected_image = Image.open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB") tensors = load_from_safetensors(lora_dpo_weights_path) return expected_image, tensors @@ -411,62 +407,62 @@ def lora_sliders( @pytest.fixture def scene_image_inpainting_refonly(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "inpainting-scene.png").convert("RGB") + return Image.open(ref_path / "inpainting-scene.png").convert("RGB") @pytest.fixture def mask_image_inpainting_refonly(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "inpainting-mask.png").convert("RGB") + return Image.open(ref_path / "inpainting-mask.png").convert("RGB") @pytest.fixture def target_image_inpainting_refonly(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "inpainting-target.png").convert("RGB") + return Image.open(ref_path / "inpainting-target.png").convert("RGB") @pytest.fixture def expected_image_inpainting_refonly(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_inpainting_refonly.png").convert("RGB") + return Image.open(ref_path / "expected_inpainting_refonly.png").convert("RGB") @pytest.fixture def expected_image_refonly(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_refonly.png").convert("RGB") + return Image.open(ref_path / "expected_refonly.png").convert("RGB") @pytest.fixture def condition_image_refonly(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "cyberpunk_guide.png").convert("RGB") + return Image.open(ref_path / "cyberpunk_guide.png").convert("RGB") @pytest.fixture def expected_image_textual_inversion_random_init(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_textual_inversion_random_init.png").convert("RGB") + return Image.open(ref_path / "expected_textual_inversion_random_init.png").convert("RGB") @pytest.fixture def expected_multi_diffusion(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_multi_diffusion.png").convert(mode="RGB") + return Image.open(ref_path / "expected_multi_diffusion.png").convert(mode="RGB") @pytest.fixture def expected_multi_diffusion_dpm(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_multi_diffusion_dpm.png").convert(mode="RGB") + return Image.open(ref_path / "expected_multi_diffusion_dpm.png").convert(mode="RGB") @pytest.fixture def expected_restart(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_restart.png").convert(mode="RGB") + return Image.open(ref_path / "expected_restart.png").convert(mode="RGB") @pytest.fixture def expected_freeu(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_freeu.png").convert(mode="RGB") + return Image.open(ref_path / "expected_freeu.png").convert(mode="RGB") @pytest.fixture def expected_sdxl_multi_loras(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_sdxl_multi_loras.png").convert(mode="RGB") + return Image.open(ref_path / "expected_sdxl_multi_loras.png").convert(mode="RGB") @pytest.fixture @@ -476,10 +472,10 @@ def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image. image_prompt = assets / "dragon_quest_slime.jpg" condition_image = assets / "dropy_canny.png" return ( - _img_open(dropy).convert(mode="RGB"), - _img_open(image_prompt).convert(mode="RGB"), - _img_open(condition_image).convert(mode="RGB"), - _img_open(ref_path / "expected_dropy_slime_9752.png").convert(mode="RGB"), + Image.open(dropy).convert(mode="RGB"), + Image.open(image_prompt).convert(mode="RGB"), + Image.open(condition_image).convert(mode="RGB"), + Image.open(ref_path / "expected_dropy_slime_9752.png").convert(mode="RGB"), ) @@ -2639,7 +2635,7 @@ def test_multi_upscaler_small( @pytest.fixture(scope="module") def expected_ic_light(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_ic_light.png").convert("RGB") + return Image.open(ref_path / "expected_ic_light.png").convert("RGB") @pytest.fixture(scope="module") diff --git a/tests/e2e/test_doc_examples.py b/tests/e2e/test_doc_examples.py index 4e6d4ef..10c08f7 100644 --- a/tests/e2e/test_doc_examples.py +++ b/tests/e2e/test_doc_examples.py @@ -13,10 +13,6 @@ from refiners.foundationals.latent_diffusion.lora import SDLoraManager from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL -def _img_open(path: Path) -> Image.Image: - return Image.open(path) # type: ignore - - @pytest.fixture(autouse=True) def ensure_gc(): # Avoid GPU OOMs @@ -51,27 +47,27 @@ def sdxl( @pytest.fixture def image_prompt_german_castle(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "german-castle.jpg").convert("RGB") + return Image.open(ref_path / "german-castle.jpg").convert("RGB") @pytest.fixture def expected_image_guide_adapting_sdxl_vanilla(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_image_guide_adapting_sdxl_vanilla.png").convert("RGB") + return Image.open(ref_path / "expected_image_guide_adapting_sdxl_vanilla.png").convert("RGB") @pytest.fixture def expected_image_guide_adapting_sdxl_single_lora(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_image_guide_adapting_sdxl_single_lora.png").convert("RGB") + return Image.open(ref_path / "expected_image_guide_adapting_sdxl_single_lora.png").convert("RGB") @pytest.fixture def expected_image_guide_adapting_sdxl_multiple_loras(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_image_guide_adapting_sdxl_multiple_loras.png").convert("RGB") + return Image.open(ref_path / "expected_image_guide_adapting_sdxl_multiple_loras.png").convert("RGB") @pytest.fixture def expected_image_guide_adapting_sdxl_loras_ip_adapter(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_image_guide_adapting_sdxl_loras_ip_adapter.png").convert("RGB") + return Image.open(ref_path / "expected_image_guide_adapting_sdxl_loras_ip_adapter.png").convert("RGB") @no_grad() diff --git a/tests/e2e/test_lcm.py b/tests/e2e/test_lcm.py index c2e03e3..69863be 100644 --- a/tests/e2e/test_lcm.py +++ b/tests/e2e/test_lcm.py @@ -15,10 +15,6 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm_lora import from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL -def _img_open(path: Path) -> Image.Image: - return Image.open(path) # type: ignore - - @pytest.fixture(autouse=True) def ensure_gc(): # Avoid GPU OOMs @@ -33,17 +29,17 @@ def ref_path(test_e2e_path: Path) -> Path: @pytest.fixture def expected_lcm_base(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_lcm_base.png").convert("RGB") + return Image.open(ref_path / "expected_lcm_base.png").convert("RGB") @pytest.fixture def expected_lcm_lora_1_0(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_lcm_lora_1_0.png").convert("RGB") + return Image.open(ref_path / "expected_lcm_lora_1_0.png").convert("RGB") @pytest.fixture def expected_lcm_lora_1_2(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_lcm_lora_1_2.png").convert("RGB") + return Image.open(ref_path / "expected_lcm_lora_1_2.png").convert("RGB") @no_grad() diff --git a/tests/e2e/test_lightning.py b/tests/e2e/test_lightning.py index 0acf358..3547cf4 100644 --- a/tests/e2e/test_lightning.py +++ b/tests/e2e/test_lightning.py @@ -14,10 +14,6 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm_lora import from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL -def _img_open(path: Path) -> Image.Image: - return Image.open(path) # type: ignore - - @pytest.fixture(autouse=True) def ensure_gc(): # Avoid GPU OOMs @@ -32,17 +28,17 @@ def ref_path(test_e2e_path: Path) -> Path: @pytest.fixture def expected_lightning_base_4step(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_lightning_base_4step.png").convert("RGB") + return Image.open(ref_path / "expected_lightning_base_4step.png").convert("RGB") @pytest.fixture def expected_lightning_base_1step(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_lightning_base_1step.png").convert("RGB") + return Image.open(ref_path / "expected_lightning_base_1step.png").convert("RGB") @pytest.fixture def expected_lightning_lora_4step(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_lightning_lora_4step.png").convert("RGB") + return Image.open(ref_path / "expected_lightning_lora_4step.png").convert("RGB") @no_grad() diff --git a/tests/e2e/test_mvanet.py b/tests/e2e/test_mvanet.py index 0022f71..2ee5ce1 100644 --- a/tests/e2e/test_mvanet.py +++ b/tests/e2e/test_mvanet.py @@ -10,10 +10,6 @@ from refiners.fluxion.utils import image_to_tensor, no_grad, normalize, tensor_t from refiners.foundationals.swin.mvanet import MVANet -def _img_open(path: Path) -> Image.Image: - return Image.open(path) # type: ignore - - @pytest.fixture(scope="module") def ref_path(test_e2e_path: Path) -> Path: return test_e2e_path / "test_mvanet_ref" @@ -21,12 +17,12 @@ def ref_path(test_e2e_path: Path) -> Path: @pytest.fixture(scope="module") def ref_cactus(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "cactus.png").convert("RGB") + return Image.open(ref_path / "cactus.png").convert("RGB") @pytest.fixture def expected_cactus_mask(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_cactus_mask.png") + return Image.open(ref_path / "expected_cactus_mask.png") @pytest.fixture diff --git a/tests/e2e/test_preprocessors.py b/tests/e2e/test_preprocessors.py index c5c9ae6..f8c7c00 100644 --- a/tests/e2e/test_preprocessors.py +++ b/tests/e2e/test_preprocessors.py @@ -9,10 +9,6 @@ from refiners.fluxion.utils import image_to_tensor, no_grad, tensor_to_image from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings -def _img_open(path: Path) -> Image.Image: - return Image.open(path) # type: ignore - - @pytest.fixture(scope="module") def diffusion_ref_path(test_e2e_path: Path) -> Path: return test_e2e_path / "test_diffusion_ref" @@ -20,12 +16,12 @@ def diffusion_ref_path(test_e2e_path: Path) -> Path: @pytest.fixture(scope="module") def cutecat_init(diffusion_ref_path: Path) -> Image.Image: - return _img_open(diffusion_ref_path / "cutecat_init.png").convert("RGB") + return Image.open(diffusion_ref_path / "cutecat_init.png").convert("RGB") @pytest.fixture def expected_image_informative_drawings(diffusion_ref_path: Path) -> Image.Image: - return _img_open(diffusion_ref_path / "cutecat_guide_lineart.png").convert("RGB") + return Image.open(diffusion_ref_path / "cutecat_guide_lineart.png").convert("RGB") @pytest.fixture diff --git a/tests/e2e/test_solutions.py b/tests/e2e/test_solutions.py index 5dc47ca..892d5ba 100644 --- a/tests/e2e/test_solutions.py +++ b/tests/e2e/test_solutions.py @@ -8,10 +8,6 @@ from tests.utils import ensure_similar_images from refiners.solutions import BoxSegmenter -def _img_open(path: Path) -> Image.Image: - return Image.open(path) # type: ignore - - @pytest.fixture(scope="module") def ref_path(test_e2e_path: Path) -> Path: return test_e2e_path / "test_solutions_ref" @@ -19,22 +15,22 @@ def ref_path(test_e2e_path: Path) -> Path: @pytest.fixture(scope="module") def ref_shelves(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "shelves.jpg").convert("RGB") + return Image.open(ref_path / "shelves.jpg").convert("RGB") @pytest.fixture def expected_box_segmenter_plant_mask(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_box_segmenter_plant_mask.png") + return Image.open(ref_path / "expected_box_segmenter_plant_mask.png") @pytest.fixture def expected_box_segmenter_spray_mask(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_box_segmenter_spray_mask.png") + return Image.open(ref_path / "expected_box_segmenter_spray_mask.png") @pytest.fixture def expected_box_segmenter_spray_cropped_mask(ref_path: Path) -> Image.Image: - return _img_open(ref_path / "expected_box_segmenter_spray_cropped_mask.png") + return Image.open(ref_path / "expected_box_segmenter_spray_cropped_mask.png") def test_box_segmenter(