mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
update pyright and fix Pillow 10.3 typing issues
This commit is contained in:
parent
328fcb8ed1
commit
fd5a15c7e0
|
@ -73,7 +73,7 @@ build-backend = "hatchling.build"
|
||||||
[tool.rye]
|
[tool.rye]
|
||||||
managed = true
|
managed = true
|
||||||
dev-dependencies = [
|
dev-dependencies = [
|
||||||
"pyright==1.1.349",
|
"pyright==1.1.356",
|
||||||
"docformatter>=1.7.5",
|
"docformatter>=1.7.5",
|
||||||
"pytest>=8.0.0",
|
"pytest>=8.0.0",
|
||||||
"coverage>=7.4.1",
|
"coverage>=7.4.1",
|
||||||
|
|
|
@ -23,9 +23,9 @@ attrs==23.2.0
|
||||||
# via aiohttp
|
# via aiohttp
|
||||||
babel==2.14.0
|
babel==2.14.0
|
||||||
# via mkdocs-material
|
# via mkdocs-material
|
||||||
bitsandbytes==0.42.0
|
bitsandbytes==0.43.0
|
||||||
# via refiners
|
# via refiners
|
||||||
black==24.2.0
|
black==24.3.0
|
||||||
# via refiners
|
# via refiners
|
||||||
certifi==2024.2.2
|
certifi==2024.2.2
|
||||||
# via requests
|
# via requests
|
||||||
|
@ -42,14 +42,14 @@ colorama==0.4.6
|
||||||
# via mkdocs-material
|
# via mkdocs-material
|
||||||
datasets==2.18.0
|
datasets==2.18.0
|
||||||
# via refiners
|
# via refiners
|
||||||
diffusers==0.26.3
|
diffusers==0.27.2
|
||||||
# via refiners
|
# via refiners
|
||||||
dill==0.3.8
|
dill==0.3.8
|
||||||
# via datasets
|
# via datasets
|
||||||
# via multiprocess
|
# via multiprocess
|
||||||
docker-pycreds==0.4.0
|
docker-pycreds==0.4.0
|
||||||
# via wandb
|
# via wandb
|
||||||
filelock==3.13.1
|
filelock==3.13.3
|
||||||
# via datasets
|
# via datasets
|
||||||
# via diffusers
|
# via diffusers
|
||||||
# via huggingface-hub
|
# via huggingface-hub
|
||||||
|
@ -67,23 +67,24 @@ ghp-import==2.1.0
|
||||||
# via mkdocs
|
# via mkdocs
|
||||||
gitdb==4.0.11
|
gitdb==4.0.11
|
||||||
# via gitpython
|
# via gitpython
|
||||||
gitpython==3.1.42
|
gitpython==3.1.43
|
||||||
# via wandb
|
# via wandb
|
||||||
griffe==0.41.3
|
griffe==0.42.1
|
||||||
# via mkdocstrings-python
|
# via mkdocstrings-python
|
||||||
huggingface-hub==0.21.3
|
huggingface-hub==0.22.2
|
||||||
# via datasets
|
# via datasets
|
||||||
# via diffusers
|
# via diffusers
|
||||||
|
# via timm
|
||||||
# via tokenizers
|
# via tokenizers
|
||||||
# via transformers
|
# via transformers
|
||||||
idna==3.6
|
idna==3.6
|
||||||
# via requests
|
# via requests
|
||||||
# via yarl
|
# via yarl
|
||||||
importlib-metadata==7.0.1
|
importlib-metadata==7.1.0
|
||||||
# via diffusers
|
# via diffusers
|
||||||
invisible-watermark==0.2.0
|
invisible-watermark==0.2.0
|
||||||
# via refiners
|
# via refiners
|
||||||
jaxtyping==0.2.25
|
jaxtyping==0.2.28
|
||||||
# via refiners
|
# via refiners
|
||||||
jinja2==3.1.3
|
jinja2==3.1.3
|
||||||
# via mkdocs
|
# via mkdocs
|
||||||
|
@ -92,7 +93,7 @@ jinja2==3.1.3
|
||||||
# via torch
|
# via torch
|
||||||
loguru==0.7.2
|
loguru==0.7.2
|
||||||
# via refiners
|
# via refiners
|
||||||
markdown==3.5.2
|
markdown==3.6
|
||||||
# via mkdocs
|
# via mkdocs
|
||||||
# via mkdocs-autorefs
|
# via mkdocs-autorefs
|
||||||
# via mkdocs-material
|
# via mkdocs-material
|
||||||
|
@ -114,7 +115,7 @@ mkdocs-autorefs==1.0.1
|
||||||
# via mkdocstrings
|
# via mkdocstrings
|
||||||
mkdocs-literate-nav==0.6.1
|
mkdocs-literate-nav==0.6.1
|
||||||
# via refiners
|
# via refiners
|
||||||
mkdocs-material==9.5.12
|
mkdocs-material==9.5.17
|
||||||
# via refiners
|
# via refiners
|
||||||
mkdocs-material-extensions==1.3.1
|
mkdocs-material-extensions==1.3.1
|
||||||
# via mkdocs-material
|
# via mkdocs-material
|
||||||
|
@ -135,6 +136,7 @@ mypy-extensions==1.0.0
|
||||||
networkx==3.2.1
|
networkx==3.2.1
|
||||||
# via torch
|
# via torch
|
||||||
numpy==1.26.4
|
numpy==1.26.4
|
||||||
|
# via bitsandbytes
|
||||||
# via datasets
|
# via datasets
|
||||||
# via diffusers
|
# via diffusers
|
||||||
# via invisible-watermark
|
# via invisible-watermark
|
||||||
|
@ -143,7 +145,6 @@ numpy==1.26.4
|
||||||
# via pandas
|
# via pandas
|
||||||
# via pyarrow
|
# via pyarrow
|
||||||
# via pywavelets
|
# via pywavelets
|
||||||
# via scipy
|
|
||||||
# via torchvision
|
# via torchvision
|
||||||
# via transformers
|
# via transformers
|
||||||
nvidia-cublas-cu12==12.1.3.1
|
nvidia-cublas-cu12==12.1.3.1
|
||||||
|
@ -169,14 +170,14 @@ nvidia-cusparse-cu12==12.1.0.106
|
||||||
# via torch
|
# via torch
|
||||||
nvidia-nccl-cu12==2.19.3
|
nvidia-nccl-cu12==2.19.3
|
||||||
# via torch
|
# via torch
|
||||||
nvidia-nvjitlink-cu12==12.3.101
|
nvidia-nvjitlink-cu12==12.4.99
|
||||||
# via nvidia-cusolver-cu12
|
# via nvidia-cusolver-cu12
|
||||||
# via nvidia-cusparse-cu12
|
# via nvidia-cusparse-cu12
|
||||||
nvidia-nvtx-cu12==12.1.105
|
nvidia-nvtx-cu12==12.1.105
|
||||||
# via torch
|
# via torch
|
||||||
opencv-python==4.9.0.80
|
opencv-python==4.9.0.80
|
||||||
# via invisible-watermark
|
# via invisible-watermark
|
||||||
packaging==23.2
|
packaging==24.0
|
||||||
# via black
|
# via black
|
||||||
# via datasets
|
# via datasets
|
||||||
# via huggingface-hub
|
# via huggingface-hub
|
||||||
|
@ -190,7 +191,7 @@ pandas==2.2.1
|
||||||
pathspec==0.12.1
|
pathspec==0.12.1
|
||||||
# via black
|
# via black
|
||||||
# via mkdocs
|
# via mkdocs
|
||||||
pillow==10.2.0
|
pillow==10.3.0
|
||||||
# via diffusers
|
# via diffusers
|
||||||
# via invisible-watermark
|
# via invisible-watermark
|
||||||
# via refiners
|
# via refiners
|
||||||
|
@ -207,11 +208,11 @@ protobuf==4.25.3
|
||||||
# via wandb
|
# via wandb
|
||||||
psutil==5.9.8
|
psutil==5.9.8
|
||||||
# via wandb
|
# via wandb
|
||||||
pyarrow==15.0.0
|
pyarrow==15.0.2
|
||||||
# via datasets
|
# via datasets
|
||||||
pyarrow-hotfix==0.6
|
pyarrow-hotfix==0.6
|
||||||
# via datasets
|
# via datasets
|
||||||
pydantic==2.6.3
|
pydantic==2.6.4
|
||||||
# via refiners
|
# via refiners
|
||||||
pydantic-core==2.16.3
|
pydantic-core==2.16.3
|
||||||
# via pydantic
|
# via pydantic
|
||||||
|
@ -225,7 +226,7 @@ python-dateutil==2.9.0.post0
|
||||||
# via pandas
|
# via pandas
|
||||||
pytz==2024.1
|
pytz==2024.1
|
||||||
# via pandas
|
# via pandas
|
||||||
pywavelets==1.5.0
|
pywavelets==1.6.0
|
||||||
# via invisible-watermark
|
# via invisible-watermark
|
||||||
pyyaml==6.0.1
|
pyyaml==6.0.1
|
||||||
# via datasets
|
# via datasets
|
||||||
|
@ -233,6 +234,7 @@ pyyaml==6.0.1
|
||||||
# via mkdocs
|
# via mkdocs
|
||||||
# via pymdown-extensions
|
# via pymdown-extensions
|
||||||
# via pyyaml-env-tag
|
# via pyyaml-env-tag
|
||||||
|
# via timm
|
||||||
# via transformers
|
# via transformers
|
||||||
# via wandb
|
# via wandb
|
||||||
pyyaml-env-tag==0.1
|
pyyaml-env-tag==0.1
|
||||||
|
@ -252,17 +254,17 @@ requests==2.31.0
|
||||||
safetensors==0.4.2
|
safetensors==0.4.2
|
||||||
# via diffusers
|
# via diffusers
|
||||||
# via refiners
|
# via refiners
|
||||||
|
# via timm
|
||||||
# via transformers
|
# via transformers
|
||||||
scipy==1.12.0
|
|
||||||
# via bitsandbytes
|
|
||||||
segment-anything-hq==0.3
|
segment-anything-hq==0.3
|
||||||
|
# via refiners
|
||||||
segment-anything-py==1.0
|
segment-anything-py==1.0
|
||||||
# via refiners
|
# via refiners
|
||||||
sentry-sdk==1.40.6
|
sentry-sdk==1.44.0
|
||||||
# via wandb
|
# via wandb
|
||||||
setproctitle==1.3.3
|
setproctitle==1.3.3
|
||||||
# via wandb
|
# via wandb
|
||||||
setuptools==69.1.1
|
setuptools==69.2.0
|
||||||
# via wandb
|
# via wandb
|
||||||
six==1.16.0
|
six==1.16.0
|
||||||
# via docker-pycreds
|
# via docker-pycreds
|
||||||
|
@ -272,26 +274,32 @@ smmap==5.0.1
|
||||||
sympy==1.12
|
sympy==1.12
|
||||||
# via torch
|
# via torch
|
||||||
timm==0.9.16
|
timm==0.9.16
|
||||||
|
# via refiners
|
||||||
tokenizers==0.15.2
|
tokenizers==0.15.2
|
||||||
# via transformers
|
# via transformers
|
||||||
tomli==2.0.1
|
tomli==2.0.1
|
||||||
# via black
|
# via black
|
||||||
# via refiners
|
# via refiners
|
||||||
torch==2.2.1
|
torch==2.2.2
|
||||||
|
# via bitsandbytes
|
||||||
# via invisible-watermark
|
# via invisible-watermark
|
||||||
# via refiners
|
# via refiners
|
||||||
|
# via segment-anything-hq
|
||||||
# via segment-anything-py
|
# via segment-anything-py
|
||||||
|
# via timm
|
||||||
# via torchvision
|
# via torchvision
|
||||||
torchvision==0.17.1
|
torchvision==0.17.2
|
||||||
# via piq
|
# via piq
|
||||||
# via refiners
|
# via refiners
|
||||||
|
# via segment-anything-hq
|
||||||
# via segment-anything-py
|
# via segment-anything-py
|
||||||
|
# via timm
|
||||||
tqdm==4.66.2
|
tqdm==4.66.2
|
||||||
# via datasets
|
# via datasets
|
||||||
# via huggingface-hub
|
# via huggingface-hub
|
||||||
# via refiners
|
# via refiners
|
||||||
# via transformers
|
# via transformers
|
||||||
transformers==4.38.2
|
transformers==4.39.3
|
||||||
# via refiners
|
# via refiners
|
||||||
triton==2.2.0
|
triton==2.2.0
|
||||||
# via torch
|
# via torch
|
||||||
|
@ -300,7 +308,6 @@ typeguard==2.13.3
|
||||||
typing-extensions==4.10.0
|
typing-extensions==4.10.0
|
||||||
# via black
|
# via black
|
||||||
# via huggingface-hub
|
# via huggingface-hub
|
||||||
# via jaxtyping
|
|
||||||
# via pydantic
|
# via pydantic
|
||||||
# via pydantic-core
|
# via pydantic-core
|
||||||
# via torch
|
# via torch
|
||||||
|
@ -309,7 +316,7 @@ tzdata==2024.1
|
||||||
urllib3==2.2.1
|
urllib3==2.2.1
|
||||||
# via requests
|
# via requests
|
||||||
# via sentry-sdk
|
# via sentry-sdk
|
||||||
wandb==0.16.3
|
wandb==0.16.5
|
||||||
# via refiners
|
# via refiners
|
||||||
watchdog==4.0.0
|
watchdog==4.0.0
|
||||||
# via mkdocs
|
# via mkdocs
|
||||||
|
@ -317,5 +324,5 @@ xxhash==3.4.1
|
||||||
# via datasets
|
# via datasets
|
||||||
yarl==1.9.4
|
yarl==1.9.4
|
||||||
# via aiohttp
|
# via aiohttp
|
||||||
zipp==3.17.0
|
zipp==3.18.1
|
||||||
# via importlib-metadata
|
# via importlib-metadata
|
||||||
|
|
|
@ -145,6 +145,7 @@ def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtyp
|
||||||
"""
|
"""
|
||||||
image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype)
|
image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
assert isinstance(image.mode, str) # type: ignore
|
||||||
match image.mode:
|
match image.mode:
|
||||||
case "L":
|
case "L":
|
||||||
image_tensor = image_tensor.unsqueeze(0)
|
image_tensor = image_tensor.unsqueeze(0)
|
||||||
|
|
|
@ -537,8 +537,9 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
mean: The mean to use for normalization.
|
mean: The mean to use for normalization.
|
||||||
std: The standard deviation to use for normalization.
|
std: The standard deviation to use for normalization.
|
||||||
"""
|
"""
|
||||||
|
resized = image.resize(size) # type: ignore
|
||||||
return normalize(
|
return normalize(
|
||||||
image_to_tensor(image.resize(size), device=self.target.device, dtype=self.target.dtype),
|
image_to_tensor(resized, device=self.target.device, dtype=self.target.dtype),
|
||||||
mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean,
|
mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean,
|
||||||
std=[0.26862954, 0.26130258, 0.27577711] if std is None else std,
|
std=[0.26862954, 0.26130258, 0.27577711] if std is None else std,
|
||||||
)
|
)
|
||||||
|
|
|
@ -52,7 +52,8 @@ class LatentDiffusionModel(fl.Module, ABC):
|
||||||
if init_image is None:
|
if init_image is None:
|
||||||
x = noise
|
x = noise
|
||||||
else:
|
else:
|
||||||
encoded_image = self.lda.image_to_latents(image=init_image.resize(size=(width, height)))
|
resized = init_image.resize(size=(width, height)) # type: ignore
|
||||||
|
encoded_image = self.lda.image_to_latents(resized)
|
||||||
x = self.solver.add_noise(
|
x = self.solver.add_noise(
|
||||||
x=encoded_image,
|
x=encoded_image,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
|
|
|
@ -197,8 +197,9 @@ class SegmentAnything(fl.Chain):
|
||||||
h, w = target_size
|
h, w = target_size
|
||||||
padh = self.image_size - h
|
padh = self.image_size - h
|
||||||
padw = self.image_size - w
|
padw = self.image_size - w
|
||||||
|
resized = image.resize((w, h), resample=Image.Resampling.BILINEAR) # type: ignore
|
||||||
image_tensor = torch.tensor(
|
image_tensor = torch.tensor(
|
||||||
np.array(image.resize((w, h), resample=Image.Resampling.BILINEAR)).astype(np.float32).transpose(2, 0, 1),
|
np.array(resized).astype(np.float32).transpose(2, 0, 1),
|
||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
).unsqueeze(0)
|
).unsqueeze(0)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from abc import ABC, abstractmethod, abstractproperty
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cached_property, wraps
|
from functools import cached_property, wraps
|
||||||
from typing import Any, Callable, Generic, Literal, TypeVar, cast
|
from typing import Any, Callable, Generic, Literal, TypeVar, cast
|
||||||
|
@ -300,7 +300,8 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractproperty
|
@property
|
||||||
|
@abstractmethod
|
||||||
def dataset_length(self) -> int:
|
def dataset_length(self) -> int:
|
||||||
"""
|
"""
|
||||||
Returns the length of the dataset.
|
Returns the length of the dataset.
|
||||||
|
|
|
@ -34,6 +34,10 @@ from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAd
|
||||||
from tests.utils import ensure_similar_images
|
from tests.utils import ensure_similar_images
|
||||||
|
|
||||||
|
|
||||||
|
def _img_open(path: Path) -> Image.Image:
|
||||||
|
return Image.open(path) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def ensure_gc():
|
def ensure_gc():
|
||||||
# Avoid GPU OOMs
|
# Avoid GPU OOMs
|
||||||
|
@ -48,112 +52,112 @@ def ref_path(test_e2e_path: Path) -> Path:
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def cutecat_init(ref_path: Path) -> Image.Image:
|
def cutecat_init(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "cutecat_init.png").convert("RGB")
|
return _img_open(ref_path / "cutecat_init.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def kitchen_dog(ref_path: Path) -> Image.Image:
|
def kitchen_dog(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "kitchen_dog.png").convert("RGB")
|
return _img_open(ref_path / "kitchen_dog.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def kitchen_dog_mask(ref_path: Path) -> Image.Image:
|
def kitchen_dog_mask(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "kitchen_dog_mask.png").convert("RGB")
|
return _img_open(ref_path / "kitchen_dog_mask.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def woman_image(ref_path: Path) -> Image.Image:
|
def woman_image(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "woman.png").convert("RGB")
|
return _img_open(ref_path / "woman.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def statue_image(ref_path: Path) -> Image.Image:
|
def statue_image(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "statue.png").convert("RGB")
|
return _img_open(ref_path / "statue.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_std_random_init(ref_path: Path) -> Image.Image:
|
def expected_image_std_random_init(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_std_random_init.png").convert("RGB")
|
return _img_open(ref_path / "expected_std_random_init.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image:
|
def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_std_random_init_euler.png").convert("RGB")
|
return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_karras_random_init(ref_path: Path) -> Image.Image:
|
def expected_karras_random_init(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_karras_random_init.png").convert("RGB")
|
return _img_open(ref_path / "expected_karras_random_init.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_std_random_init_sag(ref_path: Path) -> Image.Image:
|
def expected_image_std_random_init_sag(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_std_random_init_sag.png").convert("RGB")
|
return _img_open(ref_path / "expected_std_random_init_sag.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_std_init_image(ref_path: Path) -> Image.Image:
|
def expected_image_std_init_image(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_std_init_image.png").convert("RGB")
|
return _img_open(ref_path / "expected_std_init_image.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_std_inpainting(ref_path: Path) -> Image.Image:
|
def expected_image_std_inpainting(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_std_inpainting.png").convert("RGB")
|
return _img_open(ref_path / "expected_std_inpainting.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_controlnet_stack(ref_path: Path) -> Image.Image:
|
def expected_image_controlnet_stack(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_controlnet_stack.png").convert("RGB")
|
return _img_open(ref_path / "expected_controlnet_stack.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image:
|
def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB")
|
return _img_open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_ip_adapter_multi(ref_path: Path) -> Image.Image:
|
def expected_image_ip_adapter_multi(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_image_ip_adapter_multi.png").convert("RGB")
|
return _img_open(ref_path / "expected_image_ip_adapter_multi.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_ip_adapter_plus_statue(ref_path: Path) -> Image.Image:
|
def expected_image_ip_adapter_plus_statue(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_image_ip_adapter_plus_statue.png").convert("RGB")
|
return _img_open(ref_path / "expected_image_ip_adapter_plus_statue.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_sdxl_ip_adapter_woman(ref_path: Path) -> Image.Image:
|
def expected_image_sdxl_ip_adapter_woman(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB")
|
return _img_open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_sdxl_ip_adapter_plus_woman(ref_path: Path) -> Image.Image:
|
def expected_image_sdxl_ip_adapter_plus_woman(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_image_sdxl_ip_adapter_plus_woman.png").convert("RGB")
|
return _img_open(ref_path / "expected_image_sdxl_ip_adapter_plus_woman.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_ip_adapter_controlnet(ref_path: Path) -> Image.Image:
|
def expected_image_ip_adapter_controlnet(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_ip_adapter_controlnet.png").convert("RGB")
|
return _img_open(ref_path / "expected_ip_adapter_controlnet.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image:
|
def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert("RGB")
|
return _img_open(ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_sdxl_ddim_random_init_sag(ref_path: Path) -> Image.Image:
|
def expected_sdxl_ddim_random_init_sag(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert("RGB")
|
return _img_open(ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_sdxl_euler_random_init(ref_path: Path) -> Image.Image:
|
def expected_sdxl_euler_random_init(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_cutecat_sdxl_euler_random_init.png").convert("RGB")
|
return _img_open(ref_path / "expected_cutecat_sdxl_euler_random_init.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_style_aligned(ref_path: Path) -> Image.Image:
|
def expected_style_aligned(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(fp=ref_path / "expected_style_aligned.png").convert(mode="RGB")
|
return _img_open(ref_path / "expected_style_aligned.png").convert(mode="RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
|
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
|
||||||
|
@ -161,8 +165,8 @@ def controlnet_data(
|
||||||
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
|
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
|
||||||
) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]:
|
) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]:
|
||||||
cn_name: str = request.param
|
cn_name: str = request.param
|
||||||
condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
||||||
expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
|
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
|
||||||
weights_fn = {
|
weights_fn = {
|
||||||
"depth": "lllyasviel_control_v11f1p_sd15_depth",
|
"depth": "lllyasviel_control_v11f1p_sd15_depth",
|
||||||
"canny": "lllyasviel_control_v11p_sd15_canny",
|
"canny": "lllyasviel_control_v11p_sd15_canny",
|
||||||
|
@ -178,8 +182,8 @@ def controlnet_data(
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
|
def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
|
||||||
cn_name = "canny"
|
cn_name = "canny"
|
||||||
condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
||||||
expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
|
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
|
||||||
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11p_sd15_canny.safetensors"
|
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11p_sd15_canny.safetensors"
|
||||||
return cn_name, condition_image, expected_image, weights_path
|
return cn_name, condition_image, expected_image, weights_path
|
||||||
|
|
||||||
|
@ -187,8 +191,8 @@ def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str,
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def controlnet_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
|
def controlnet_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
|
||||||
cn_name = "depth"
|
cn_name = "depth"
|
||||||
condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
||||||
expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
|
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
|
||||||
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11f1p_sd15_depth.safetensors"
|
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11f1p_sd15_depth.safetensors"
|
||||||
return cn_name, condition_image, expected_image, weights_path
|
return cn_name, condition_image, expected_image, weights_path
|
||||||
|
|
||||||
|
@ -257,12 +261,12 @@ def controllora_sdxl_config(
|
||||||
) -> tuple[Image.Image, dict[str, ControlLoraResolvedConfig]]:
|
) -> tuple[Image.Image, dict[str, ControlLoraResolvedConfig]]:
|
||||||
name: str = request.param[0]
|
name: str = request.param[0]
|
||||||
configs: dict[str, ControlLoraConfig] = request.param[1]
|
configs: dict[str, ControlLoraConfig] = request.param[1]
|
||||||
expected_image = Image.open(ref_path / name).convert("RGB")
|
expected_image = _img_open(ref_path / name).convert("RGB")
|
||||||
|
|
||||||
loaded_configs = {
|
loaded_configs = {
|
||||||
config_name: ControlLoraResolvedConfig(
|
config_name: ControlLoraResolvedConfig(
|
||||||
scale=config.scale,
|
scale=config.scale,
|
||||||
condition_image=Image.open(ref_path / config.condition_path).convert("RGB"),
|
condition_image=_img_open(ref_path / config.condition_path).convert("RGB"),
|
||||||
weights_path=test_weights_path / "control-loras" / config.weights_path,
|
weights_path=test_weights_path / "control-loras" / config.weights_path,
|
||||||
)
|
)
|
||||||
for config_name, config in configs.items()
|
for config_name, config in configs.items()
|
||||||
|
@ -274,8 +278,8 @@ def controllora_sdxl_config(
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
|
def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
|
||||||
name = "depth"
|
name = "depth"
|
||||||
condition_image = Image.open(ref_path / f"cutecat_guide_{name}.png").convert("RGB")
|
condition_image = _img_open(ref_path / f"cutecat_guide_{name}.png").convert("RGB")
|
||||||
expected_image = Image.open(ref_path / f"expected_t2i_adapter_{name}.png").convert("RGB")
|
expected_image = _img_open(ref_path / f"expected_t2i_adapter_{name}.png").convert("RGB")
|
||||||
weights_path = test_weights_path / "T2I-Adapter" / "t2iadapter_depth_sd15v2.safetensors"
|
weights_path = test_weights_path / "T2I-Adapter" / "t2iadapter_depth_sd15v2.safetensors"
|
||||||
return name, condition_image, expected_image, weights_path
|
return name, condition_image, expected_image, weights_path
|
||||||
|
|
||||||
|
@ -283,8 +287,8 @@ def t2i_adapter_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def t2i_adapter_xl_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
|
def t2i_adapter_xl_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
|
||||||
name = "canny"
|
name = "canny"
|
||||||
condition_image = Image.open(ref_path / f"fairy_guide_{name}.png").convert("RGB")
|
condition_image = _img_open(ref_path / f"fairy_guide_{name}.png").convert("RGB")
|
||||||
expected_image = Image.open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB")
|
expected_image = _img_open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB")
|
||||||
weights_path = test_weights_path / "T2I-Adapter" / "t2i-adapter-canny-sdxl-1.0.safetensors"
|
weights_path = test_weights_path / "T2I-Adapter" / "t2i-adapter-canny-sdxl-1.0.safetensors"
|
||||||
|
|
||||||
if not weights_path.is_file():
|
if not weights_path.is_file():
|
||||||
|
@ -296,7 +300,7 @@ def t2i_adapter_xl_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]:
|
def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]:
|
||||||
expected_image = Image.open(ref_path / "expected_lora_pokemon.png").convert("RGB")
|
expected_image = _img_open(ref_path / "expected_lora_pokemon.png").convert("RGB")
|
||||||
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin"
|
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin"
|
||||||
|
|
||||||
if not weights_path.is_file():
|
if not weights_path.is_file():
|
||||||
|
@ -309,7 +313,7 @@ def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Im
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def lora_data_dpo(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]:
|
def lora_data_dpo(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]:
|
||||||
expected_image = Image.open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB")
|
expected_image = _img_open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB")
|
||||||
weights_path = test_weights_path / "loras" / "dpo-lora" / "pytorch_lora_weights.safetensors"
|
weights_path = test_weights_path / "loras" / "dpo-lora" / "pytorch_lora_weights.safetensors"
|
||||||
|
|
||||||
if not weights_path.is_file():
|
if not weights_path.is_file():
|
||||||
|
@ -341,57 +345,57 @@ def lora_sliders(test_weights_path: Path) -> tuple[dict[str, dict[str, torch.Ten
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def scene_image_inpainting_refonly(ref_path: Path) -> Image.Image:
|
def scene_image_inpainting_refonly(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "inpainting-scene.png").convert("RGB")
|
return _img_open(ref_path / "inpainting-scene.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mask_image_inpainting_refonly(ref_path: Path) -> Image.Image:
|
def mask_image_inpainting_refonly(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "inpainting-mask.png").convert("RGB")
|
return _img_open(ref_path / "inpainting-mask.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def target_image_inpainting_refonly(ref_path: Path) -> Image.Image:
|
def target_image_inpainting_refonly(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "inpainting-target.png").convert("RGB")
|
return _img_open(ref_path / "inpainting-target.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_inpainting_refonly(ref_path: Path) -> Image.Image:
|
def expected_image_inpainting_refonly(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_inpainting_refonly.png").convert("RGB")
|
return _img_open(ref_path / "expected_inpainting_refonly.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_refonly(ref_path: Path) -> Image.Image:
|
def expected_image_refonly(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_refonly.png").convert("RGB")
|
return _img_open(ref_path / "expected_refonly.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def condition_image_refonly(ref_path: Path) -> Image.Image:
|
def condition_image_refonly(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "cyberpunk_guide.png").convert("RGB")
|
return _img_open(ref_path / "cyberpunk_guide.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_textual_inversion_random_init(ref_path: Path) -> Image.Image:
|
def expected_image_textual_inversion_random_init(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_textual_inversion_random_init.png").convert("RGB")
|
return _img_open(ref_path / "expected_textual_inversion_random_init.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_multi_diffusion(ref_path: Path) -> Image.Image:
|
def expected_multi_diffusion(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(fp=ref_path / "expected_multi_diffusion.png").convert(mode="RGB")
|
return _img_open(ref_path / "expected_multi_diffusion.png").convert(mode="RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_restart(ref_path: Path) -> Image.Image:
|
def expected_restart(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(fp=ref_path / "expected_restart.png").convert(mode="RGB")
|
return _img_open(ref_path / "expected_restart.png").convert(mode="RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_freeu(ref_path: Path) -> Image.Image:
|
def expected_freeu(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(fp=ref_path / "expected_freeu.png").convert(mode="RGB")
|
return _img_open(ref_path / "expected_freeu.png").convert(mode="RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_sdxl_multi_loras(ref_path: Path) -> Image.Image:
|
def expected_sdxl_multi_loras(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(fp=ref_path / "expected_sdxl_multi_loras.png").convert(mode="RGB")
|
return _img_open(ref_path / "expected_sdxl_multi_loras.png").convert(mode="RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -401,10 +405,10 @@ def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image.
|
||||||
image_prompt = assets / "dragon_quest_slime.jpg"
|
image_prompt = assets / "dragon_quest_slime.jpg"
|
||||||
condition_image = assets / "dropy_canny.png"
|
condition_image = assets / "dropy_canny.png"
|
||||||
return (
|
return (
|
||||||
Image.open(fp=dropy).convert(mode="RGB"),
|
_img_open(dropy).convert(mode="RGB"),
|
||||||
Image.open(fp=image_prompt).convert(mode="RGB"),
|
_img_open(image_prompt).convert(mode="RGB"),
|
||||||
Image.open(fp=condition_image).convert(mode="RGB"),
|
_img_open(condition_image).convert(mode="RGB"),
|
||||||
Image.open(fp=ref_path / "expected_dropy_slime_9752.png").convert(mode="RGB"),
|
_img_open(ref_path / "expected_dropy_slime_9752.png").convert(mode="RGB"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2314,7 +2318,7 @@ def test_style_aligned(
|
||||||
# tile all images horizontally
|
# tile all images horizontally
|
||||||
merged_image = Image.new("RGB", (1024 * len(predicted_images), 1024))
|
merged_image = Image.new("RGB", (1024 * len(predicted_images), 1024))
|
||||||
for i in range(len(predicted_images)):
|
for i in range(len(predicted_images)):
|
||||||
merged_image.paste(predicted_images[i], (i * 1024, 0))
|
merged_image.paste(predicted_images[i], (i * 1024, 0)) # type: ignore
|
||||||
|
|
||||||
# compare against reference image
|
# compare against reference image
|
||||||
ensure_similar_images(merged_image, expected_style_aligned, min_psnr=35, min_ssim=0.99)
|
ensure_similar_images(merged_image, expected_style_aligned, min_psnr=35, min_ssim=0.99)
|
||||||
|
|
|
@ -13,6 +13,10 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import St
|
||||||
from tests.utils import ensure_similar_images
|
from tests.utils import ensure_similar_images
|
||||||
|
|
||||||
|
|
||||||
|
def _img_open(path: Path) -> Image.Image:
|
||||||
|
return Image.open(path) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def ensure_gc():
|
def ensure_gc():
|
||||||
# Avoid GPU OOMs
|
# Avoid GPU OOMs
|
||||||
|
@ -110,27 +114,27 @@ def sdxl(
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def image_prompt_german_castle(ref_path: Path) -> Image.Image:
|
def image_prompt_german_castle(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "german-castle.jpg").convert("RGB")
|
return _img_open(ref_path / "german-castle.jpg").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_guide_adapting_sdxl_vanilla(ref_path: Path) -> Image.Image:
|
def expected_image_guide_adapting_sdxl_vanilla(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_image_guide_adapting_sdxl_vanilla.png").convert("RGB")
|
return _img_open(ref_path / "expected_image_guide_adapting_sdxl_vanilla.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_guide_adapting_sdxl_single_lora(ref_path: Path) -> Image.Image:
|
def expected_image_guide_adapting_sdxl_single_lora(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_image_guide_adapting_sdxl_single_lora.png").convert("RGB")
|
return _img_open(ref_path / "expected_image_guide_adapting_sdxl_single_lora.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_guide_adapting_sdxl_multiple_loras(ref_path: Path) -> Image.Image:
|
def expected_image_guide_adapting_sdxl_multiple_loras(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_image_guide_adapting_sdxl_multiple_loras.png").convert("RGB")
|
return _img_open(ref_path / "expected_image_guide_adapting_sdxl_multiple_loras.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_guide_adapting_sdxl_loras_ip_adapter(ref_path: Path) -> Image.Image:
|
def expected_image_guide_adapting_sdxl_loras_ip_adapter(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_image_guide_adapting_sdxl_loras_ip_adapter.png").convert("RGB")
|
return _img_open(ref_path / "expected_image_guide_adapting_sdxl_loras_ip_adapter.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
|
|
|
@ -15,6 +15,10 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import St
|
||||||
from tests.utils import ensure_similar_images
|
from tests.utils import ensure_similar_images
|
||||||
|
|
||||||
|
|
||||||
|
def _img_open(path: Path) -> Image.Image:
|
||||||
|
return Image.open(path) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def ensure_gc():
|
def ensure_gc():
|
||||||
# Avoid GPU OOMs
|
# Avoid GPU OOMs
|
||||||
|
@ -74,17 +78,17 @@ def ref_path(test_e2e_path: Path) -> Path:
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_lcm_base(ref_path: Path) -> Image.Image:
|
def expected_lcm_base(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_lcm_base.png").convert("RGB")
|
return _img_open(ref_path / "expected_lcm_base.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_lcm_lora_1_0(ref_path: Path) -> Image.Image:
|
def expected_lcm_lora_1_0(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_lcm_lora_1_0.png").convert("RGB")
|
return _img_open(ref_path / "expected_lcm_lora_1_0.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_lcm_lora_1_2(ref_path: Path) -> Image.Image:
|
def expected_lcm_lora_1_2(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_lcm_lora_1_2.png").convert("RGB")
|
return _img_open(ref_path / "expected_lcm_lora_1_2.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
|
|
|
@ -14,6 +14,10 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import St
|
||||||
from tests.utils import ensure_similar_images
|
from tests.utils import ensure_similar_images
|
||||||
|
|
||||||
|
|
||||||
|
def _img_open(path: Path) -> Image.Image:
|
||||||
|
return Image.open(path) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def ensure_gc():
|
def ensure_gc():
|
||||||
# Avoid GPU OOMs
|
# Avoid GPU OOMs
|
||||||
|
@ -82,17 +86,17 @@ def ref_path(test_e2e_path: Path) -> Path:
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_lightning_base_4step(ref_path: Path) -> Image.Image:
|
def expected_lightning_base_4step(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_lightning_base_4step.png").convert("RGB")
|
return _img_open(ref_path / "expected_lightning_base_4step.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_lightning_base_1step(ref_path: Path) -> Image.Image:
|
def expected_lightning_base_1step(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_lightning_base_1step.png").convert("RGB")
|
return _img_open(ref_path / "expected_lightning_base_1step.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_lightning_lora_4step(ref_path: Path) -> Image.Image:
|
def expected_lightning_lora_4step(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_lightning_lora_4step.png").convert("RGB")
|
return _img_open(ref_path / "expected_lightning_lora_4step.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
|
|
|
@ -10,6 +10,10 @@ from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings
|
||||||
from tests.utils import ensure_similar_images
|
from tests.utils import ensure_similar_images
|
||||||
|
|
||||||
|
|
||||||
|
def _img_open(path: Path) -> Image.Image:
|
||||||
|
return Image.open(path) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def diffusion_ref_path(test_e2e_path: Path) -> Path:
|
def diffusion_ref_path(test_e2e_path: Path) -> Path:
|
||||||
return test_e2e_path / "test_diffusion_ref"
|
return test_e2e_path / "test_diffusion_ref"
|
||||||
|
@ -17,12 +21,12 @@ def diffusion_ref_path(test_e2e_path: Path) -> Path:
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def cutecat_init(diffusion_ref_path: Path) -> Image.Image:
|
def cutecat_init(diffusion_ref_path: Path) -> Image.Image:
|
||||||
return Image.open(diffusion_ref_path / "cutecat_init.png").convert("RGB")
|
return _img_open(diffusion_ref_path / "cutecat_init.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_informative_drawings(diffusion_ref_path: Path) -> Image.Image:
|
def expected_image_informative_drawings(diffusion_ref_path: Path) -> Image.Image:
|
||||||
return Image.open(diffusion_ref_path / "cutecat_guide_lineart.png").convert("RGB")
|
return _img_open(diffusion_ref_path / "cutecat_guide_lineart.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
|
|
@ -70,10 +70,10 @@ def test_image_to_tensor() -> None:
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_to_image() -> None:
|
def test_tensor_to_image() -> None:
|
||||||
assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB"
|
assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB" # type: ignore
|
||||||
assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L"
|
assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L" # type: ignore
|
||||||
assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA"
|
assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA" # type: ignore
|
||||||
assert tensor_to_image(torch.zeros(1, 3, 512, 512, dtype=torch.bfloat16)).mode == "RGB"
|
assert tensor_to_image(torch.zeros(1, 3, 512, 512, dtype=torch.bfloat16)).mode == "RGB" # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def test_summarize_tensor() -> None:
|
def test_summarize_tensor() -> None:
|
||||||
|
|
|
@ -33,7 +33,7 @@ def sample_image(ref_path: Path) -> Image.Image:
|
||||||
if not test_image.is_file():
|
if not test_image.is_file():
|
||||||
warn(f"could not reference image at {test_image}, skipping")
|
warn(f"could not reference image at {test_image}, skipping")
|
||||||
pytest.skip(allow_module_level=True)
|
pytest.skip(allow_module_level=True)
|
||||||
img = Image.open(test_image)
|
img = Image.open(test_image) # type: ignore
|
||||||
assert img.size == (512, 512)
|
assert img.size == (512, 512)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ def test_encode_decode_image(encoder: LatentDiffusionAutoencoder, sample_image:
|
||||||
encoded = encoder.image_to_latents(sample_image)
|
encoded = encoder.image_to_latents(sample_image)
|
||||||
decoded = encoder.latents_to_image(encoded)
|
decoded = encoder.latents_to_image(encoded)
|
||||||
|
|
||||||
assert decoded.mode == "RGB"
|
assert decoded.mode == "RGB" # type: ignore
|
||||||
|
|
||||||
# Ensure no saturation. The green channel (band = 1) must not max out.
|
# Ensure no saturation. The green channel (band = 1) must not max out.
|
||||||
assert max(iter(decoded.getdata(band=1))) < 255 # type: ignore
|
assert max(iter(decoded.getdata(band=1))) < 255 # type: ignore
|
||||||
|
|
|
@ -33,7 +33,7 @@ def one_prompt() -> SAMPrompt:
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def tennis(ref_path: Path) -> Image.Image:
|
def tennis(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "tennis.png").convert("RGB")
|
return Image.open(ref_path / "tennis.png").convert("RGB") # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@ -148,7 +148,7 @@ def test_early_vit_embedding(
|
||||||
) -> None:
|
) -> None:
|
||||||
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
||||||
|
|
||||||
image_tensor = image_to_tensor(image=tennis.resize(size=(1024, 1024)))
|
image_tensor = image_to_tensor(image=tennis.resize(size=(1024, 1024))) # type: ignore
|
||||||
|
|
||||||
_ = sam_h.image_encoder(image_tensor.to(sam_h.device))
|
_ = sam_h.image_encoder(image_tensor.to(sam_h.device))
|
||||||
early_vit_embedding_refiners = sam_h.use_context(context_name="hq_sam")["early_vit_embedding"]
|
early_vit_embedding_refiners = sam_h.use_context(context_name="hq_sam")["early_vit_embedding"]
|
||||||
|
@ -284,7 +284,7 @@ def test_predictor_equal(
|
||||||
# and to use image_embedding as input
|
# and to use image_embedding as input
|
||||||
|
|
||||||
size = (1024, 1024)
|
size = (1024, 1024)
|
||||||
resized_tennis = tennis.resize(size)
|
resized_tennis = tennis.resize(size) # type: ignore
|
||||||
|
|
||||||
# Reference
|
# Reference
|
||||||
reference_sam_h_predictor.set_image(np.array(resized_tennis))
|
reference_sam_h_predictor.set_image(np.array(resized_tennis))
|
||||||
|
|
|
@ -91,7 +91,7 @@ def sam_h_single_output(sam_h_weights: Path, test_device: torch.device) -> Segme
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def truck(ref_path: Path) -> Image.Image:
|
def truck(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "truck.jpg").convert("RGB")
|
return Image.open(ref_path / "truck.jpg").convert("RGB") # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
|
@ -140,7 +140,8 @@ def test_multimask_output_error() -> None:
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, truck: Image.Image) -> None:
|
def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, truck: Image.Image) -> None:
|
||||||
image_tensor = image_to_tensor(image=truck.resize(size=(1024, 1024)), device=facebook_sam_h.device)
|
resized = truck.resize(size=(1024, 1024)) # type: ignore
|
||||||
|
image_tensor = image_to_tensor(image=resized, device=facebook_sam_h.device)
|
||||||
y_1 = facebook_sam_h.image_encoder(image_tensor)
|
y_1 = facebook_sam_h.image_encoder(image_tensor)
|
||||||
y_2 = sam_h.image_encoder(image_tensor)
|
y_2 = sam_h.image_encoder(image_tensor)
|
||||||
|
|
||||||
|
@ -451,7 +452,7 @@ def test_predictor_resized_single_output(
|
||||||
|
|
||||||
predictor = facebook_sam_h_predictor
|
predictor = facebook_sam_h_predictor
|
||||||
size = (1024, 1024)
|
size = (1024, 1024)
|
||||||
resized_truck = truck.resize(size)
|
resized_truck = truck.resize(size) # type: ignore
|
||||||
predictor.set_image(np.array(resized_truck))
|
predictor.set_image(np.array(resized_truck))
|
||||||
|
|
||||||
_, _, facebook_low_res_masks = predictor.predict( # type: ignore
|
_, _, facebook_low_res_masks = predictor.predict( # type: ignore
|
||||||
|
|
Loading…
Reference in a new issue