mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
add minimal unit tests for DINOv2
To be completed with tests using image preprocessing, e.g. test cosine similarity on a relevant pair of images
This commit is contained in:
parent
832f012fe4
commit
68cc346905
|
@ -2,6 +2,9 @@ import torch
|
||||||
|
|
||||||
from refiners.foundationals.dinov2.vit import ViT
|
from refiners.foundationals.dinov2.vit import ViT
|
||||||
|
|
||||||
|
# TODO: add preprocessing logic like
|
||||||
|
# https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/data/transforms.py#L77
|
||||||
|
|
||||||
|
|
||||||
class DINOv2_small(ViT):
|
class DINOv2_small(ViT):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -269,6 +269,7 @@ class ViT(fl.Chain):
|
||||||
),
|
),
|
||||||
dim=1,
|
dim=1,
|
||||||
),
|
),
|
||||||
|
# TODO: support https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/models/vision_transformer.py#L179
|
||||||
PositionalEncoder(
|
PositionalEncoder(
|
||||||
sequence_length=num_patches**2 + 1,
|
sequence_length=num_patches**2 + 1,
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
|
|
91
tests/foundationals/dinov2/test_dinov2.py
Normal file
91
tests/foundationals/dinov2/test_dinov2.py
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from warnings import warn
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModel # type: ignore
|
||||||
|
from transformers.models.dinov2.modeling_dinov2 import Dinov2Model # type: ignore
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import load_from_safetensors, manual_seed
|
||||||
|
from refiners.foundationals.dinov2 import DINOv2_base, DINOv2_large, DINOv2_small
|
||||||
|
from refiners.foundationals.dinov2.vit import ViT
|
||||||
|
|
||||||
|
# TODO: add DINOv2 with registers ("dinov2_vits14_reg", etc). At the time of writing, those are not yet supported in
|
||||||
|
# transformers (https://github.com/huggingface/transformers/issues/27379). Alternatively, it is also possible to use
|
||||||
|
# facebookresearch/dinov2 directly (https://github.com/finegrain-ai/refiners/pull/132).
|
||||||
|
FLAVORS = [
|
||||||
|
"dinov2_vits14",
|
||||||
|
"dinov2_vitb14",
|
||||||
|
"dinov2_vitl14",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", params=FLAVORS)
|
||||||
|
def flavor(request: pytest.FixtureRequest) -> str:
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def our_backbone(test_weights_path: Path, flavor: str, test_device: torch.device) -> ViT:
|
||||||
|
weights = test_weights_path / f"{flavor}_pretrain.safetensors"
|
||||||
|
if not weights.is_file():
|
||||||
|
warn(f"could not find weights at {weights}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
match flavor:
|
||||||
|
case "dinov2_vits14":
|
||||||
|
backbone = DINOv2_small(device=test_device)
|
||||||
|
case "dinov2_vitb14":
|
||||||
|
backbone = DINOv2_base(device=test_device)
|
||||||
|
case "dinov2_vitl14":
|
||||||
|
backbone = DINOv2_large(device=test_device)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}")
|
||||||
|
tensors = load_from_safetensors(weights)
|
||||||
|
backbone.load_state_dict(tensors)
|
||||||
|
return backbone
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def dinov2_weights_path(test_weights_path: Path, flavor: str):
|
||||||
|
match flavor:
|
||||||
|
case "dinov2_vits14":
|
||||||
|
name = "dinov2-small"
|
||||||
|
case "dinov2_vitb14":
|
||||||
|
name = "dinov2-base"
|
||||||
|
case "dinov2_vitl14":
|
||||||
|
name = "dinov2-large"
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}")
|
||||||
|
r = test_weights_path / "facebook" / name
|
||||||
|
if not r.is_dir():
|
||||||
|
warn(f"could not find DINOv2 weights at {r}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def ref_backbone(dinov2_weights_path: Path, test_device: torch.device) -> Dinov2Model:
|
||||||
|
backbone = AutoModel.from_pretrained(dinov2_weights_path) # type: ignore
|
||||||
|
assert isinstance(backbone, Dinov2Model)
|
||||||
|
return backbone.to(test_device) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder(
|
||||||
|
ref_backbone: Dinov2Model,
|
||||||
|
our_backbone: ViT,
|
||||||
|
test_device: torch.device,
|
||||||
|
):
|
||||||
|
manual_seed(42)
|
||||||
|
|
||||||
|
# Position encoding interpolation [1] at runtime is not supported yet. So stick to the default image resolution
|
||||||
|
# e.g. using (224, 224) pixels as input would give a runtime error (sequence size mismatch)
|
||||||
|
# [1]: https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/models/vision_transformer.py#L179
|
||||||
|
assert our_backbone.image_size == 518
|
||||||
|
|
||||||
|
x = torch.randn(1, 3, 518, 518).to(test_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
ref_features = ref_backbone(x).last_hidden_state
|
||||||
|
our_features = our_backbone(x)
|
||||||
|
|
||||||
|
assert (our_features - ref_features).abs().max() < 1e-3
|
Loading…
Reference in a new issue