refiners/tests/foundationals/latent_diffusion/test_auto_encoder.py
Colle 4a6146bb6c clip text, lda encode batch inputs
* text_encoder([str1, str2])
* lda decode_latents/encode_image image_to_latent/latent_to_image
* images_to_tensor, tensor_to_images
---------
Co-authored-by: Cédric Deltheil <355031+deltheil@users.noreply.github.com>
2024-02-01 17:05:28 +01:00

61 lines
2.1 KiB
Python

from pathlib import Path
from warnings import warn
import pytest
import torch
from PIL import Image
from tests.utils import ensure_similar_images
from refiners.fluxion.utils import load_from_safetensors, no_grad
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
@pytest.fixture(scope="module")
def ref_path() -> Path:
return Path(__file__).parent / "test_auto_encoder_ref"
@pytest.fixture(scope="module")
def encoder(test_weights_path: Path, test_device: torch.device) -> LatentDiffusionAutoencoder:
lda_weights = test_weights_path / "lda.safetensors"
if not lda_weights.is_file():
warn(f"could not find weights at {lda_weights}, skipping")
pytest.skip(allow_module_level=True)
encoder = LatentDiffusionAutoencoder(device=test_device)
tensors = load_from_safetensors(lda_weights)
encoder.load_state_dict(tensors)
return encoder
@pytest.fixture(scope="module")
def sample_image(ref_path: Path) -> Image.Image:
test_image = ref_path / "macaw.png"
if not test_image.is_file():
warn(f"could not reference image at {test_image}, skipping")
pytest.skip(allow_module_level=True)
img = Image.open(test_image)
assert img.size == (512, 512)
return img
@no_grad()
def test_encode_decode_image(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
encoded = encoder.image_to_latents(sample_image)
decoded = encoder.latents_to_image(encoded)
assert decoded.mode == "RGB"
# Ensure no saturation. The green channel (band = 1) must not max out.
assert max(iter(decoded.getdata(band=1))) < 255 # type: ignore
ensure_similar_images(sample_image, decoded, min_psnr=20, min_ssim=0.9)
@no_grad()
def test_encode_decode_images(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
encoded = encoder.images_to_latents([sample_image, sample_image])
images = encoder.latents_to_images(encoded)
assert isinstance(images, list)
assert len(images) == 2
ensure_similar_images(sample_image, images[1], min_psnr=20, min_ssim=0.9)