mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
update pyright and fix Pillow 10.3 typing issues
This commit is contained in:
parent
328fcb8ed1
commit
fd5a15c7e0
|
@ -73,7 +73,7 @@ build-backend = "hatchling.build"
|
|||
[tool.rye]
|
||||
managed = true
|
||||
dev-dependencies = [
|
||||
"pyright==1.1.349",
|
||||
"pyright==1.1.356",
|
||||
"docformatter>=1.7.5",
|
||||
"pytest>=8.0.0",
|
||||
"coverage>=7.4.1",
|
||||
|
|
|
@ -23,9 +23,9 @@ attrs==23.2.0
|
|||
# via aiohttp
|
||||
babel==2.14.0
|
||||
# via mkdocs-material
|
||||
bitsandbytes==0.42.0
|
||||
bitsandbytes==0.43.0
|
||||
# via refiners
|
||||
black==24.2.0
|
||||
black==24.3.0
|
||||
# via refiners
|
||||
certifi==2024.2.2
|
||||
# via requests
|
||||
|
@ -42,14 +42,14 @@ colorama==0.4.6
|
|||
# via mkdocs-material
|
||||
datasets==2.18.0
|
||||
# via refiners
|
||||
diffusers==0.26.3
|
||||
diffusers==0.27.2
|
||||
# via refiners
|
||||
dill==0.3.8
|
||||
# via datasets
|
||||
# via multiprocess
|
||||
docker-pycreds==0.4.0
|
||||
# via wandb
|
||||
filelock==3.13.1
|
||||
filelock==3.13.3
|
||||
# via datasets
|
||||
# via diffusers
|
||||
# via huggingface-hub
|
||||
|
@ -67,23 +67,24 @@ ghp-import==2.1.0
|
|||
# via mkdocs
|
||||
gitdb==4.0.11
|
||||
# via gitpython
|
||||
gitpython==3.1.42
|
||||
gitpython==3.1.43
|
||||
# via wandb
|
||||
griffe==0.41.3
|
||||
griffe==0.42.1
|
||||
# via mkdocstrings-python
|
||||
huggingface-hub==0.21.3
|
||||
huggingface-hub==0.22.2
|
||||
# via datasets
|
||||
# via diffusers
|
||||
# via timm
|
||||
# via tokenizers
|
||||
# via transformers
|
||||
idna==3.6
|
||||
# via requests
|
||||
# via yarl
|
||||
importlib-metadata==7.0.1
|
||||
importlib-metadata==7.1.0
|
||||
# via diffusers
|
||||
invisible-watermark==0.2.0
|
||||
# via refiners
|
||||
jaxtyping==0.2.25
|
||||
jaxtyping==0.2.28
|
||||
# via refiners
|
||||
jinja2==3.1.3
|
||||
# via mkdocs
|
||||
|
@ -92,7 +93,7 @@ jinja2==3.1.3
|
|||
# via torch
|
||||
loguru==0.7.2
|
||||
# via refiners
|
||||
markdown==3.5.2
|
||||
markdown==3.6
|
||||
# via mkdocs
|
||||
# via mkdocs-autorefs
|
||||
# via mkdocs-material
|
||||
|
@ -114,7 +115,7 @@ mkdocs-autorefs==1.0.1
|
|||
# via mkdocstrings
|
||||
mkdocs-literate-nav==0.6.1
|
||||
# via refiners
|
||||
mkdocs-material==9.5.12
|
||||
mkdocs-material==9.5.17
|
||||
# via refiners
|
||||
mkdocs-material-extensions==1.3.1
|
||||
# via mkdocs-material
|
||||
|
@ -135,6 +136,7 @@ mypy-extensions==1.0.0
|
|||
networkx==3.2.1
|
||||
# via torch
|
||||
numpy==1.26.4
|
||||
# via bitsandbytes
|
||||
# via datasets
|
||||
# via diffusers
|
||||
# via invisible-watermark
|
||||
|
@ -143,7 +145,6 @@ numpy==1.26.4
|
|||
# via pandas
|
||||
# via pyarrow
|
||||
# via pywavelets
|
||||
# via scipy
|
||||
# via torchvision
|
||||
# via transformers
|
||||
nvidia-cublas-cu12==12.1.3.1
|
||||
|
@ -169,14 +170,14 @@ nvidia-cusparse-cu12==12.1.0.106
|
|||
# via torch
|
||||
nvidia-nccl-cu12==2.19.3
|
||||
# via torch
|
||||
nvidia-nvjitlink-cu12==12.3.101
|
||||
nvidia-nvjitlink-cu12==12.4.99
|
||||
# via nvidia-cusolver-cu12
|
||||
# via nvidia-cusparse-cu12
|
||||
nvidia-nvtx-cu12==12.1.105
|
||||
# via torch
|
||||
opencv-python==4.9.0.80
|
||||
# via invisible-watermark
|
||||
packaging==23.2
|
||||
packaging==24.0
|
||||
# via black
|
||||
# via datasets
|
||||
# via huggingface-hub
|
||||
|
@ -190,7 +191,7 @@ pandas==2.2.1
|
|||
pathspec==0.12.1
|
||||
# via black
|
||||
# via mkdocs
|
||||
pillow==10.2.0
|
||||
pillow==10.3.0
|
||||
# via diffusers
|
||||
# via invisible-watermark
|
||||
# via refiners
|
||||
|
@ -207,11 +208,11 @@ protobuf==4.25.3
|
|||
# via wandb
|
||||
psutil==5.9.8
|
||||
# via wandb
|
||||
pyarrow==15.0.0
|
||||
pyarrow==15.0.2
|
||||
# via datasets
|
||||
pyarrow-hotfix==0.6
|
||||
# via datasets
|
||||
pydantic==2.6.3
|
||||
pydantic==2.6.4
|
||||
# via refiners
|
||||
pydantic-core==2.16.3
|
||||
# via pydantic
|
||||
|
@ -225,7 +226,7 @@ python-dateutil==2.9.0.post0
|
|||
# via pandas
|
||||
pytz==2024.1
|
||||
# via pandas
|
||||
pywavelets==1.5.0
|
||||
pywavelets==1.6.0
|
||||
# via invisible-watermark
|
||||
pyyaml==6.0.1
|
||||
# via datasets
|
||||
|
@ -233,6 +234,7 @@ pyyaml==6.0.1
|
|||
# via mkdocs
|
||||
# via pymdown-extensions
|
||||
# via pyyaml-env-tag
|
||||
# via timm
|
||||
# via transformers
|
||||
# via wandb
|
||||
pyyaml-env-tag==0.1
|
||||
|
@ -252,17 +254,17 @@ requests==2.31.0
|
|||
safetensors==0.4.2
|
||||
# via diffusers
|
||||
# via refiners
|
||||
# via timm
|
||||
# via transformers
|
||||
scipy==1.12.0
|
||||
# via bitsandbytes
|
||||
segment-anything-hq==0.3
|
||||
# via refiners
|
||||
segment-anything-py==1.0
|
||||
# via refiners
|
||||
sentry-sdk==1.40.6
|
||||
sentry-sdk==1.44.0
|
||||
# via wandb
|
||||
setproctitle==1.3.3
|
||||
# via wandb
|
||||
setuptools==69.1.1
|
||||
setuptools==69.2.0
|
||||
# via wandb
|
||||
six==1.16.0
|
||||
# via docker-pycreds
|
||||
|
@ -272,26 +274,32 @@ smmap==5.0.1
|
|||
sympy==1.12
|
||||
# via torch
|
||||
timm==0.9.16
|
||||
# via refiners
|
||||
tokenizers==0.15.2
|
||||
# via transformers
|
||||
tomli==2.0.1
|
||||
# via black
|
||||
# via refiners
|
||||
torch==2.2.1
|
||||
torch==2.2.2
|
||||
# via bitsandbytes
|
||||
# via invisible-watermark
|
||||
# via refiners
|
||||
# via segment-anything-hq
|
||||
# via segment-anything-py
|
||||
# via timm
|
||||
# via torchvision
|
||||
torchvision==0.17.1
|
||||
torchvision==0.17.2
|
||||
# via piq
|
||||
# via refiners
|
||||
# via segment-anything-hq
|
||||
# via segment-anything-py
|
||||
# via timm
|
||||
tqdm==4.66.2
|
||||
# via datasets
|
||||
# via huggingface-hub
|
||||
# via refiners
|
||||
# via transformers
|
||||
transformers==4.38.2
|
||||
transformers==4.39.3
|
||||
# via refiners
|
||||
triton==2.2.0
|
||||
# via torch
|
||||
|
@ -300,7 +308,6 @@ typeguard==2.13.3
|
|||
typing-extensions==4.10.0
|
||||
# via black
|
||||
# via huggingface-hub
|
||||
# via jaxtyping
|
||||
# via pydantic
|
||||
# via pydantic-core
|
||||
# via torch
|
||||
|
@ -309,7 +316,7 @@ tzdata==2024.1
|
|||
urllib3==2.2.1
|
||||
# via requests
|
||||
# via sentry-sdk
|
||||
wandb==0.16.3
|
||||
wandb==0.16.5
|
||||
# via refiners
|
||||
watchdog==4.0.0
|
||||
# via mkdocs
|
||||
|
@ -317,5 +324,5 @@ xxhash==3.4.1
|
|||
# via datasets
|
||||
yarl==1.9.4
|
||||
# via aiohttp
|
||||
zipp==3.17.0
|
||||
zipp==3.18.1
|
||||
# via importlib-metadata
|
||||
|
|
|
@ -145,6 +145,7 @@ def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtyp
|
|||
"""
|
||||
image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype)
|
||||
|
||||
assert isinstance(image.mode, str) # type: ignore
|
||||
match image.mode:
|
||||
case "L":
|
||||
image_tensor = image_tensor.unsqueeze(0)
|
||||
|
|
|
@ -537,8 +537,9 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
mean: The mean to use for normalization.
|
||||
std: The standard deviation to use for normalization.
|
||||
"""
|
||||
resized = image.resize(size) # type: ignore
|
||||
return normalize(
|
||||
image_to_tensor(image.resize(size), device=self.target.device, dtype=self.target.dtype),
|
||||
image_to_tensor(resized, device=self.target.device, dtype=self.target.dtype),
|
||||
mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean,
|
||||
std=[0.26862954, 0.26130258, 0.27577711] if std is None else std,
|
||||
)
|
||||
|
|
|
@ -52,7 +52,8 @@ class LatentDiffusionModel(fl.Module, ABC):
|
|||
if init_image is None:
|
||||
x = noise
|
||||
else:
|
||||
encoded_image = self.lda.image_to_latents(image=init_image.resize(size=(width, height)))
|
||||
resized = init_image.resize(size=(width, height)) # type: ignore
|
||||
encoded_image = self.lda.image_to_latents(resized)
|
||||
x = self.solver.add_noise(
|
||||
x=encoded_image,
|
||||
noise=noise,
|
||||
|
|
|
@ -197,8 +197,9 @@ class SegmentAnything(fl.Chain):
|
|||
h, w = target_size
|
||||
padh = self.image_size - h
|
||||
padw = self.image_size - w
|
||||
resized = image.resize((w, h), resample=Image.Resampling.BILINEAR) # type: ignore
|
||||
image_tensor = torch.tensor(
|
||||
np.array(image.resize((w, h), resample=Image.Resampling.BILINEAR)).astype(np.float32).transpose(2, 0, 1),
|
||||
np.array(resized).astype(np.float32).transpose(2, 0, 1),
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
).unsqueeze(0)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from abc import ABC, abstractmethod, abstractproperty
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property, wraps
|
||||
from typing import Any, Callable, Generic, Literal, TypeVar, cast
|
||||
|
@ -300,7 +300,8 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
"""
|
||||
...
|
||||
|
||||
@abstractproperty
|
||||
@property
|
||||
@abstractmethod
|
||||
def dataset_length(self) -> int:
|
||||
"""
|
||||
Returns the length of the dataset.
|
||||
|
|
|
@ -34,6 +34,10 @@ from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAd
|
|||
from tests.utils import ensure_similar_images
|
||||
|
||||
|
||||
def _img_open(path: Path) -> Image.Image:
|
||||
return Image.open(path) # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def ensure_gc():
|
||||
# Avoid GPU OOMs
|
||||
|
@ -48,112 +52,112 @@ def ref_path(test_e2e_path: Path) -> Path:
|
|||
|
||||
@pytest.fixture(scope="module")
|
||||
def cutecat_init(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "cutecat_init.png").convert("RGB")
|
||||
return _img_open(ref_path / "cutecat_init.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def kitchen_dog(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "kitchen_dog.png").convert("RGB")
|
||||
return _img_open(ref_path / "kitchen_dog.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def kitchen_dog_mask(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "kitchen_dog_mask.png").convert("RGB")
|
||||
return _img_open(ref_path / "kitchen_dog_mask.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def woman_image(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "woman.png").convert("RGB")
|
||||
return _img_open(ref_path / "woman.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def statue_image(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "statue.png").convert("RGB")
|
||||
return _img_open(ref_path / "statue.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_std_random_init(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_std_random_init.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_std_random_init.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_std_random_init_euler.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_karras_random_init(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_karras_random_init.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_karras_random_init.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_std_random_init_sag(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_std_random_init_sag.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_std_random_init_sag.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_std_init_image(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_std_init_image.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_std_init_image.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_std_inpainting(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_std_inpainting.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_std_inpainting.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_controlnet_stack(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_controlnet_stack.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_controlnet_stack.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_ip_adapter_multi(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_image_ip_adapter_multi.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_image_ip_adapter_multi.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_ip_adapter_plus_statue(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_image_ip_adapter_plus_statue.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_image_ip_adapter_plus_statue.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_sdxl_ip_adapter_woman(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_sdxl_ip_adapter_plus_woman(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_image_sdxl_ip_adapter_plus_woman.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_image_sdxl_ip_adapter_plus_woman.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_ip_adapter_controlnet(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_ip_adapter_controlnet.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_ip_adapter_controlnet.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_sdxl_ddim_random_init_sag(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_sdxl_euler_random_init(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_cutecat_sdxl_euler_random_init.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_cutecat_sdxl_euler_random_init.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_style_aligned(ref_path: Path) -> Image.Image:
|
||||
return Image.open(fp=ref_path / "expected_style_aligned.png").convert(mode="RGB")
|
||||
return _img_open(ref_path / "expected_style_aligned.png").convert(mode="RGB")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
|
||||
|
@ -161,8 +165,8 @@ def controlnet_data(
|
|||
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
|
||||
) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]:
|
||||
cn_name: str = request.param
|
||||
condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
||||
expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
|
||||
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
||||
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
|
||||
weights_fn = {
|
||||
"depth": "lllyasviel_control_v11f1p_sd15_depth",
|
||||
"canny": "lllyasviel_control_v11p_sd15_canny",
|
||||
|
@ -178,8 +182,8 @@ def controlnet_data(
|
|||
@pytest.fixture(scope="module")
|
||||
def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
|
||||
cn_name = "canny"
|
||||
condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
||||
expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
|
||||
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
||||
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
|
||||
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11p_sd15_canny.safetensors"
|
||||
return cn_name, condition_image, expected_image, weights_path
|
||||
|
||||
|
@ -187,8 +191,8 @@ def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str,
|
|||
@pytest.fixture(scope="module")
|
||||
def controlnet_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
|
||||
cn_name = "depth"
|
||||
condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
||||
expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
|
||||
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
||||
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
|
||||
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11f1p_sd15_depth.safetensors"
|
||||
return cn_name, condition_image, expected_image, weights_path
|
||||
|
||||
|
@ -257,12 +261,12 @@ def controllora_sdxl_config(
|
|||
) -> tuple[Image.Image, dict[str, ControlLoraResolvedConfig]]:
|
||||
name: str = request.param[0]
|
||||
configs: dict[str, ControlLoraConfig] = request.param[1]
|
||||
expected_image = Image.open(ref_path / name).convert("RGB")
|
||||
expected_image = _img_open(ref_path / name).convert("RGB")
|
||||
|
||||
loaded_configs = {
|
||||
config_name: ControlLoraResolvedConfig(
|
||||
scale=config.scale,
|
||||
condition_image=Image.open(ref_path / config.condition_path).convert("RGB"),
|
||||
condition_image=_img_open(ref_path / config.condition_path).convert("RGB"),
|
||||
weights_path=test_weights_path / "control-loras" / config.weights_path,
|
||||
)
|
||||
for config_name, config in configs.items()
|
||||
|
@ -274,8 +278,8 @@ def controllora_sdxl_config(
|
|||
@pytest.fixture(scope="module")
|
||||
def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
|
||||
name = "depth"
|
||||
condition_image = Image.open(ref_path / f"cutecat_guide_{name}.png").convert("RGB")
|
||||
expected_image = Image.open(ref_path / f"expected_t2i_adapter_{name}.png").convert("RGB")
|
||||
condition_image = _img_open(ref_path / f"cutecat_guide_{name}.png").convert("RGB")
|
||||
expected_image = _img_open(ref_path / f"expected_t2i_adapter_{name}.png").convert("RGB")
|
||||
weights_path = test_weights_path / "T2I-Adapter" / "t2iadapter_depth_sd15v2.safetensors"
|
||||
return name, condition_image, expected_image, weights_path
|
||||
|
||||
|
@ -283,8 +287,8 @@ def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str
|
|||
@pytest.fixture(scope="module")
|
||||
def t2i_adapter_xl_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
|
||||
name = "canny"
|
||||
condition_image = Image.open(ref_path / f"fairy_guide_{name}.png").convert("RGB")
|
||||
expected_image = Image.open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB")
|
||||
condition_image = _img_open(ref_path / f"fairy_guide_{name}.png").convert("RGB")
|
||||
expected_image = _img_open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB")
|
||||
weights_path = test_weights_path / "T2I-Adapter" / "t2i-adapter-canny-sdxl-1.0.safetensors"
|
||||
|
||||
if not weights_path.is_file():
|
||||
|
@ -296,7 +300,7 @@ def t2i_adapter_xl_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[
|
|||
|
||||
@pytest.fixture(scope="module")
|
||||
def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]:
|
||||
expected_image = Image.open(ref_path / "expected_lora_pokemon.png").convert("RGB")
|
||||
expected_image = _img_open(ref_path / "expected_lora_pokemon.png").convert("RGB")
|
||||
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin"
|
||||
|
||||
if not weights_path.is_file():
|
||||
|
@ -309,7 +313,7 @@ def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Im
|
|||
|
||||
@pytest.fixture(scope="module")
|
||||
def lora_data_dpo(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]:
|
||||
expected_image = Image.open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB")
|
||||
expected_image = _img_open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB")
|
||||
weights_path = test_weights_path / "loras" / "dpo-lora" / "pytorch_lora_weights.safetensors"
|
||||
|
||||
if not weights_path.is_file():
|
||||
|
@ -341,57 +345,57 @@ def lora_sliders(test_weights_path: Path) -> tuple[dict[str, dict[str, torch.Ten
|
|||
|
||||
@pytest.fixture
|
||||
def scene_image_inpainting_refonly(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "inpainting-scene.png").convert("RGB")
|
||||
return _img_open(ref_path / "inpainting-scene.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mask_image_inpainting_refonly(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "inpainting-mask.png").convert("RGB")
|
||||
return _img_open(ref_path / "inpainting-mask.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_image_inpainting_refonly(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "inpainting-target.png").convert("RGB")
|
||||
return _img_open(ref_path / "inpainting-target.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_inpainting_refonly(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_inpainting_refonly.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_inpainting_refonly.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_refonly(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_refonly.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_refonly.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def condition_image_refonly(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "cyberpunk_guide.png").convert("RGB")
|
||||
return _img_open(ref_path / "cyberpunk_guide.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_textual_inversion_random_init(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_textual_inversion_random_init.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_textual_inversion_random_init.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_multi_diffusion(ref_path: Path) -> Image.Image:
|
||||
return Image.open(fp=ref_path / "expected_multi_diffusion.png").convert(mode="RGB")
|
||||
return _img_open(ref_path / "expected_multi_diffusion.png").convert(mode="RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_restart(ref_path: Path) -> Image.Image:
|
||||
return Image.open(fp=ref_path / "expected_restart.png").convert(mode="RGB")
|
||||
return _img_open(ref_path / "expected_restart.png").convert(mode="RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_freeu(ref_path: Path) -> Image.Image:
|
||||
return Image.open(fp=ref_path / "expected_freeu.png").convert(mode="RGB")
|
||||
return _img_open(ref_path / "expected_freeu.png").convert(mode="RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_sdxl_multi_loras(ref_path: Path) -> Image.Image:
|
||||
return Image.open(fp=ref_path / "expected_sdxl_multi_loras.png").convert(mode="RGB")
|
||||
return _img_open(ref_path / "expected_sdxl_multi_loras.png").convert(mode="RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -401,10 +405,10 @@ def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image.
|
|||
image_prompt = assets / "dragon_quest_slime.jpg"
|
||||
condition_image = assets / "dropy_canny.png"
|
||||
return (
|
||||
Image.open(fp=dropy).convert(mode="RGB"),
|
||||
Image.open(fp=image_prompt).convert(mode="RGB"),
|
||||
Image.open(fp=condition_image).convert(mode="RGB"),
|
||||
Image.open(fp=ref_path / "expected_dropy_slime_9752.png").convert(mode="RGB"),
|
||||
_img_open(dropy).convert(mode="RGB"),
|
||||
_img_open(image_prompt).convert(mode="RGB"),
|
||||
_img_open(condition_image).convert(mode="RGB"),
|
||||
_img_open(ref_path / "expected_dropy_slime_9752.png").convert(mode="RGB"),
|
||||
)
|
||||
|
||||
|
||||
|
@ -2314,7 +2318,7 @@ def test_style_aligned(
|
|||
# tile all images horizontally
|
||||
merged_image = Image.new("RGB", (1024 * len(predicted_images), 1024))
|
||||
for i in range(len(predicted_images)):
|
||||
merged_image.paste(predicted_images[i], (i * 1024, 0))
|
||||
merged_image.paste(predicted_images[i], (i * 1024, 0)) # type: ignore
|
||||
|
||||
# compare against reference image
|
||||
ensure_similar_images(merged_image, expected_style_aligned, min_psnr=35, min_ssim=0.99)
|
||||
|
|
|
@ -13,6 +13,10 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import St
|
|||
from tests.utils import ensure_similar_images
|
||||
|
||||
|
||||
def _img_open(path: Path) -> Image.Image:
|
||||
return Image.open(path) # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def ensure_gc():
|
||||
# Avoid GPU OOMs
|
||||
|
@ -110,27 +114,27 @@ def sdxl(
|
|||
|
||||
@pytest.fixture
|
||||
def image_prompt_german_castle(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "german-castle.jpg").convert("RGB")
|
||||
return _img_open(ref_path / "german-castle.jpg").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_guide_adapting_sdxl_vanilla(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_image_guide_adapting_sdxl_vanilla.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_image_guide_adapting_sdxl_vanilla.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_guide_adapting_sdxl_single_lora(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_image_guide_adapting_sdxl_single_lora.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_image_guide_adapting_sdxl_single_lora.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_guide_adapting_sdxl_multiple_loras(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_image_guide_adapting_sdxl_multiple_loras.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_image_guide_adapting_sdxl_multiple_loras.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_guide_adapting_sdxl_loras_ip_adapter(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_image_guide_adapting_sdxl_loras_ip_adapter.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_image_guide_adapting_sdxl_loras_ip_adapter.png").convert("RGB")
|
||||
|
||||
|
||||
@no_grad()
|
||||
|
|
|
@ -15,6 +15,10 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import St
|
|||
from tests.utils import ensure_similar_images
|
||||
|
||||
|
||||
def _img_open(path: Path) -> Image.Image:
|
||||
return Image.open(path) # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def ensure_gc():
|
||||
# Avoid GPU OOMs
|
||||
|
@ -74,17 +78,17 @@ def ref_path(test_e2e_path: Path) -> Path:
|
|||
|
||||
@pytest.fixture
|
||||
def expected_lcm_base(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_lcm_base.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_lcm_base.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_lcm_lora_1_0(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_lcm_lora_1_0.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_lcm_lora_1_0.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_lcm_lora_1_2(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_lcm_lora_1_2.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_lcm_lora_1_2.png").convert("RGB")
|
||||
|
||||
|
||||
@no_grad()
|
||||
|
|
|
@ -14,6 +14,10 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import St
|
|||
from tests.utils import ensure_similar_images
|
||||
|
||||
|
||||
def _img_open(path: Path) -> Image.Image:
|
||||
return Image.open(path) # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def ensure_gc():
|
||||
# Avoid GPU OOMs
|
||||
|
@ -82,17 +86,17 @@ def ref_path(test_e2e_path: Path) -> Path:
|
|||
|
||||
@pytest.fixture
|
||||
def expected_lightning_base_4step(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_lightning_base_4step.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_lightning_base_4step.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_lightning_base_1step(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_lightning_base_1step.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_lightning_base_1step.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_lightning_lora_4step(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_lightning_lora_4step.png").convert("RGB")
|
||||
return _img_open(ref_path / "expected_lightning_lora_4step.png").convert("RGB")
|
||||
|
||||
|
||||
@no_grad()
|
||||
|
|
|
@ -10,6 +10,10 @@ from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings
|
|||
from tests.utils import ensure_similar_images
|
||||
|
||||
|
||||
def _img_open(path: Path) -> Image.Image:
|
||||
return Image.open(path) # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def diffusion_ref_path(test_e2e_path: Path) -> Path:
|
||||
return test_e2e_path / "test_diffusion_ref"
|
||||
|
@ -17,12 +21,12 @@ def diffusion_ref_path(test_e2e_path: Path) -> Path:
|
|||
|
||||
@pytest.fixture(scope="module")
|
||||
def cutecat_init(diffusion_ref_path: Path) -> Image.Image:
|
||||
return Image.open(diffusion_ref_path / "cutecat_init.png").convert("RGB")
|
||||
return _img_open(diffusion_ref_path / "cutecat_init.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_informative_drawings(diffusion_ref_path: Path) -> Image.Image:
|
||||
return Image.open(diffusion_ref_path / "cutecat_guide_lineart.png").convert("RGB")
|
||||
return _img_open(diffusion_ref_path / "cutecat_guide_lineart.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
|
|
@ -70,10 +70,10 @@ def test_image_to_tensor() -> None:
|
|||
|
||||
|
||||
def test_tensor_to_image() -> None:
|
||||
assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB"
|
||||
assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L"
|
||||
assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA"
|
||||
assert tensor_to_image(torch.zeros(1, 3, 512, 512, dtype=torch.bfloat16)).mode == "RGB"
|
||||
assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB" # type: ignore
|
||||
assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L" # type: ignore
|
||||
assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA" # type: ignore
|
||||
assert tensor_to_image(torch.zeros(1, 3, 512, 512, dtype=torch.bfloat16)).mode == "RGB" # type: ignore
|
||||
|
||||
|
||||
def test_summarize_tensor() -> None:
|
||||
|
|
|
@ -33,7 +33,7 @@ def sample_image(ref_path: Path) -> Image.Image:
|
|||
if not test_image.is_file():
|
||||
warn(f"could not reference image at {test_image}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
img = Image.open(test_image)
|
||||
img = Image.open(test_image) # type: ignore
|
||||
assert img.size == (512, 512)
|
||||
return img
|
||||
|
||||
|
@ -43,7 +43,7 @@ def test_encode_decode_image(encoder: LatentDiffusionAutoencoder, sample_image:
|
|||
encoded = encoder.image_to_latents(sample_image)
|
||||
decoded = encoder.latents_to_image(encoded)
|
||||
|
||||
assert decoded.mode == "RGB"
|
||||
assert decoded.mode == "RGB" # type: ignore
|
||||
|
||||
# Ensure no saturation. The green channel (band = 1) must not max out.
|
||||
assert max(iter(decoded.getdata(band=1))) < 255 # type: ignore
|
||||
|
|
|
@ -33,7 +33,7 @@ def one_prompt() -> SAMPrompt:
|
|||
|
||||
@pytest.fixture(scope="module")
|
||||
def tennis(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "tennis.png").convert("RGB")
|
||||
return Image.open(ref_path / "tennis.png").convert("RGB") # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
@ -148,7 +148,7 @@ def test_early_vit_embedding(
|
|||
) -> None:
|
||||
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
||||
|
||||
image_tensor = image_to_tensor(image=tennis.resize(size=(1024, 1024)))
|
||||
image_tensor = image_to_tensor(image=tennis.resize(size=(1024, 1024))) # type: ignore
|
||||
|
||||
_ = sam_h.image_encoder(image_tensor.to(sam_h.device))
|
||||
early_vit_embedding_refiners = sam_h.use_context(context_name="hq_sam")["early_vit_embedding"]
|
||||
|
@ -284,7 +284,7 @@ def test_predictor_equal(
|
|||
# and to use image_embedding as input
|
||||
|
||||
size = (1024, 1024)
|
||||
resized_tennis = tennis.resize(size)
|
||||
resized_tennis = tennis.resize(size) # type: ignore
|
||||
|
||||
# Reference
|
||||
reference_sam_h_predictor.set_image(np.array(resized_tennis))
|
||||
|
|
|
@ -91,7 +91,7 @@ def sam_h_single_output(sam_h_weights: Path, test_device: torch.device) -> Segme
|
|||
|
||||
@pytest.fixture(scope="module")
|
||||
def truck(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "truck.jpg").convert("RGB")
|
||||
return Image.open(ref_path / "truck.jpg").convert("RGB") # type: ignore
|
||||
|
||||
|
||||
@no_grad()
|
||||
|
@ -140,7 +140,8 @@ def test_multimask_output_error() -> None:
|
|||
|
||||
@no_grad()
|
||||
def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, truck: Image.Image) -> None:
|
||||
image_tensor = image_to_tensor(image=truck.resize(size=(1024, 1024)), device=facebook_sam_h.device)
|
||||
resized = truck.resize(size=(1024, 1024)) # type: ignore
|
||||
image_tensor = image_to_tensor(image=resized, device=facebook_sam_h.device)
|
||||
y_1 = facebook_sam_h.image_encoder(image_tensor)
|
||||
y_2 = sam_h.image_encoder(image_tensor)
|
||||
|
||||
|
@ -451,7 +452,7 @@ def test_predictor_resized_single_output(
|
|||
|
||||
predictor = facebook_sam_h_predictor
|
||||
size = (1024, 1024)
|
||||
resized_truck = truck.resize(size)
|
||||
resized_truck = truck.resize(size) # type: ignore
|
||||
predictor.set_image(np.array(resized_truck))
|
||||
|
||||
_, _, facebook_low_res_masks = predictor.predict( # type: ignore
|
||||
|
|
Loading…
Reference in a new issue