From fd5a15c7e0664907817bba906ab206ea9a30920c Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Tue, 2 Apr 2024 17:30:57 +0200 Subject: [PATCH] update pyright and fix Pillow 10.3 typing issues --- pyproject.toml | 2 +- requirements.lock | 63 +++++----- src/refiners/fluxion/utils.py | 1 + .../latent_diffusion/image_prompt.py | 3 +- .../foundationals/latent_diffusion/model.py | 3 +- .../foundationals/segment_anything/model.py | 3 +- src/refiners/training_utils/trainer.py | 5 +- tests/e2e/test_diffusion.py | 108 +++++++++--------- tests/e2e/test_doc_examples.py | 14 ++- tests/e2e/test_lcm.py | 10 +- tests/e2e/test_lightning.py | 10 +- tests/e2e/test_preprocessors.py | 8 +- tests/fluxion/test_utils.py | 8 +- .../latent_diffusion/test_auto_encoder.py | 4 +- .../segment_anything/test_hq_sam.py | 6 +- .../segment_anything/test_sam.py | 7 +- 16 files changed, 144 insertions(+), 111 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 041b5de..8ea9aac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,7 +73,7 @@ build-backend = "hatchling.build" [tool.rye] managed = true dev-dependencies = [ - "pyright==1.1.349", + "pyright==1.1.356", "docformatter>=1.7.5", "pytest>=8.0.0", "coverage>=7.4.1", diff --git a/requirements.lock b/requirements.lock index 93fdce7..aeb101b 100644 --- a/requirements.lock +++ b/requirements.lock @@ -23,9 +23,9 @@ attrs==23.2.0 # via aiohttp babel==2.14.0 # via mkdocs-material -bitsandbytes==0.42.0 +bitsandbytes==0.43.0 # via refiners -black==24.2.0 +black==24.3.0 # via refiners certifi==2024.2.2 # via requests @@ -42,14 +42,14 @@ colorama==0.4.6 # via mkdocs-material datasets==2.18.0 # via refiners -diffusers==0.26.3 +diffusers==0.27.2 # via refiners dill==0.3.8 # via datasets # via multiprocess docker-pycreds==0.4.0 # via wandb -filelock==3.13.1 +filelock==3.13.3 # via datasets # via diffusers # via huggingface-hub @@ -67,23 +67,24 @@ ghp-import==2.1.0 # via mkdocs gitdb==4.0.11 # via gitpython -gitpython==3.1.42 +gitpython==3.1.43 # via wandb -griffe==0.41.3 +griffe==0.42.1 # via mkdocstrings-python -huggingface-hub==0.21.3 +huggingface-hub==0.22.2 # via datasets # via diffusers + # via timm # via tokenizers # via transformers idna==3.6 # via requests # via yarl -importlib-metadata==7.0.1 +importlib-metadata==7.1.0 # via diffusers invisible-watermark==0.2.0 # via refiners -jaxtyping==0.2.25 +jaxtyping==0.2.28 # via refiners jinja2==3.1.3 # via mkdocs @@ -92,7 +93,7 @@ jinja2==3.1.3 # via torch loguru==0.7.2 # via refiners -markdown==3.5.2 +markdown==3.6 # via mkdocs # via mkdocs-autorefs # via mkdocs-material @@ -114,7 +115,7 @@ mkdocs-autorefs==1.0.1 # via mkdocstrings mkdocs-literate-nav==0.6.1 # via refiners -mkdocs-material==9.5.12 +mkdocs-material==9.5.17 # via refiners mkdocs-material-extensions==1.3.1 # via mkdocs-material @@ -135,6 +136,7 @@ mypy-extensions==1.0.0 networkx==3.2.1 # via torch numpy==1.26.4 + # via bitsandbytes # via datasets # via diffusers # via invisible-watermark @@ -143,7 +145,6 @@ numpy==1.26.4 # via pandas # via pyarrow # via pywavelets - # via scipy # via torchvision # via transformers nvidia-cublas-cu12==12.1.3.1 @@ -169,14 +170,14 @@ nvidia-cusparse-cu12==12.1.0.106 # via torch nvidia-nccl-cu12==2.19.3 # via torch -nvidia-nvjitlink-cu12==12.3.101 +nvidia-nvjitlink-cu12==12.4.99 # via nvidia-cusolver-cu12 # via nvidia-cusparse-cu12 nvidia-nvtx-cu12==12.1.105 # via torch opencv-python==4.9.0.80 # via invisible-watermark -packaging==23.2 +packaging==24.0 # via black # via datasets # via huggingface-hub @@ -190,7 +191,7 @@ pandas==2.2.1 pathspec==0.12.1 # via black # via mkdocs -pillow==10.2.0 +pillow==10.3.0 # via diffusers # via invisible-watermark # via refiners @@ -207,11 +208,11 @@ protobuf==4.25.3 # via wandb psutil==5.9.8 # via wandb -pyarrow==15.0.0 +pyarrow==15.0.2 # via datasets pyarrow-hotfix==0.6 # via datasets -pydantic==2.6.3 +pydantic==2.6.4 # via refiners pydantic-core==2.16.3 # via pydantic @@ -225,7 +226,7 @@ python-dateutil==2.9.0.post0 # via pandas pytz==2024.1 # via pandas -pywavelets==1.5.0 +pywavelets==1.6.0 # via invisible-watermark pyyaml==6.0.1 # via datasets @@ -233,6 +234,7 @@ pyyaml==6.0.1 # via mkdocs # via pymdown-extensions # via pyyaml-env-tag + # via timm # via transformers # via wandb pyyaml-env-tag==0.1 @@ -252,17 +254,17 @@ requests==2.31.0 safetensors==0.4.2 # via diffusers # via refiners + # via timm # via transformers -scipy==1.12.0 - # via bitsandbytes segment-anything-hq==0.3 + # via refiners segment-anything-py==1.0 # via refiners -sentry-sdk==1.40.6 +sentry-sdk==1.44.0 # via wandb setproctitle==1.3.3 # via wandb -setuptools==69.1.1 +setuptools==69.2.0 # via wandb six==1.16.0 # via docker-pycreds @@ -272,26 +274,32 @@ smmap==5.0.1 sympy==1.12 # via torch timm==0.9.16 + # via refiners tokenizers==0.15.2 # via transformers tomli==2.0.1 # via black # via refiners -torch==2.2.1 +torch==2.2.2 + # via bitsandbytes # via invisible-watermark # via refiners + # via segment-anything-hq # via segment-anything-py + # via timm # via torchvision -torchvision==0.17.1 +torchvision==0.17.2 # via piq # via refiners + # via segment-anything-hq # via segment-anything-py + # via timm tqdm==4.66.2 # via datasets # via huggingface-hub # via refiners # via transformers -transformers==4.38.2 +transformers==4.39.3 # via refiners triton==2.2.0 # via torch @@ -300,7 +308,6 @@ typeguard==2.13.3 typing-extensions==4.10.0 # via black # via huggingface-hub - # via jaxtyping # via pydantic # via pydantic-core # via torch @@ -309,7 +316,7 @@ tzdata==2024.1 urllib3==2.2.1 # via requests # via sentry-sdk -wandb==0.16.3 +wandb==0.16.5 # via refiners watchdog==4.0.0 # via mkdocs @@ -317,5 +324,5 @@ xxhash==3.4.1 # via datasets yarl==1.9.4 # via aiohttp -zipp==3.17.0 +zipp==3.18.1 # via importlib-metadata diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index c37df2a..78ad45f 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -145,6 +145,7 @@ def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtyp """ image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype) + assert isinstance(image.mode, str) # type: ignore match image.mode: case "L": image_tensor = image_tensor.unsqueeze(0) diff --git a/src/refiners/foundationals/latent_diffusion/image_prompt.py b/src/refiners/foundationals/latent_diffusion/image_prompt.py index 14d5bde..72b61f1 100644 --- a/src/refiners/foundationals/latent_diffusion/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/image_prompt.py @@ -537,8 +537,9 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): mean: The mean to use for normalization. std: The standard deviation to use for normalization. """ + resized = image.resize(size) # type: ignore return normalize( - image_to_tensor(image.resize(size), device=self.target.device, dtype=self.target.dtype), + image_to_tensor(resized, device=self.target.device, dtype=self.target.dtype), mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean, std=[0.26862954, 0.26130258, 0.27577711] if std is None else std, ) diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index 497b1bd..aea4bcd 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -52,7 +52,8 @@ class LatentDiffusionModel(fl.Module, ABC): if init_image is None: x = noise else: - encoded_image = self.lda.image_to_latents(image=init_image.resize(size=(width, height))) + resized = init_image.resize(size=(width, height)) # type: ignore + encoded_image = self.lda.image_to_latents(resized) x = self.solver.add_noise( x=encoded_image, noise=noise, diff --git a/src/refiners/foundationals/segment_anything/model.py b/src/refiners/foundationals/segment_anything/model.py index b8a50d8..48126f6 100644 --- a/src/refiners/foundationals/segment_anything/model.py +++ b/src/refiners/foundationals/segment_anything/model.py @@ -197,8 +197,9 @@ class SegmentAnything(fl.Chain): h, w = target_size padh = self.image_size - h padw = self.image_size - w + resized = image.resize((w, h), resample=Image.Resampling.BILINEAR) # type: ignore image_tensor = torch.tensor( - np.array(image.resize((w, h), resample=Image.Resampling.BILINEAR)).astype(np.float32).transpose(2, 0, 1), + np.array(resized).astype(np.float32).transpose(2, 0, 1), device=self.device, dtype=self.dtype, ).unsqueeze(0) diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 50c9065..84c8d57 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -1,4 +1,4 @@ -from abc import ABC, abstractmethod, abstractproperty +from abc import ABC, abstractmethod from dataclasses import dataclass from functools import cached_property, wraps from typing import Any, Callable, Generic, Literal, TypeVar, cast @@ -300,7 +300,8 @@ class Trainer(Generic[ConfigType, Batch], ABC): """ ... - @abstractproperty + @property + @abstractmethod def dataset_length(self) -> int: """ Returns the length of the dataset. diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index c4831f8..4431c95 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -34,6 +34,10 @@ from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAd from tests.utils import ensure_similar_images +def _img_open(path: Path) -> Image.Image: + return Image.open(path) # type: ignore + + @pytest.fixture(autouse=True) def ensure_gc(): # Avoid GPU OOMs @@ -48,112 +52,112 @@ def ref_path(test_e2e_path: Path) -> Path: @pytest.fixture(scope="module") def cutecat_init(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "cutecat_init.png").convert("RGB") + return _img_open(ref_path / "cutecat_init.png").convert("RGB") @pytest.fixture(scope="module") def kitchen_dog(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "kitchen_dog.png").convert("RGB") + return _img_open(ref_path / "kitchen_dog.png").convert("RGB") @pytest.fixture(scope="module") def kitchen_dog_mask(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "kitchen_dog_mask.png").convert("RGB") + return _img_open(ref_path / "kitchen_dog_mask.png").convert("RGB") @pytest.fixture(scope="module") def woman_image(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "woman.png").convert("RGB") + return _img_open(ref_path / "woman.png").convert("RGB") @pytest.fixture(scope="module") def statue_image(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "statue.png").convert("RGB") + return _img_open(ref_path / "statue.png").convert("RGB") @pytest.fixture def expected_image_std_random_init(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_std_random_init.png").convert("RGB") + return _img_open(ref_path / "expected_std_random_init.png").convert("RGB") @pytest.fixture def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_std_random_init_euler.png").convert("RGB") + return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB") @pytest.fixture def expected_karras_random_init(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_karras_random_init.png").convert("RGB") + return _img_open(ref_path / "expected_karras_random_init.png").convert("RGB") @pytest.fixture def expected_image_std_random_init_sag(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_std_random_init_sag.png").convert("RGB") + return _img_open(ref_path / "expected_std_random_init_sag.png").convert("RGB") @pytest.fixture def expected_image_std_init_image(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_std_init_image.png").convert("RGB") + return _img_open(ref_path / "expected_std_init_image.png").convert("RGB") @pytest.fixture def expected_image_std_inpainting(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_std_inpainting.png").convert("RGB") + return _img_open(ref_path / "expected_std_inpainting.png").convert("RGB") @pytest.fixture def expected_image_controlnet_stack(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_controlnet_stack.png").convert("RGB") + return _img_open(ref_path / "expected_controlnet_stack.png").convert("RGB") @pytest.fixture def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB") + return _img_open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB") @pytest.fixture def expected_image_ip_adapter_multi(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_image_ip_adapter_multi.png").convert("RGB") + return _img_open(ref_path / "expected_image_ip_adapter_multi.png").convert("RGB") @pytest.fixture def expected_image_ip_adapter_plus_statue(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_image_ip_adapter_plus_statue.png").convert("RGB") + return _img_open(ref_path / "expected_image_ip_adapter_plus_statue.png").convert("RGB") @pytest.fixture def expected_image_sdxl_ip_adapter_woman(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB") + return _img_open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB") @pytest.fixture def expected_image_sdxl_ip_adapter_plus_woman(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_image_sdxl_ip_adapter_plus_woman.png").convert("RGB") + return _img_open(ref_path / "expected_image_sdxl_ip_adapter_plus_woman.png").convert("RGB") @pytest.fixture def expected_image_ip_adapter_controlnet(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_ip_adapter_controlnet.png").convert("RGB") + return _img_open(ref_path / "expected_ip_adapter_controlnet.png").convert("RGB") @pytest.fixture def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert("RGB") + return _img_open(ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert("RGB") @pytest.fixture def expected_sdxl_ddim_random_init_sag(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert("RGB") + return _img_open(ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert("RGB") @pytest.fixture def expected_sdxl_euler_random_init(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_cutecat_sdxl_euler_random_init.png").convert("RGB") + return _img_open(ref_path / "expected_cutecat_sdxl_euler_random_init.png").convert("RGB") @pytest.fixture def expected_style_aligned(ref_path: Path) -> Image.Image: - return Image.open(fp=ref_path / "expected_style_aligned.png").convert(mode="RGB") + return _img_open(ref_path / "expected_style_aligned.png").convert(mode="RGB") @pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"]) @@ -161,8 +165,8 @@ def controlnet_data( ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest ) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]: cn_name: str = request.param - condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") - expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") + condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") + expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") weights_fn = { "depth": "lllyasviel_control_v11f1p_sd15_depth", "canny": "lllyasviel_control_v11p_sd15_canny", @@ -178,8 +182,8 @@ def controlnet_data( @pytest.fixture(scope="module") def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: cn_name = "canny" - condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") - expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") + condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") + expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11p_sd15_canny.safetensors" return cn_name, condition_image, expected_image, weights_path @@ -187,8 +191,8 @@ def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, @pytest.fixture(scope="module") def controlnet_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: cn_name = "depth" - condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") - expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") + condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") + expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11f1p_sd15_depth.safetensors" return cn_name, condition_image, expected_image, weights_path @@ -257,12 +261,12 @@ def controllora_sdxl_config( ) -> tuple[Image.Image, dict[str, ControlLoraResolvedConfig]]: name: str = request.param[0] configs: dict[str, ControlLoraConfig] = request.param[1] - expected_image = Image.open(ref_path / name).convert("RGB") + expected_image = _img_open(ref_path / name).convert("RGB") loaded_configs = { config_name: ControlLoraResolvedConfig( scale=config.scale, - condition_image=Image.open(ref_path / config.condition_path).convert("RGB"), + condition_image=_img_open(ref_path / config.condition_path).convert("RGB"), weights_path=test_weights_path / "control-loras" / config.weights_path, ) for config_name, config in configs.items() @@ -274,8 +278,8 @@ def controllora_sdxl_config( @pytest.fixture(scope="module") def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: name = "depth" - condition_image = Image.open(ref_path / f"cutecat_guide_{name}.png").convert("RGB") - expected_image = Image.open(ref_path / f"expected_t2i_adapter_{name}.png").convert("RGB") + condition_image = _img_open(ref_path / f"cutecat_guide_{name}.png").convert("RGB") + expected_image = _img_open(ref_path / f"expected_t2i_adapter_{name}.png").convert("RGB") weights_path = test_weights_path / "T2I-Adapter" / "t2iadapter_depth_sd15v2.safetensors" return name, condition_image, expected_image, weights_path @@ -283,8 +287,8 @@ def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str @pytest.fixture(scope="module") def t2i_adapter_xl_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: name = "canny" - condition_image = Image.open(ref_path / f"fairy_guide_{name}.png").convert("RGB") - expected_image = Image.open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB") + condition_image = _img_open(ref_path / f"fairy_guide_{name}.png").convert("RGB") + expected_image = _img_open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB") weights_path = test_weights_path / "T2I-Adapter" / "t2i-adapter-canny-sdxl-1.0.safetensors" if not weights_path.is_file(): @@ -296,7 +300,7 @@ def t2i_adapter_xl_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[ @pytest.fixture(scope="module") def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]: - expected_image = Image.open(ref_path / "expected_lora_pokemon.png").convert("RGB") + expected_image = _img_open(ref_path / "expected_lora_pokemon.png").convert("RGB") weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin" if not weights_path.is_file(): @@ -309,7 +313,7 @@ def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Im @pytest.fixture(scope="module") def lora_data_dpo(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]: - expected_image = Image.open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB") + expected_image = _img_open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB") weights_path = test_weights_path / "loras" / "dpo-lora" / "pytorch_lora_weights.safetensors" if not weights_path.is_file(): @@ -341,57 +345,57 @@ def lora_sliders(test_weights_path: Path) -> tuple[dict[str, dict[str, torch.Ten @pytest.fixture def scene_image_inpainting_refonly(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "inpainting-scene.png").convert("RGB") + return _img_open(ref_path / "inpainting-scene.png").convert("RGB") @pytest.fixture def mask_image_inpainting_refonly(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "inpainting-mask.png").convert("RGB") + return _img_open(ref_path / "inpainting-mask.png").convert("RGB") @pytest.fixture def target_image_inpainting_refonly(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "inpainting-target.png").convert("RGB") + return _img_open(ref_path / "inpainting-target.png").convert("RGB") @pytest.fixture def expected_image_inpainting_refonly(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_inpainting_refonly.png").convert("RGB") + return _img_open(ref_path / "expected_inpainting_refonly.png").convert("RGB") @pytest.fixture def expected_image_refonly(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_refonly.png").convert("RGB") + return _img_open(ref_path / "expected_refonly.png").convert("RGB") @pytest.fixture def condition_image_refonly(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "cyberpunk_guide.png").convert("RGB") + return _img_open(ref_path / "cyberpunk_guide.png").convert("RGB") @pytest.fixture def expected_image_textual_inversion_random_init(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_textual_inversion_random_init.png").convert("RGB") + return _img_open(ref_path / "expected_textual_inversion_random_init.png").convert("RGB") @pytest.fixture def expected_multi_diffusion(ref_path: Path) -> Image.Image: - return Image.open(fp=ref_path / "expected_multi_diffusion.png").convert(mode="RGB") + return _img_open(ref_path / "expected_multi_diffusion.png").convert(mode="RGB") @pytest.fixture def expected_restart(ref_path: Path) -> Image.Image: - return Image.open(fp=ref_path / "expected_restart.png").convert(mode="RGB") + return _img_open(ref_path / "expected_restart.png").convert(mode="RGB") @pytest.fixture def expected_freeu(ref_path: Path) -> Image.Image: - return Image.open(fp=ref_path / "expected_freeu.png").convert(mode="RGB") + return _img_open(ref_path / "expected_freeu.png").convert(mode="RGB") @pytest.fixture def expected_sdxl_multi_loras(ref_path: Path) -> Image.Image: - return Image.open(fp=ref_path / "expected_sdxl_multi_loras.png").convert(mode="RGB") + return _img_open(ref_path / "expected_sdxl_multi_loras.png").convert(mode="RGB") @pytest.fixture @@ -401,10 +405,10 @@ def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image. image_prompt = assets / "dragon_quest_slime.jpg" condition_image = assets / "dropy_canny.png" return ( - Image.open(fp=dropy).convert(mode="RGB"), - Image.open(fp=image_prompt).convert(mode="RGB"), - Image.open(fp=condition_image).convert(mode="RGB"), - Image.open(fp=ref_path / "expected_dropy_slime_9752.png").convert(mode="RGB"), + _img_open(dropy).convert(mode="RGB"), + _img_open(image_prompt).convert(mode="RGB"), + _img_open(condition_image).convert(mode="RGB"), + _img_open(ref_path / "expected_dropy_slime_9752.png").convert(mode="RGB"), ) @@ -2314,7 +2318,7 @@ def test_style_aligned( # tile all images horizontally merged_image = Image.new("RGB", (1024 * len(predicted_images), 1024)) for i in range(len(predicted_images)): - merged_image.paste(predicted_images[i], (i * 1024, 0)) + merged_image.paste(predicted_images[i], (i * 1024, 0)) # type: ignore # compare against reference image ensure_similar_images(merged_image, expected_style_aligned, min_psnr=35, min_ssim=0.99) diff --git a/tests/e2e/test_doc_examples.py b/tests/e2e/test_doc_examples.py index c2b51ce..4991848 100644 --- a/tests/e2e/test_doc_examples.py +++ b/tests/e2e/test_doc_examples.py @@ -13,6 +13,10 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import St from tests.utils import ensure_similar_images +def _img_open(path: Path) -> Image.Image: + return Image.open(path) # type: ignore + + @pytest.fixture(autouse=True) def ensure_gc(): # Avoid GPU OOMs @@ -110,27 +114,27 @@ def sdxl( @pytest.fixture def image_prompt_german_castle(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "german-castle.jpg").convert("RGB") + return _img_open(ref_path / "german-castle.jpg").convert("RGB") @pytest.fixture def expected_image_guide_adapting_sdxl_vanilla(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_image_guide_adapting_sdxl_vanilla.png").convert("RGB") + return _img_open(ref_path / "expected_image_guide_adapting_sdxl_vanilla.png").convert("RGB") @pytest.fixture def expected_image_guide_adapting_sdxl_single_lora(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_image_guide_adapting_sdxl_single_lora.png").convert("RGB") + return _img_open(ref_path / "expected_image_guide_adapting_sdxl_single_lora.png").convert("RGB") @pytest.fixture def expected_image_guide_adapting_sdxl_multiple_loras(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_image_guide_adapting_sdxl_multiple_loras.png").convert("RGB") + return _img_open(ref_path / "expected_image_guide_adapting_sdxl_multiple_loras.png").convert("RGB") @pytest.fixture def expected_image_guide_adapting_sdxl_loras_ip_adapter(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_image_guide_adapting_sdxl_loras_ip_adapter.png").convert("RGB") + return _img_open(ref_path / "expected_image_guide_adapting_sdxl_loras_ip_adapter.png").convert("RGB") @no_grad() diff --git a/tests/e2e/test_lcm.py b/tests/e2e/test_lcm.py index 8d52250..af9b2a9 100644 --- a/tests/e2e/test_lcm.py +++ b/tests/e2e/test_lcm.py @@ -15,6 +15,10 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import St from tests.utils import ensure_similar_images +def _img_open(path: Path) -> Image.Image: + return Image.open(path) # type: ignore + + @pytest.fixture(autouse=True) def ensure_gc(): # Avoid GPU OOMs @@ -74,17 +78,17 @@ def ref_path(test_e2e_path: Path) -> Path: @pytest.fixture def expected_lcm_base(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_lcm_base.png").convert("RGB") + return _img_open(ref_path / "expected_lcm_base.png").convert("RGB") @pytest.fixture def expected_lcm_lora_1_0(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_lcm_lora_1_0.png").convert("RGB") + return _img_open(ref_path / "expected_lcm_lora_1_0.png").convert("RGB") @pytest.fixture def expected_lcm_lora_1_2(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_lcm_lora_1_2.png").convert("RGB") + return _img_open(ref_path / "expected_lcm_lora_1_2.png").convert("RGB") @no_grad() diff --git a/tests/e2e/test_lightning.py b/tests/e2e/test_lightning.py index 4f5f6f6..80ea900 100644 --- a/tests/e2e/test_lightning.py +++ b/tests/e2e/test_lightning.py @@ -14,6 +14,10 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import St from tests.utils import ensure_similar_images +def _img_open(path: Path) -> Image.Image: + return Image.open(path) # type: ignore + + @pytest.fixture(autouse=True) def ensure_gc(): # Avoid GPU OOMs @@ -82,17 +86,17 @@ def ref_path(test_e2e_path: Path) -> Path: @pytest.fixture def expected_lightning_base_4step(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_lightning_base_4step.png").convert("RGB") + return _img_open(ref_path / "expected_lightning_base_4step.png").convert("RGB") @pytest.fixture def expected_lightning_base_1step(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_lightning_base_1step.png").convert("RGB") + return _img_open(ref_path / "expected_lightning_base_1step.png").convert("RGB") @pytest.fixture def expected_lightning_lora_4step(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "expected_lightning_lora_4step.png").convert("RGB") + return _img_open(ref_path / "expected_lightning_lora_4step.png").convert("RGB") @no_grad() diff --git a/tests/e2e/test_preprocessors.py b/tests/e2e/test_preprocessors.py index 4492638..068b4fe 100644 --- a/tests/e2e/test_preprocessors.py +++ b/tests/e2e/test_preprocessors.py @@ -10,6 +10,10 @@ from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings from tests.utils import ensure_similar_images +def _img_open(path: Path) -> Image.Image: + return Image.open(path) # type: ignore + + @pytest.fixture(scope="module") def diffusion_ref_path(test_e2e_path: Path) -> Path: return test_e2e_path / "test_diffusion_ref" @@ -17,12 +21,12 @@ def diffusion_ref_path(test_e2e_path: Path) -> Path: @pytest.fixture(scope="module") def cutecat_init(diffusion_ref_path: Path) -> Image.Image: - return Image.open(diffusion_ref_path / "cutecat_init.png").convert("RGB") + return _img_open(diffusion_ref_path / "cutecat_init.png").convert("RGB") @pytest.fixture def expected_image_informative_drawings(diffusion_ref_path: Path) -> Image.Image: - return Image.open(diffusion_ref_path / "cutecat_guide_lineart.png").convert("RGB") + return _img_open(diffusion_ref_path / "cutecat_guide_lineart.png").convert("RGB") @pytest.fixture(scope="module") diff --git a/tests/fluxion/test_utils.py b/tests/fluxion/test_utils.py index c3f6663..3084f7d 100644 --- a/tests/fluxion/test_utils.py +++ b/tests/fluxion/test_utils.py @@ -70,10 +70,10 @@ def test_image_to_tensor() -> None: def test_tensor_to_image() -> None: - assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB" - assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L" - assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA" - assert tensor_to_image(torch.zeros(1, 3, 512, 512, dtype=torch.bfloat16)).mode == "RGB" + assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB" # type: ignore + assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L" # type: ignore + assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA" # type: ignore + assert tensor_to_image(torch.zeros(1, 3, 512, 512, dtype=torch.bfloat16)).mode == "RGB" # type: ignore def test_summarize_tensor() -> None: diff --git a/tests/foundationals/latent_diffusion/test_auto_encoder.py b/tests/foundationals/latent_diffusion/test_auto_encoder.py index 871135b..d140a24 100644 --- a/tests/foundationals/latent_diffusion/test_auto_encoder.py +++ b/tests/foundationals/latent_diffusion/test_auto_encoder.py @@ -33,7 +33,7 @@ def sample_image(ref_path: Path) -> Image.Image: if not test_image.is_file(): warn(f"could not reference image at {test_image}, skipping") pytest.skip(allow_module_level=True) - img = Image.open(test_image) + img = Image.open(test_image) # type: ignore assert img.size == (512, 512) return img @@ -43,7 +43,7 @@ def test_encode_decode_image(encoder: LatentDiffusionAutoencoder, sample_image: encoded = encoder.image_to_latents(sample_image) decoded = encoder.latents_to_image(encoded) - assert decoded.mode == "RGB" + assert decoded.mode == "RGB" # type: ignore # Ensure no saturation. The green channel (band = 1) must not max out. assert max(iter(decoded.getdata(band=1))) < 255 # type: ignore diff --git a/tests/foundationals/segment_anything/test_hq_sam.py b/tests/foundationals/segment_anything/test_hq_sam.py index 228fd55..c412dc5 100644 --- a/tests/foundationals/segment_anything/test_hq_sam.py +++ b/tests/foundationals/segment_anything/test_hq_sam.py @@ -33,7 +33,7 @@ def one_prompt() -> SAMPrompt: @pytest.fixture(scope="module") def tennis(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "tennis.png").convert("RGB") + return Image.open(ref_path / "tennis.png").convert("RGB") # type: ignore @pytest.fixture(scope="module") @@ -148,7 +148,7 @@ def test_early_vit_embedding( ) -> None: HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() - image_tensor = image_to_tensor(image=tennis.resize(size=(1024, 1024))) + image_tensor = image_to_tensor(image=tennis.resize(size=(1024, 1024))) # type: ignore _ = sam_h.image_encoder(image_tensor.to(sam_h.device)) early_vit_embedding_refiners = sam_h.use_context(context_name="hq_sam")["early_vit_embedding"] @@ -284,7 +284,7 @@ def test_predictor_equal( # and to use image_embedding as input size = (1024, 1024) - resized_tennis = tennis.resize(size) + resized_tennis = tennis.resize(size) # type: ignore # Reference reference_sam_h_predictor.set_image(np.array(resized_tennis)) diff --git a/tests/foundationals/segment_anything/test_sam.py b/tests/foundationals/segment_anything/test_sam.py index df37e30..47a5df3 100644 --- a/tests/foundationals/segment_anything/test_sam.py +++ b/tests/foundationals/segment_anything/test_sam.py @@ -91,7 +91,7 @@ def sam_h_single_output(sam_h_weights: Path, test_device: torch.device) -> Segme @pytest.fixture(scope="module") def truck(ref_path: Path) -> Image.Image: - return Image.open(ref_path / "truck.jpg").convert("RGB") + return Image.open(ref_path / "truck.jpg").convert("RGB") # type: ignore @no_grad() @@ -140,7 +140,8 @@ def test_multimask_output_error() -> None: @no_grad() def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, truck: Image.Image) -> None: - image_tensor = image_to_tensor(image=truck.resize(size=(1024, 1024)), device=facebook_sam_h.device) + resized = truck.resize(size=(1024, 1024)) # type: ignore + image_tensor = image_to_tensor(image=resized, device=facebook_sam_h.device) y_1 = facebook_sam_h.image_encoder(image_tensor) y_2 = sam_h.image_encoder(image_tensor) @@ -451,7 +452,7 @@ def test_predictor_resized_single_output( predictor = facebook_sam_h_predictor size = (1024, 1024) - resized_truck = truck.resize(size) + resized_truck = truck.resize(size) # type: ignore predictor.set_image(np.array(resized_truck)) _, _, facebook_low_res_masks = predictor.predict( # type: ignore