From 7dc2e93cfffe94652467d821782c6df310eaaf57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Wed, 30 Aug 2023 10:20:55 +0200 Subject: [PATCH] tests: add test for clip image encoder This covers a CLIPImageEncoderH model (Stable Diffusion v2-1-unclip) specifically --- .../foundationals/clip/test_image_encoder.py | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 tests/foundationals/clip/test_image_encoder.py diff --git a/tests/foundationals/clip/test_image_encoder.py b/tests/foundationals/clip/test_image_encoder.py new file mode 100644 index 0000000..e3027c9 --- /dev/null +++ b/tests/foundationals/clip/test_image_encoder.py @@ -0,0 +1,53 @@ +import torch +import pytest + +from warnings import warn +from pathlib import Path + +from transformers import CLIPVisionModelWithProjection # type: ignore + +from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH +from refiners.fluxion.utils import load_from_safetensors + + +@pytest.fixture(scope="module") +def our_encoder(test_weights_path: Path, test_device: torch.device) -> CLIPImageEncoderH: + weights = test_weights_path / "CLIPImageEncoderH.safetensors" + if not weights.is_file(): + warn(f"could not find weights at {weights}, skipping") + pytest.skip(allow_module_level=True) + encoder = CLIPImageEncoderH(device=test_device) + tensors = load_from_safetensors(weights) + encoder.load_state_dict(tensors) + return encoder + + +@pytest.fixture(scope="module") +def stabilityai_unclip_weights_path(test_weights_path: Path): + r = test_weights_path / "stabilityai" / "stable-diffusion-2-1-unclip" + if not r.is_dir(): + warn(f"could not find Stability AI weights at {r}, skipping") + pytest.skip(allow_module_level=True) + return r + + +@pytest.fixture(scope="module") +def ref_encoder(stabilityai_unclip_weights_path: Path, test_device: torch.device) -> CLIPVisionModelWithProjection: + return CLIPVisionModelWithProjection.from_pretrained(stabilityai_unclip_weights_path, subfolder="image_encoder").to(test_device) # type: ignore + + +def test_encoder( + ref_encoder: CLIPVisionModelWithProjection, + our_encoder: CLIPImageEncoderH, + test_device: torch.device, +): + x = torch.randn(1, 3, 224, 224).to(test_device) + + with torch.no_grad(): + ref_embeddings = ref_encoder(x).image_embeds + our_embeddings = our_encoder(x) + + assert ref_embeddings.shape == (1, 1024) + assert our_embeddings.shape == (1, 1024) + + assert (our_embeddings - ref_embeddings).abs().max() < 0.01