2023-08-30 08:20:55 +00:00
|
|
|
from pathlib import Path
|
|
|
|
|
2023-12-11 10:46:38 +00:00
|
|
|
import pytest
|
|
|
|
import torch
|
2023-08-30 08:20:55 +00:00
|
|
|
from transformers import CLIPVisionModelWithProjection # type: ignore
|
|
|
|
|
2023-12-29 09:59:51 +00:00
|
|
|
from refiners.fluxion.utils import load_from_safetensors, no_grad
|
2023-12-11 10:46:38 +00:00
|
|
|
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
2023-08-30 08:20:55 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
2024-10-03 08:47:37 +00:00
|
|
|
def our_encoder(
|
2024-10-09 09:28:34 +00:00
|
|
|
clip_image_encoder_huge_weights_path: Path,
|
2024-10-03 08:47:37 +00:00
|
|
|
test_device: torch.device,
|
|
|
|
test_dtype_fp32_bf16_fp16: torch.dtype,
|
|
|
|
) -> CLIPImageEncoderH:
|
|
|
|
encoder = CLIPImageEncoderH(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
2024-10-09 09:28:34 +00:00
|
|
|
tensors = load_from_safetensors(clip_image_encoder_huge_weights_path)
|
2023-08-30 08:20:55 +00:00
|
|
|
encoder.load_state_dict(tensors)
|
|
|
|
return encoder
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
2024-10-03 08:47:37 +00:00
|
|
|
def ref_encoder(
|
2024-10-09 09:28:34 +00:00
|
|
|
unclip21_transformers_stabilityai_path: str,
|
2024-10-03 08:47:37 +00:00
|
|
|
test_device: torch.device,
|
|
|
|
test_dtype_fp32_bf16_fp16: torch.dtype,
|
2024-10-09 09:28:34 +00:00
|
|
|
use_local_weights: bool,
|
2024-10-03 08:47:37 +00:00
|
|
|
) -> CLIPVisionModelWithProjection:
|
|
|
|
return CLIPVisionModelWithProjection.from_pretrained( # type: ignore
|
2024-10-09 09:28:34 +00:00
|
|
|
unclip21_transformers_stabilityai_path,
|
|
|
|
local_files_only=use_local_weights,
|
2024-10-03 08:47:37 +00:00
|
|
|
subfolder="image_encoder",
|
2024-10-09 09:28:34 +00:00
|
|
|
).to(device=test_device, dtype=test_dtype_fp32_bf16_fp16) # type: ignore
|
2023-08-30 08:20:55 +00:00
|
|
|
|
|
|
|
|
2024-10-03 08:47:37 +00:00
|
|
|
@no_grad()
|
|
|
|
@pytest.mark.flaky(reruns=3)
|
2023-08-30 08:20:55 +00:00
|
|
|
def test_encoder(
|
|
|
|
ref_encoder: CLIPVisionModelWithProjection,
|
|
|
|
our_encoder: CLIPImageEncoderH,
|
|
|
|
):
|
2024-10-03 08:47:37 +00:00
|
|
|
assert ref_encoder.dtype == our_encoder.dtype
|
|
|
|
assert ref_encoder.device == our_encoder.device
|
|
|
|
x = torch.randn((1, 3, 224, 224), dtype=ref_encoder.dtype, device=ref_encoder.device)
|
2023-08-30 08:20:55 +00:00
|
|
|
|
2024-10-03 08:47:37 +00:00
|
|
|
ref_embeddings = ref_encoder(x).image_embeds
|
|
|
|
our_embeddings = our_encoder(x)
|
2023-08-30 08:20:55 +00:00
|
|
|
|
|
|
|
assert ref_embeddings.shape == (1, 1024)
|
|
|
|
assert our_embeddings.shape == (1, 1024)
|
|
|
|
|
2024-10-03 08:47:37 +00:00
|
|
|
assert torch.allclose(our_embeddings, ref_embeddings, atol=0.05)
|