update pyright and fix Pillow 10.3 typing issues

This commit is contained in:
Pierre Chapuis 2024-04-02 17:30:57 +02:00
parent 328fcb8ed1
commit fd5a15c7e0
16 changed files with 144 additions and 111 deletions

View file

@ -73,7 +73,7 @@ build-backend = "hatchling.build"
[tool.rye] [tool.rye]
managed = true managed = true
dev-dependencies = [ dev-dependencies = [
"pyright==1.1.349", "pyright==1.1.356",
"docformatter>=1.7.5", "docformatter>=1.7.5",
"pytest>=8.0.0", "pytest>=8.0.0",
"coverage>=7.4.1", "coverage>=7.4.1",

View file

@ -23,9 +23,9 @@ attrs==23.2.0
# via aiohttp # via aiohttp
babel==2.14.0 babel==2.14.0
# via mkdocs-material # via mkdocs-material
bitsandbytes==0.42.0 bitsandbytes==0.43.0
# via refiners # via refiners
black==24.2.0 black==24.3.0
# via refiners # via refiners
certifi==2024.2.2 certifi==2024.2.2
# via requests # via requests
@ -42,14 +42,14 @@ colorama==0.4.6
# via mkdocs-material # via mkdocs-material
datasets==2.18.0 datasets==2.18.0
# via refiners # via refiners
diffusers==0.26.3 diffusers==0.27.2
# via refiners # via refiners
dill==0.3.8 dill==0.3.8
# via datasets # via datasets
# via multiprocess # via multiprocess
docker-pycreds==0.4.0 docker-pycreds==0.4.0
# via wandb # via wandb
filelock==3.13.1 filelock==3.13.3
# via datasets # via datasets
# via diffusers # via diffusers
# via huggingface-hub # via huggingface-hub
@ -67,23 +67,24 @@ ghp-import==2.1.0
# via mkdocs # via mkdocs
gitdb==4.0.11 gitdb==4.0.11
# via gitpython # via gitpython
gitpython==3.1.42 gitpython==3.1.43
# via wandb # via wandb
griffe==0.41.3 griffe==0.42.1
# via mkdocstrings-python # via mkdocstrings-python
huggingface-hub==0.21.3 huggingface-hub==0.22.2
# via datasets # via datasets
# via diffusers # via diffusers
# via timm
# via tokenizers # via tokenizers
# via transformers # via transformers
idna==3.6 idna==3.6
# via requests # via requests
# via yarl # via yarl
importlib-metadata==7.0.1 importlib-metadata==7.1.0
# via diffusers # via diffusers
invisible-watermark==0.2.0 invisible-watermark==0.2.0
# via refiners # via refiners
jaxtyping==0.2.25 jaxtyping==0.2.28
# via refiners # via refiners
jinja2==3.1.3 jinja2==3.1.3
# via mkdocs # via mkdocs
@ -92,7 +93,7 @@ jinja2==3.1.3
# via torch # via torch
loguru==0.7.2 loguru==0.7.2
# via refiners # via refiners
markdown==3.5.2 markdown==3.6
# via mkdocs # via mkdocs
# via mkdocs-autorefs # via mkdocs-autorefs
# via mkdocs-material # via mkdocs-material
@ -114,7 +115,7 @@ mkdocs-autorefs==1.0.1
# via mkdocstrings # via mkdocstrings
mkdocs-literate-nav==0.6.1 mkdocs-literate-nav==0.6.1
# via refiners # via refiners
mkdocs-material==9.5.12 mkdocs-material==9.5.17
# via refiners # via refiners
mkdocs-material-extensions==1.3.1 mkdocs-material-extensions==1.3.1
# via mkdocs-material # via mkdocs-material
@ -135,6 +136,7 @@ mypy-extensions==1.0.0
networkx==3.2.1 networkx==3.2.1
# via torch # via torch
numpy==1.26.4 numpy==1.26.4
# via bitsandbytes
# via datasets # via datasets
# via diffusers # via diffusers
# via invisible-watermark # via invisible-watermark
@ -143,7 +145,6 @@ numpy==1.26.4
# via pandas # via pandas
# via pyarrow # via pyarrow
# via pywavelets # via pywavelets
# via scipy
# via torchvision # via torchvision
# via transformers # via transformers
nvidia-cublas-cu12==12.1.3.1 nvidia-cublas-cu12==12.1.3.1
@ -169,14 +170,14 @@ nvidia-cusparse-cu12==12.1.0.106
# via torch # via torch
nvidia-nccl-cu12==2.19.3 nvidia-nccl-cu12==2.19.3
# via torch # via torch
nvidia-nvjitlink-cu12==12.3.101 nvidia-nvjitlink-cu12==12.4.99
# via nvidia-cusolver-cu12 # via nvidia-cusolver-cu12
# via nvidia-cusparse-cu12 # via nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105 nvidia-nvtx-cu12==12.1.105
# via torch # via torch
opencv-python==4.9.0.80 opencv-python==4.9.0.80
# via invisible-watermark # via invisible-watermark
packaging==23.2 packaging==24.0
# via black # via black
# via datasets # via datasets
# via huggingface-hub # via huggingface-hub
@ -190,7 +191,7 @@ pandas==2.2.1
pathspec==0.12.1 pathspec==0.12.1
# via black # via black
# via mkdocs # via mkdocs
pillow==10.2.0 pillow==10.3.0
# via diffusers # via diffusers
# via invisible-watermark # via invisible-watermark
# via refiners # via refiners
@ -207,11 +208,11 @@ protobuf==4.25.3
# via wandb # via wandb
psutil==5.9.8 psutil==5.9.8
# via wandb # via wandb
pyarrow==15.0.0 pyarrow==15.0.2
# via datasets # via datasets
pyarrow-hotfix==0.6 pyarrow-hotfix==0.6
# via datasets # via datasets
pydantic==2.6.3 pydantic==2.6.4
# via refiners # via refiners
pydantic-core==2.16.3 pydantic-core==2.16.3
# via pydantic # via pydantic
@ -225,7 +226,7 @@ python-dateutil==2.9.0.post0
# via pandas # via pandas
pytz==2024.1 pytz==2024.1
# via pandas # via pandas
pywavelets==1.5.0 pywavelets==1.6.0
# via invisible-watermark # via invisible-watermark
pyyaml==6.0.1 pyyaml==6.0.1
# via datasets # via datasets
@ -233,6 +234,7 @@ pyyaml==6.0.1
# via mkdocs # via mkdocs
# via pymdown-extensions # via pymdown-extensions
# via pyyaml-env-tag # via pyyaml-env-tag
# via timm
# via transformers # via transformers
# via wandb # via wandb
pyyaml-env-tag==0.1 pyyaml-env-tag==0.1
@ -252,17 +254,17 @@ requests==2.31.0
safetensors==0.4.2 safetensors==0.4.2
# via diffusers # via diffusers
# via refiners # via refiners
# via timm
# via transformers # via transformers
scipy==1.12.0
# via bitsandbytes
segment-anything-hq==0.3 segment-anything-hq==0.3
# via refiners
segment-anything-py==1.0 segment-anything-py==1.0
# via refiners # via refiners
sentry-sdk==1.40.6 sentry-sdk==1.44.0
# via wandb # via wandb
setproctitle==1.3.3 setproctitle==1.3.3
# via wandb # via wandb
setuptools==69.1.1 setuptools==69.2.0
# via wandb # via wandb
six==1.16.0 six==1.16.0
# via docker-pycreds # via docker-pycreds
@ -272,26 +274,32 @@ smmap==5.0.1
sympy==1.12 sympy==1.12
# via torch # via torch
timm==0.9.16 timm==0.9.16
# via refiners
tokenizers==0.15.2 tokenizers==0.15.2
# via transformers # via transformers
tomli==2.0.1 tomli==2.0.1
# via black # via black
# via refiners # via refiners
torch==2.2.1 torch==2.2.2
# via bitsandbytes
# via invisible-watermark # via invisible-watermark
# via refiners # via refiners
# via segment-anything-hq
# via segment-anything-py # via segment-anything-py
# via timm
# via torchvision # via torchvision
torchvision==0.17.1 torchvision==0.17.2
# via piq # via piq
# via refiners # via refiners
# via segment-anything-hq
# via segment-anything-py # via segment-anything-py
# via timm
tqdm==4.66.2 tqdm==4.66.2
# via datasets # via datasets
# via huggingface-hub # via huggingface-hub
# via refiners # via refiners
# via transformers # via transformers
transformers==4.38.2 transformers==4.39.3
# via refiners # via refiners
triton==2.2.0 triton==2.2.0
# via torch # via torch
@ -300,7 +308,6 @@ typeguard==2.13.3
typing-extensions==4.10.0 typing-extensions==4.10.0
# via black # via black
# via huggingface-hub # via huggingface-hub
# via jaxtyping
# via pydantic # via pydantic
# via pydantic-core # via pydantic-core
# via torch # via torch
@ -309,7 +316,7 @@ tzdata==2024.1
urllib3==2.2.1 urllib3==2.2.1
# via requests # via requests
# via sentry-sdk # via sentry-sdk
wandb==0.16.3 wandb==0.16.5
# via refiners # via refiners
watchdog==4.0.0 watchdog==4.0.0
# via mkdocs # via mkdocs
@ -317,5 +324,5 @@ xxhash==3.4.1
# via datasets # via datasets
yarl==1.9.4 yarl==1.9.4
# via aiohttp # via aiohttp
zipp==3.17.0 zipp==3.18.1
# via importlib-metadata # via importlib-metadata

View file

@ -145,6 +145,7 @@ def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtyp
""" """
image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype) image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype)
assert isinstance(image.mode, str) # type: ignore
match image.mode: match image.mode:
case "L": case "L":
image_tensor = image_tensor.unsqueeze(0) image_tensor = image_tensor.unsqueeze(0)

View file

@ -537,8 +537,9 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
mean: The mean to use for normalization. mean: The mean to use for normalization.
std: The standard deviation to use for normalization. std: The standard deviation to use for normalization.
""" """
resized = image.resize(size) # type: ignore
return normalize( return normalize(
image_to_tensor(image.resize(size), device=self.target.device, dtype=self.target.dtype), image_to_tensor(resized, device=self.target.device, dtype=self.target.dtype),
mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean, mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean,
std=[0.26862954, 0.26130258, 0.27577711] if std is None else std, std=[0.26862954, 0.26130258, 0.27577711] if std is None else std,
) )

View file

@ -52,7 +52,8 @@ class LatentDiffusionModel(fl.Module, ABC):
if init_image is None: if init_image is None:
x = noise x = noise
else: else:
encoded_image = self.lda.image_to_latents(image=init_image.resize(size=(width, height))) resized = init_image.resize(size=(width, height)) # type: ignore
encoded_image = self.lda.image_to_latents(resized)
x = self.solver.add_noise( x = self.solver.add_noise(
x=encoded_image, x=encoded_image,
noise=noise, noise=noise,

View file

@ -197,8 +197,9 @@ class SegmentAnything(fl.Chain):
h, w = target_size h, w = target_size
padh = self.image_size - h padh = self.image_size - h
padw = self.image_size - w padw = self.image_size - w
resized = image.resize((w, h), resample=Image.Resampling.BILINEAR) # type: ignore
image_tensor = torch.tensor( image_tensor = torch.tensor(
np.array(image.resize((w, h), resample=Image.Resampling.BILINEAR)).astype(np.float32).transpose(2, 0, 1), np.array(resized).astype(np.float32).transpose(2, 0, 1),
device=self.device, device=self.device,
dtype=self.dtype, dtype=self.dtype,
).unsqueeze(0) ).unsqueeze(0)

View file

@ -1,4 +1,4 @@
from abc import ABC, abstractmethod, abstractproperty from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property, wraps from functools import cached_property, wraps
from typing import Any, Callable, Generic, Literal, TypeVar, cast from typing import Any, Callable, Generic, Literal, TypeVar, cast
@ -300,7 +300,8 @@ class Trainer(Generic[ConfigType, Batch], ABC):
""" """
... ...
@abstractproperty @property
@abstractmethod
def dataset_length(self) -> int: def dataset_length(self) -> int:
""" """
Returns the length of the dataset. Returns the length of the dataset.

View file

@ -34,6 +34,10 @@ from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAd
from tests.utils import ensure_similar_images from tests.utils import ensure_similar_images
def _img_open(path: Path) -> Image.Image:
return Image.open(path) # type: ignore
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def ensure_gc(): def ensure_gc():
# Avoid GPU OOMs # Avoid GPU OOMs
@ -48,112 +52,112 @@ def ref_path(test_e2e_path: Path) -> Path:
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def cutecat_init(ref_path: Path) -> Image.Image: def cutecat_init(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "cutecat_init.png").convert("RGB") return _img_open(ref_path / "cutecat_init.png").convert("RGB")
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def kitchen_dog(ref_path: Path) -> Image.Image: def kitchen_dog(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "kitchen_dog.png").convert("RGB") return _img_open(ref_path / "kitchen_dog.png").convert("RGB")
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def kitchen_dog_mask(ref_path: Path) -> Image.Image: def kitchen_dog_mask(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "kitchen_dog_mask.png").convert("RGB") return _img_open(ref_path / "kitchen_dog_mask.png").convert("RGB")
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def woman_image(ref_path: Path) -> Image.Image: def woman_image(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "woman.png").convert("RGB") return _img_open(ref_path / "woman.png").convert("RGB")
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def statue_image(ref_path: Path) -> Image.Image: def statue_image(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "statue.png").convert("RGB") return _img_open(ref_path / "statue.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_std_random_init(ref_path: Path) -> Image.Image: def expected_image_std_random_init(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_std_random_init.png").convert("RGB") return _img_open(ref_path / "expected_std_random_init.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image: def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_std_random_init_euler.png").convert("RGB") return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_karras_random_init(ref_path: Path) -> Image.Image: def expected_karras_random_init(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_karras_random_init.png").convert("RGB") return _img_open(ref_path / "expected_karras_random_init.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_std_random_init_sag(ref_path: Path) -> Image.Image: def expected_image_std_random_init_sag(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_std_random_init_sag.png").convert("RGB") return _img_open(ref_path / "expected_std_random_init_sag.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_std_init_image(ref_path: Path) -> Image.Image: def expected_image_std_init_image(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_std_init_image.png").convert("RGB") return _img_open(ref_path / "expected_std_init_image.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_std_inpainting(ref_path: Path) -> Image.Image: def expected_image_std_inpainting(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_std_inpainting.png").convert("RGB") return _img_open(ref_path / "expected_std_inpainting.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_controlnet_stack(ref_path: Path) -> Image.Image: def expected_image_controlnet_stack(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_controlnet_stack.png").convert("RGB") return _img_open(ref_path / "expected_controlnet_stack.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image: def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB") return _img_open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_ip_adapter_multi(ref_path: Path) -> Image.Image: def expected_image_ip_adapter_multi(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_ip_adapter_multi.png").convert("RGB") return _img_open(ref_path / "expected_image_ip_adapter_multi.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_ip_adapter_plus_statue(ref_path: Path) -> Image.Image: def expected_image_ip_adapter_plus_statue(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_ip_adapter_plus_statue.png").convert("RGB") return _img_open(ref_path / "expected_image_ip_adapter_plus_statue.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_sdxl_ip_adapter_woman(ref_path: Path) -> Image.Image: def expected_image_sdxl_ip_adapter_woman(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB") return _img_open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_sdxl_ip_adapter_plus_woman(ref_path: Path) -> Image.Image: def expected_image_sdxl_ip_adapter_plus_woman(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_sdxl_ip_adapter_plus_woman.png").convert("RGB") return _img_open(ref_path / "expected_image_sdxl_ip_adapter_plus_woman.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_ip_adapter_controlnet(ref_path: Path) -> Image.Image: def expected_image_ip_adapter_controlnet(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_ip_adapter_controlnet.png").convert("RGB") return _img_open(ref_path / "expected_ip_adapter_controlnet.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image: def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert("RGB") return _img_open(ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_sdxl_ddim_random_init_sag(ref_path: Path) -> Image.Image: def expected_sdxl_ddim_random_init_sag(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert("RGB") return _img_open(ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_sdxl_euler_random_init(ref_path: Path) -> Image.Image: def expected_sdxl_euler_random_init(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_cutecat_sdxl_euler_random_init.png").convert("RGB") return _img_open(ref_path / "expected_cutecat_sdxl_euler_random_init.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_style_aligned(ref_path: Path) -> Image.Image: def expected_style_aligned(ref_path: Path) -> Image.Image:
return Image.open(fp=ref_path / "expected_style_aligned.png").convert(mode="RGB") return _img_open(ref_path / "expected_style_aligned.png").convert(mode="RGB")
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"]) @pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
@ -161,8 +165,8 @@ def controlnet_data(
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]: ) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]:
cn_name: str = request.param cn_name: str = request.param
condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
weights_fn = { weights_fn = {
"depth": "lllyasviel_control_v11f1p_sd15_depth", "depth": "lllyasviel_control_v11f1p_sd15_depth",
"canny": "lllyasviel_control_v11p_sd15_canny", "canny": "lllyasviel_control_v11p_sd15_canny",
@ -178,8 +182,8 @@ def controlnet_data(
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
cn_name = "canny" cn_name = "canny"
condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11p_sd15_canny.safetensors" weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11p_sd15_canny.safetensors"
return cn_name, condition_image, expected_image, weights_path return cn_name, condition_image, expected_image, weights_path
@ -187,8 +191,8 @@ def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str,
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def controlnet_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: def controlnet_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
cn_name = "depth" cn_name = "depth"
condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11f1p_sd15_depth.safetensors" weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11f1p_sd15_depth.safetensors"
return cn_name, condition_image, expected_image, weights_path return cn_name, condition_image, expected_image, weights_path
@ -257,12 +261,12 @@ def controllora_sdxl_config(
) -> tuple[Image.Image, dict[str, ControlLoraResolvedConfig]]: ) -> tuple[Image.Image, dict[str, ControlLoraResolvedConfig]]:
name: str = request.param[0] name: str = request.param[0]
configs: dict[str, ControlLoraConfig] = request.param[1] configs: dict[str, ControlLoraConfig] = request.param[1]
expected_image = Image.open(ref_path / name).convert("RGB") expected_image = _img_open(ref_path / name).convert("RGB")
loaded_configs = { loaded_configs = {
config_name: ControlLoraResolvedConfig( config_name: ControlLoraResolvedConfig(
scale=config.scale, scale=config.scale,
condition_image=Image.open(ref_path / config.condition_path).convert("RGB"), condition_image=_img_open(ref_path / config.condition_path).convert("RGB"),
weights_path=test_weights_path / "control-loras" / config.weights_path, weights_path=test_weights_path / "control-loras" / config.weights_path,
) )
for config_name, config in configs.items() for config_name, config in configs.items()
@ -274,8 +278,8 @@ def controllora_sdxl_config(
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
name = "depth" name = "depth"
condition_image = Image.open(ref_path / f"cutecat_guide_{name}.png").convert("RGB") condition_image = _img_open(ref_path / f"cutecat_guide_{name}.png").convert("RGB")
expected_image = Image.open(ref_path / f"expected_t2i_adapter_{name}.png").convert("RGB") expected_image = _img_open(ref_path / f"expected_t2i_adapter_{name}.png").convert("RGB")
weights_path = test_weights_path / "T2I-Adapter" / "t2iadapter_depth_sd15v2.safetensors" weights_path = test_weights_path / "T2I-Adapter" / "t2iadapter_depth_sd15v2.safetensors"
return name, condition_image, expected_image, weights_path return name, condition_image, expected_image, weights_path
@ -283,8 +287,8 @@ def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def t2i_adapter_xl_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: def t2i_adapter_xl_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
name = "canny" name = "canny"
condition_image = Image.open(ref_path / f"fairy_guide_{name}.png").convert("RGB") condition_image = _img_open(ref_path / f"fairy_guide_{name}.png").convert("RGB")
expected_image = Image.open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB") expected_image = _img_open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB")
weights_path = test_weights_path / "T2I-Adapter" / "t2i-adapter-canny-sdxl-1.0.safetensors" weights_path = test_weights_path / "T2I-Adapter" / "t2i-adapter-canny-sdxl-1.0.safetensors"
if not weights_path.is_file(): if not weights_path.is_file():
@ -296,7 +300,7 @@ def t2i_adapter_xl_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]: def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]:
expected_image = Image.open(ref_path / "expected_lora_pokemon.png").convert("RGB") expected_image = _img_open(ref_path / "expected_lora_pokemon.png").convert("RGB")
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin" weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin"
if not weights_path.is_file(): if not weights_path.is_file():
@ -309,7 +313,7 @@ def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Im
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def lora_data_dpo(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]: def lora_data_dpo(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]:
expected_image = Image.open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB") expected_image = _img_open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB")
weights_path = test_weights_path / "loras" / "dpo-lora" / "pytorch_lora_weights.safetensors" weights_path = test_weights_path / "loras" / "dpo-lora" / "pytorch_lora_weights.safetensors"
if not weights_path.is_file(): if not weights_path.is_file():
@ -341,57 +345,57 @@ def lora_sliders(test_weights_path: Path) -> tuple[dict[str, dict[str, torch.Ten
@pytest.fixture @pytest.fixture
def scene_image_inpainting_refonly(ref_path: Path) -> Image.Image: def scene_image_inpainting_refonly(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "inpainting-scene.png").convert("RGB") return _img_open(ref_path / "inpainting-scene.png").convert("RGB")
@pytest.fixture @pytest.fixture
def mask_image_inpainting_refonly(ref_path: Path) -> Image.Image: def mask_image_inpainting_refonly(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "inpainting-mask.png").convert("RGB") return _img_open(ref_path / "inpainting-mask.png").convert("RGB")
@pytest.fixture @pytest.fixture
def target_image_inpainting_refonly(ref_path: Path) -> Image.Image: def target_image_inpainting_refonly(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "inpainting-target.png").convert("RGB") return _img_open(ref_path / "inpainting-target.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_inpainting_refonly(ref_path: Path) -> Image.Image: def expected_image_inpainting_refonly(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_inpainting_refonly.png").convert("RGB") return _img_open(ref_path / "expected_inpainting_refonly.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_refonly(ref_path: Path) -> Image.Image: def expected_image_refonly(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_refonly.png").convert("RGB") return _img_open(ref_path / "expected_refonly.png").convert("RGB")
@pytest.fixture @pytest.fixture
def condition_image_refonly(ref_path: Path) -> Image.Image: def condition_image_refonly(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "cyberpunk_guide.png").convert("RGB") return _img_open(ref_path / "cyberpunk_guide.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_textual_inversion_random_init(ref_path: Path) -> Image.Image: def expected_image_textual_inversion_random_init(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_textual_inversion_random_init.png").convert("RGB") return _img_open(ref_path / "expected_textual_inversion_random_init.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_multi_diffusion(ref_path: Path) -> Image.Image: def expected_multi_diffusion(ref_path: Path) -> Image.Image:
return Image.open(fp=ref_path / "expected_multi_diffusion.png").convert(mode="RGB") return _img_open(ref_path / "expected_multi_diffusion.png").convert(mode="RGB")
@pytest.fixture @pytest.fixture
def expected_restart(ref_path: Path) -> Image.Image: def expected_restart(ref_path: Path) -> Image.Image:
return Image.open(fp=ref_path / "expected_restart.png").convert(mode="RGB") return _img_open(ref_path / "expected_restart.png").convert(mode="RGB")
@pytest.fixture @pytest.fixture
def expected_freeu(ref_path: Path) -> Image.Image: def expected_freeu(ref_path: Path) -> Image.Image:
return Image.open(fp=ref_path / "expected_freeu.png").convert(mode="RGB") return _img_open(ref_path / "expected_freeu.png").convert(mode="RGB")
@pytest.fixture @pytest.fixture
def expected_sdxl_multi_loras(ref_path: Path) -> Image.Image: def expected_sdxl_multi_loras(ref_path: Path) -> Image.Image:
return Image.open(fp=ref_path / "expected_sdxl_multi_loras.png").convert(mode="RGB") return _img_open(ref_path / "expected_sdxl_multi_loras.png").convert(mode="RGB")
@pytest.fixture @pytest.fixture
@ -401,10 +405,10 @@ def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image.
image_prompt = assets / "dragon_quest_slime.jpg" image_prompt = assets / "dragon_quest_slime.jpg"
condition_image = assets / "dropy_canny.png" condition_image = assets / "dropy_canny.png"
return ( return (
Image.open(fp=dropy).convert(mode="RGB"), _img_open(dropy).convert(mode="RGB"),
Image.open(fp=image_prompt).convert(mode="RGB"), _img_open(image_prompt).convert(mode="RGB"),
Image.open(fp=condition_image).convert(mode="RGB"), _img_open(condition_image).convert(mode="RGB"),
Image.open(fp=ref_path / "expected_dropy_slime_9752.png").convert(mode="RGB"), _img_open(ref_path / "expected_dropy_slime_9752.png").convert(mode="RGB"),
) )
@ -2314,7 +2318,7 @@ def test_style_aligned(
# tile all images horizontally # tile all images horizontally
merged_image = Image.new("RGB", (1024 * len(predicted_images), 1024)) merged_image = Image.new("RGB", (1024 * len(predicted_images), 1024))
for i in range(len(predicted_images)): for i in range(len(predicted_images)):
merged_image.paste(predicted_images[i], (i * 1024, 0)) merged_image.paste(predicted_images[i], (i * 1024, 0)) # type: ignore
# compare against reference image # compare against reference image
ensure_similar_images(merged_image, expected_style_aligned, min_psnr=35, min_ssim=0.99) ensure_similar_images(merged_image, expected_style_aligned, min_psnr=35, min_ssim=0.99)

View file

@ -13,6 +13,10 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import St
from tests.utils import ensure_similar_images from tests.utils import ensure_similar_images
def _img_open(path: Path) -> Image.Image:
return Image.open(path) # type: ignore
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def ensure_gc(): def ensure_gc():
# Avoid GPU OOMs # Avoid GPU OOMs
@ -110,27 +114,27 @@ def sdxl(
@pytest.fixture @pytest.fixture
def image_prompt_german_castle(ref_path: Path) -> Image.Image: def image_prompt_german_castle(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "german-castle.jpg").convert("RGB") return _img_open(ref_path / "german-castle.jpg").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_guide_adapting_sdxl_vanilla(ref_path: Path) -> Image.Image: def expected_image_guide_adapting_sdxl_vanilla(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_guide_adapting_sdxl_vanilla.png").convert("RGB") return _img_open(ref_path / "expected_image_guide_adapting_sdxl_vanilla.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_guide_adapting_sdxl_single_lora(ref_path: Path) -> Image.Image: def expected_image_guide_adapting_sdxl_single_lora(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_guide_adapting_sdxl_single_lora.png").convert("RGB") return _img_open(ref_path / "expected_image_guide_adapting_sdxl_single_lora.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_guide_adapting_sdxl_multiple_loras(ref_path: Path) -> Image.Image: def expected_image_guide_adapting_sdxl_multiple_loras(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_guide_adapting_sdxl_multiple_loras.png").convert("RGB") return _img_open(ref_path / "expected_image_guide_adapting_sdxl_multiple_loras.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_guide_adapting_sdxl_loras_ip_adapter(ref_path: Path) -> Image.Image: def expected_image_guide_adapting_sdxl_loras_ip_adapter(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_guide_adapting_sdxl_loras_ip_adapter.png").convert("RGB") return _img_open(ref_path / "expected_image_guide_adapting_sdxl_loras_ip_adapter.png").convert("RGB")
@no_grad() @no_grad()

View file

@ -15,6 +15,10 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import St
from tests.utils import ensure_similar_images from tests.utils import ensure_similar_images
def _img_open(path: Path) -> Image.Image:
return Image.open(path) # type: ignore
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def ensure_gc(): def ensure_gc():
# Avoid GPU OOMs # Avoid GPU OOMs
@ -74,17 +78,17 @@ def ref_path(test_e2e_path: Path) -> Path:
@pytest.fixture @pytest.fixture
def expected_lcm_base(ref_path: Path) -> Image.Image: def expected_lcm_base(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_lcm_base.png").convert("RGB") return _img_open(ref_path / "expected_lcm_base.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_lcm_lora_1_0(ref_path: Path) -> Image.Image: def expected_lcm_lora_1_0(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_lcm_lora_1_0.png").convert("RGB") return _img_open(ref_path / "expected_lcm_lora_1_0.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_lcm_lora_1_2(ref_path: Path) -> Image.Image: def expected_lcm_lora_1_2(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_lcm_lora_1_2.png").convert("RGB") return _img_open(ref_path / "expected_lcm_lora_1_2.png").convert("RGB")
@no_grad() @no_grad()

View file

@ -14,6 +14,10 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import St
from tests.utils import ensure_similar_images from tests.utils import ensure_similar_images
def _img_open(path: Path) -> Image.Image:
return Image.open(path) # type: ignore
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def ensure_gc(): def ensure_gc():
# Avoid GPU OOMs # Avoid GPU OOMs
@ -82,17 +86,17 @@ def ref_path(test_e2e_path: Path) -> Path:
@pytest.fixture @pytest.fixture
def expected_lightning_base_4step(ref_path: Path) -> Image.Image: def expected_lightning_base_4step(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_lightning_base_4step.png").convert("RGB") return _img_open(ref_path / "expected_lightning_base_4step.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_lightning_base_1step(ref_path: Path) -> Image.Image: def expected_lightning_base_1step(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_lightning_base_1step.png").convert("RGB") return _img_open(ref_path / "expected_lightning_base_1step.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_lightning_lora_4step(ref_path: Path) -> Image.Image: def expected_lightning_lora_4step(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_lightning_lora_4step.png").convert("RGB") return _img_open(ref_path / "expected_lightning_lora_4step.png").convert("RGB")
@no_grad() @no_grad()

View file

@ -10,6 +10,10 @@ from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings
from tests.utils import ensure_similar_images from tests.utils import ensure_similar_images
def _img_open(path: Path) -> Image.Image:
return Image.open(path) # type: ignore
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def diffusion_ref_path(test_e2e_path: Path) -> Path: def diffusion_ref_path(test_e2e_path: Path) -> Path:
return test_e2e_path / "test_diffusion_ref" return test_e2e_path / "test_diffusion_ref"
@ -17,12 +21,12 @@ def diffusion_ref_path(test_e2e_path: Path) -> Path:
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def cutecat_init(diffusion_ref_path: Path) -> Image.Image: def cutecat_init(diffusion_ref_path: Path) -> Image.Image:
return Image.open(diffusion_ref_path / "cutecat_init.png").convert("RGB") return _img_open(diffusion_ref_path / "cutecat_init.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_informative_drawings(diffusion_ref_path: Path) -> Image.Image: def expected_image_informative_drawings(diffusion_ref_path: Path) -> Image.Image:
return Image.open(diffusion_ref_path / "cutecat_guide_lineart.png").convert("RGB") return _img_open(diffusion_ref_path / "cutecat_guide_lineart.png").convert("RGB")
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View file

@ -70,10 +70,10 @@ def test_image_to_tensor() -> None:
def test_tensor_to_image() -> None: def test_tensor_to_image() -> None:
assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB" assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB" # type: ignore
assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L" assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L" # type: ignore
assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA" assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA" # type: ignore
assert tensor_to_image(torch.zeros(1, 3, 512, 512, dtype=torch.bfloat16)).mode == "RGB" assert tensor_to_image(torch.zeros(1, 3, 512, 512, dtype=torch.bfloat16)).mode == "RGB" # type: ignore
def test_summarize_tensor() -> None: def test_summarize_tensor() -> None:

View file

@ -33,7 +33,7 @@ def sample_image(ref_path: Path) -> Image.Image:
if not test_image.is_file(): if not test_image.is_file():
warn(f"could not reference image at {test_image}, skipping") warn(f"could not reference image at {test_image}, skipping")
pytest.skip(allow_module_level=True) pytest.skip(allow_module_level=True)
img = Image.open(test_image) img = Image.open(test_image) # type: ignore
assert img.size == (512, 512) assert img.size == (512, 512)
return img return img
@ -43,7 +43,7 @@ def test_encode_decode_image(encoder: LatentDiffusionAutoencoder, sample_image:
encoded = encoder.image_to_latents(sample_image) encoded = encoder.image_to_latents(sample_image)
decoded = encoder.latents_to_image(encoded) decoded = encoder.latents_to_image(encoded)
assert decoded.mode == "RGB" assert decoded.mode == "RGB" # type: ignore
# Ensure no saturation. The green channel (band = 1) must not max out. # Ensure no saturation. The green channel (band = 1) must not max out.
assert max(iter(decoded.getdata(band=1))) < 255 # type: ignore assert max(iter(decoded.getdata(band=1))) < 255 # type: ignore

View file

@ -33,7 +33,7 @@ def one_prompt() -> SAMPrompt:
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def tennis(ref_path: Path) -> Image.Image: def tennis(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "tennis.png").convert("RGB") return Image.open(ref_path / "tennis.png").convert("RGB") # type: ignore
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -148,7 +148,7 @@ def test_early_vit_embedding(
) -> None: ) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
image_tensor = image_to_tensor(image=tennis.resize(size=(1024, 1024))) image_tensor = image_to_tensor(image=tennis.resize(size=(1024, 1024))) # type: ignore
_ = sam_h.image_encoder(image_tensor.to(sam_h.device)) _ = sam_h.image_encoder(image_tensor.to(sam_h.device))
early_vit_embedding_refiners = sam_h.use_context(context_name="hq_sam")["early_vit_embedding"] early_vit_embedding_refiners = sam_h.use_context(context_name="hq_sam")["early_vit_embedding"]
@ -284,7 +284,7 @@ def test_predictor_equal(
# and to use image_embedding as input # and to use image_embedding as input
size = (1024, 1024) size = (1024, 1024)
resized_tennis = tennis.resize(size) resized_tennis = tennis.resize(size) # type: ignore
# Reference # Reference
reference_sam_h_predictor.set_image(np.array(resized_tennis)) reference_sam_h_predictor.set_image(np.array(resized_tennis))

View file

@ -91,7 +91,7 @@ def sam_h_single_output(sam_h_weights: Path, test_device: torch.device) -> Segme
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def truck(ref_path: Path) -> Image.Image: def truck(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "truck.jpg").convert("RGB") return Image.open(ref_path / "truck.jpg").convert("RGB") # type: ignore
@no_grad() @no_grad()
@ -140,7 +140,8 @@ def test_multimask_output_error() -> None:
@no_grad() @no_grad()
def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, truck: Image.Image) -> None: def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, truck: Image.Image) -> None:
image_tensor = image_to_tensor(image=truck.resize(size=(1024, 1024)), device=facebook_sam_h.device) resized = truck.resize(size=(1024, 1024)) # type: ignore
image_tensor = image_to_tensor(image=resized, device=facebook_sam_h.device)
y_1 = facebook_sam_h.image_encoder(image_tensor) y_1 = facebook_sam_h.image_encoder(image_tensor)
y_2 = sam_h.image_encoder(image_tensor) y_2 = sam_h.image_encoder(image_tensor)
@ -451,7 +452,7 @@ def test_predictor_resized_single_output(
predictor = facebook_sam_h_predictor predictor = facebook_sam_h_predictor
size = (1024, 1024) size = (1024, 1024)
resized_truck = truck.resize(size) resized_truck = truck.resize(size) # type: ignore
predictor.set_image(np.array(resized_truck)) predictor.set_image(np.array(resized_truck))
_, _, facebook_low_res_masks = predictor.predict( # type: ignore _, _, facebook_low_res_masks = predictor.predict( # type: ignore