diff --git a/tests/foundationals/dinov2/test_dinov2.py b/tests/foundationals/dinov2/test_dinov2.py index 5020cc8..9b5f40e 100644 --- a/tests/foundationals/dinov2/test_dinov2.py +++ b/tests/foundationals/dinov2/test_dinov2.py @@ -1,3 +1,4 @@ +from math import isclose from pathlib import Path from warnings import warn @@ -7,16 +8,23 @@ from transformers import AutoModel # type: ignore from transformers.models.dinov2.modeling_dinov2 import Dinov2Model # type: ignore from refiners.fluxion.utils import load_from_safetensors, manual_seed -from refiners.foundationals.dinov2 import DINOv2_base, DINOv2_large, DINOv2_small +from refiners.foundationals.dinov2 import ( + DINOv2_base, + DINOv2_base_reg, + DINOv2_large, + DINOv2_large_reg, + DINOv2_small, + DINOv2_small_reg, +) from refiners.foundationals.dinov2.vit import ViT -# TODO: add DINOv2 with registers ("dinov2_vits14_reg", etc). At the time of writing, those are not yet supported in -# transformers (https://github.com/huggingface/transformers/issues/27379). Alternatively, it is also possible to use -# facebookresearch/dinov2 directly (https://github.com/finegrain-ai/refiners/pull/132). FLAVORS = [ "dinov2_vits14", "dinov2_vitb14", "dinov2_vitl14", + "dinov2_vits14_reg4", + "dinov2_vitb14_reg4", + "dinov2_vitl14_reg4", ] @@ -25,6 +33,26 @@ def flavor(request: pytest.FixtureRequest) -> str: return request.param +# Temporary: see comments in `test_encoder_only` +@pytest.fixture(scope="module") +def seed_expected_norm(flavor: str) -> tuple[int, float]: + match flavor: + case "dinov2_vits14": + return (42, 1977.9213867) + case "dinov2_vitb14": + return (42, 1902.6384277) + case "dinov2_vitl14": + return (42, 1763.9187011) + case "dinov2_vits14_reg4": + return (42, 989.2380981) + case "dinov2_vitb14_reg4": + return (42, 974.4362182) + case "dinov2_vitl14_reg4": + return (42, 924.8797607) + case _: + raise ValueError(f"Unexpected DINOv2 flavor: {flavor}") + + @pytest.fixture(scope="module") def our_backbone(test_weights_path: Path, flavor: str, test_device: torch.device) -> ViT: weights = test_weights_path / f"{flavor}_pretrain.safetensors" @@ -38,6 +66,12 @@ def our_backbone(test_weights_path: Path, flavor: str, test_device: torch.device backbone = DINOv2_base(device=test_device) case "dinov2_vitl14": backbone = DINOv2_large(device=test_device) + case "dinov2_vits14_reg4": + backbone = DINOv2_small_reg(device=test_device) + case "dinov2_vitb14_reg4": + backbone = DINOv2_base_reg(device=test_device) + case "dinov2_vitl14_reg4": + backbone = DINOv2_large_reg(device=test_device) case _: raise ValueError(f"Unexpected DINOv2 flavor: {flavor}") tensors = load_from_safetensors(weights) @@ -47,6 +81,12 @@ def our_backbone(test_weights_path: Path, flavor: str, test_device: torch.device @pytest.fixture(scope="module") def dinov2_weights_path(test_weights_path: Path, flavor: str): + # TODO: At the time of writing, those are not yet supported in transformers + # (https://github.com/huggingface/transformers/issues/27379). Alternatively, it is also possible to use + # facebookresearch/dinov2 directly (https://github.com/finegrain-ai/refiners/pull/132). + if flavor.endswith("_reg4"): + warn(f"DINOv2 with registers are not yet supported in Hugging Face, skipping") + pytest.skip(allow_module_level=True) match flavor: case "dinov2_vits14": name = "dinov2-small" @@ -89,3 +129,19 @@ def test_encoder( our_features = our_backbone(x) assert (our_features - ref_features).abs().max() < 1e-3 + + +# Mainly for DINOv2 + registers coverage (this test can be removed once `test_encoder` supports all flavors) +def test_encoder_only( + our_backbone: ViT, + seed_expected_norm: tuple[int, float], + test_device: torch.device, +): + seed, expected_norm = seed_expected_norm + manual_seed(seed) + + x = torch.randn(1, 3, 518, 518).to(test_device) + + our_features = our_backbone(x) + + assert isclose(our_features.norm().item(), expected_norm, rel_tol=1e-04)