From 68cc34690588d8f7c2f946e58e3494eeacf3f3cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Sat, 16 Dec 2023 16:16:54 +0100 Subject: [PATCH] add minimal unit tests for DINOv2 To be completed with tests using image preprocessing, e.g. test cosine similarity on a relevant pair of images --- src/refiners/foundationals/dinov2/dinov2.py | 3 + src/refiners/foundationals/dinov2/vit.py | 1 + tests/foundationals/dinov2/test_dinov2.py | 91 +++++++++++++++++++++ 3 files changed, 95 insertions(+) create mode 100644 tests/foundationals/dinov2/test_dinov2.py diff --git a/src/refiners/foundationals/dinov2/dinov2.py b/src/refiners/foundationals/dinov2/dinov2.py index bd7a3a7..a4fbdf7 100644 --- a/src/refiners/foundationals/dinov2/dinov2.py +++ b/src/refiners/foundationals/dinov2/dinov2.py @@ -2,6 +2,9 @@ import torch from refiners.foundationals.dinov2.vit import ViT +# TODO: add preprocessing logic like +# https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/data/transforms.py#L77 + class DINOv2_small(ViT): def __init__( diff --git a/src/refiners/foundationals/dinov2/vit.py b/src/refiners/foundationals/dinov2/vit.py index f80052c..46ddc19 100644 --- a/src/refiners/foundationals/dinov2/vit.py +++ b/src/refiners/foundationals/dinov2/vit.py @@ -269,6 +269,7 @@ class ViT(fl.Chain): ), dim=1, ), + # TODO: support https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/models/vision_transformer.py#L179 PositionalEncoder( sequence_length=num_patches**2 + 1, embedding_dim=embedding_dim, diff --git a/tests/foundationals/dinov2/test_dinov2.py b/tests/foundationals/dinov2/test_dinov2.py new file mode 100644 index 0000000..5020cc8 --- /dev/null +++ b/tests/foundationals/dinov2/test_dinov2.py @@ -0,0 +1,91 @@ +from pathlib import Path +from warnings import warn + +import pytest +import torch +from transformers import AutoModel # type: ignore +from transformers.models.dinov2.modeling_dinov2 import Dinov2Model # type: ignore + +from refiners.fluxion.utils import load_from_safetensors, manual_seed +from refiners.foundationals.dinov2 import DINOv2_base, DINOv2_large, DINOv2_small +from refiners.foundationals.dinov2.vit import ViT + +# TODO: add DINOv2 with registers ("dinov2_vits14_reg", etc). At the time of writing, those are not yet supported in +# transformers (https://github.com/huggingface/transformers/issues/27379). Alternatively, it is also possible to use +# facebookresearch/dinov2 directly (https://github.com/finegrain-ai/refiners/pull/132). +FLAVORS = [ + "dinov2_vits14", + "dinov2_vitb14", + "dinov2_vitl14", +] + + +@pytest.fixture(scope="module", params=FLAVORS) +def flavor(request: pytest.FixtureRequest) -> str: + return request.param + + +@pytest.fixture(scope="module") +def our_backbone(test_weights_path: Path, flavor: str, test_device: torch.device) -> ViT: + weights = test_weights_path / f"{flavor}_pretrain.safetensors" + if not weights.is_file(): + warn(f"could not find weights at {weights}, skipping") + pytest.skip(allow_module_level=True) + match flavor: + case "dinov2_vits14": + backbone = DINOv2_small(device=test_device) + case "dinov2_vitb14": + backbone = DINOv2_base(device=test_device) + case "dinov2_vitl14": + backbone = DINOv2_large(device=test_device) + case _: + raise ValueError(f"Unexpected DINOv2 flavor: {flavor}") + tensors = load_from_safetensors(weights) + backbone.load_state_dict(tensors) + return backbone + + +@pytest.fixture(scope="module") +def dinov2_weights_path(test_weights_path: Path, flavor: str): + match flavor: + case "dinov2_vits14": + name = "dinov2-small" + case "dinov2_vitb14": + name = "dinov2-base" + case "dinov2_vitl14": + name = "dinov2-large" + case _: + raise ValueError(f"Unexpected DINOv2 flavor: {flavor}") + r = test_weights_path / "facebook" / name + if not r.is_dir(): + warn(f"could not find DINOv2 weights at {r}, skipping") + pytest.skip(allow_module_level=True) + return r + + +@pytest.fixture(scope="module") +def ref_backbone(dinov2_weights_path: Path, test_device: torch.device) -> Dinov2Model: + backbone = AutoModel.from_pretrained(dinov2_weights_path) # type: ignore + assert isinstance(backbone, Dinov2Model) + return backbone.to(test_device) # type: ignore + + +def test_encoder( + ref_backbone: Dinov2Model, + our_backbone: ViT, + test_device: torch.device, +): + manual_seed(42) + + # Position encoding interpolation [1] at runtime is not supported yet. So stick to the default image resolution + # e.g. using (224, 224) pixels as input would give a runtime error (sequence size mismatch) + # [1]: https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/models/vision_transformer.py#L179 + assert our_backbone.image_size == 518 + + x = torch.randn(1, 3, 518, 518).to(test_device) + + with torch.no_grad(): + ref_features = ref_backbone(x).last_hidden_state + our_features = our_backbone(x) + + assert (our_features - ref_features).abs().max() < 1e-3